fix pinn doc

This commit is contained in:
giovanni
2025-03-13 16:50:05 +01:00
committed by Nicola Demo
parent 9a26c94e07
commit 28ef4c823b
8 changed files with 377 additions and 337 deletions

View File

@@ -1,4 +1,4 @@
"""Module for Physics Informed Neural Network Interface."""
"""Module for the Physics Informed Neural Network Interface."""
from abc import ABCMeta, abstractmethod
import torch
@@ -17,14 +17,13 @@ from ...condition import (
class PINNInterface(SolverInterface, metaclass=ABCMeta):
"""
Base PINN solver class. This class implements the Solver Interface
for Physics Informed Neural Network solver.
Base class for Physics-Informed Neural Network (PINN) solvers, implementing
the :class:`~pina.solver.SolverInterface` class.
This class can be used to define PINNs with multiple ``optimizers``,
and/or ``models``.
By default it takes :class:`~pina.problem.abstract_problem.AbstractProblem`,
so the user can choose what type of problem the implemented solver,
inheriting from this class, is designed to solve.
The `PINNInterface` class can be used to define PINNs that work with one or
multiple optimizers and/or models. By default, it is compatible with
problems defined by :class:`~pina.problem.AbstractProblem`, and users can
choose the problem type the solver is meant to address.
"""
accepted_conditions_types = (
@@ -35,9 +34,14 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
def __init__(self, problem, loss=None, **kwargs):
"""
:param AbstractProblem problem: A problem definition instance.
:param torch.nn.Module loss: The loss function to be minimized,
default `None`.
Initialization of the :class:`PINNInterface` class.
:param AbstractProblem problem: The problem to be solved.
:param torch.nn.Module loss: The loss function to be minimized.
If ``None``, the Mean Squared Error (MSE) loss is used.
Default is ``None``.
:param kwargs: Additional keyword arguments to be passed to the
:class:`~pina.solver.SolverInterface` class.
"""
if loss is None:
@@ -62,10 +66,28 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
self.__metric = None
def optimization_cycle(self, batch):
"""
The optimization cycle for the PINN solver.
This method allows to call `_run_optimization_cycle` with the physics
loss as argument, thus distinguishing the training step from the
validation and test steps.
:param dict batch: The batch of data to use in the optimization cycle.
:return: The loss of the optimization cycle.
:rtype: torch.Tensor
"""
return self._run_optimization_cycle(batch, self.loss_phys)
@torch.set_grad_enabled(True)
def validation_step(self, batch):
"""
The validation step for the PINN solver.
:param dict batch: The batch of data to use in the validation step.
:return: The loss of the validation step.
:rtype: torch.Tensor
"""
losses = self._run_optimization_cycle(batch, self._residual_loss)
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
self.store_log("val_loss", loss, self.get_batch_size(batch))
@@ -73,6 +95,13 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
@torch.set_grad_enabled(True)
def test_step(self, batch):
"""
The test step for the PINN solver.
:param dict batch: The batch of data to use in the test step.
:return: The loss of the test step.
:rtype: torch.Tensor
"""
losses = self._run_optimization_cycle(batch, self._residual_loss)
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
self.store_log("test_loss", loss, self.get_batch_size(batch))
@@ -80,14 +109,14 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
def loss_data(self, input_pts, output_pts):
"""
The data loss for the PINN solver. It computes the loss between
the network output against the true solution. This function
should not be override if not intentionally.
Compute the data loss for the PINN solver by evaluating the loss
between the network's output and the true solution. This method
should only be overridden intentionally.
:param LabelTensor input_pts: The input to the neural networks.
:param LabelTensor output_pts: The true solution to compare the
network solution.
:return: The residual loss averaged on the input coordinates
:param LabelTensor input_pts: The input points to the neural network.
:param LabelTensor output_pts: The true solution to compare with the
network's output.
:return: The supervised loss, averaged over the number of observations.
:rtype: torch.Tensor
"""
return self._loss(self.forward(input_pts), output_pts)
@@ -95,28 +124,23 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
@abstractmethod
def loss_phys(self, samples, equation):
"""
Computes the physics loss for the physics informed solver based on given
samples and equation. This method must be override by all inherited
classes and it is the core to define a new physics informed solver.
Computes the physics loss for the physics-informed solver based on the
provided samples and equation. This method must be overridden in
subclasses. It distinguishes different types of PINN solvers.
:param LabelTensor samples: The samples to evaluate the physics loss.
:param EquationInterface equation: The governing equation
representing the physics.
:return: The physics loss calculated based on given
samples and equation.
:param EquationInterface equation: The governing equation.
:return: The computed physics loss.
:rtype: LabelTensor
"""
def compute_residual(self, samples, equation):
"""
Compute the residual for Physics Informed learning. This function
returns the :obj:`~pina.equation.equation.Equation` specified in the
:obj:`~pina.condition.Condition` evaluated at the ``samples`` points.
Compute the residuals of the equation.
:param LabelTensor samples: The samples to evaluate the physics loss.
:param EquationInterface equation: The governing equation
representing the physics.
:return: The residual of the neural network solution.
:param LabelTensor samples: The samples to evaluate the loss.
:param EquationInterface equation: The governing equation.
:return: The residual of the solution of the model.
:rtype: LabelTensor
"""
try:
@@ -129,10 +153,27 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
return residual
def _residual_loss(self, samples, equation):
"""
Compute the residual loss.
:param LabelTensor samples: The samples to evaluate the loss.
:param EquationInterface equation: The governing equation.
:return: The residual loss.
:rtype: torch.Tensor
"""
residuals = self.compute_residual(samples, equation)
return self.loss(residuals, torch.zeros_like(residuals))
def _run_optimization_cycle(self, batch, loss_residuals):
"""
Compute, given a batch, the loss for each condition and return a
dictionary with the condition name as key and the loss as value.
:param dict batch: The batch of data to use in the optimization cycle.
:param function loss_residuals: The loss function to be minimized.
:return: The loss for each condition.
:rtype dict
"""
condition_loss = {}
for condition_name, points in batch:
self.__metric = condition_name
@@ -158,8 +199,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
def _clamp_inverse_problem_params(self):
"""
Clamps the parameters of the inverse problem
solver to the specified ranges.
Clamps the parameters of the inverse problem solver to specified ranges.
"""
for v in self._params:
self._params[v].data.clamp_(
@@ -170,7 +210,10 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
@property
def loss(self):
"""
Loss used for training.
The loss used for training.
:return: The loss function used for training.
:rtype: torch.nn.Module
"""
return self._loss
@@ -178,5 +221,8 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
def current_condition_name(self):
"""
The current condition name.
:return: The current condition name.
:rtype: str
"""
return self.__metric