fix pinn doc
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
"""Module for Physics Informed Neural Network Interface."""
|
||||
"""Module for the Physics Informed Neural Network Interface."""
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
import torch
|
||||
@@ -17,14 +17,13 @@ from ...condition import (
|
||||
|
||||
class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
"""
|
||||
Base PINN solver class. This class implements the Solver Interface
|
||||
for Physics Informed Neural Network solver.
|
||||
Base class for Physics-Informed Neural Network (PINN) solvers, implementing
|
||||
the :class:`~pina.solver.SolverInterface` class.
|
||||
|
||||
This class can be used to define PINNs with multiple ``optimizers``,
|
||||
and/or ``models``.
|
||||
By default it takes :class:`~pina.problem.abstract_problem.AbstractProblem`,
|
||||
so the user can choose what type of problem the implemented solver,
|
||||
inheriting from this class, is designed to solve.
|
||||
The `PINNInterface` class can be used to define PINNs that work with one or
|
||||
multiple optimizers and/or models. By default, it is compatible with
|
||||
problems defined by :class:`~pina.problem.AbstractProblem`, and users can
|
||||
choose the problem type the solver is meant to address.
|
||||
"""
|
||||
|
||||
accepted_conditions_types = (
|
||||
@@ -35,9 +34,14 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
|
||||
def __init__(self, problem, loss=None, **kwargs):
|
||||
"""
|
||||
:param AbstractProblem problem: A problem definition instance.
|
||||
:param torch.nn.Module loss: The loss function to be minimized,
|
||||
default `None`.
|
||||
Initialization of the :class:`PINNInterface` class.
|
||||
|
||||
:param AbstractProblem problem: The problem to be solved.
|
||||
:param torch.nn.Module loss: The loss function to be minimized.
|
||||
If ``None``, the Mean Squared Error (MSE) loss is used.
|
||||
Default is ``None``.
|
||||
:param kwargs: Additional keyword arguments to be passed to the
|
||||
:class:`~pina.solver.SolverInterface` class.
|
||||
"""
|
||||
|
||||
if loss is None:
|
||||
@@ -62,10 +66,28 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
self.__metric = None
|
||||
|
||||
def optimization_cycle(self, batch):
|
||||
"""
|
||||
The optimization cycle for the PINN solver.
|
||||
|
||||
This method allows to call `_run_optimization_cycle` with the physics
|
||||
loss as argument, thus distinguishing the training step from the
|
||||
validation and test steps.
|
||||
|
||||
:param dict batch: The batch of data to use in the optimization cycle.
|
||||
:return: The loss of the optimization cycle.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return self._run_optimization_cycle(batch, self.loss_phys)
|
||||
|
||||
@torch.set_grad_enabled(True)
|
||||
def validation_step(self, batch):
|
||||
"""
|
||||
The validation step for the PINN solver.
|
||||
|
||||
:param dict batch: The batch of data to use in the validation step.
|
||||
:return: The loss of the validation step.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
losses = self._run_optimization_cycle(batch, self._residual_loss)
|
||||
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
|
||||
self.store_log("val_loss", loss, self.get_batch_size(batch))
|
||||
@@ -73,6 +95,13 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
|
||||
@torch.set_grad_enabled(True)
|
||||
def test_step(self, batch):
|
||||
"""
|
||||
The test step for the PINN solver.
|
||||
|
||||
:param dict batch: The batch of data to use in the test step.
|
||||
:return: The loss of the test step.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
losses = self._run_optimization_cycle(batch, self._residual_loss)
|
||||
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
|
||||
self.store_log("test_loss", loss, self.get_batch_size(batch))
|
||||
@@ -80,14 +109,14 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
|
||||
def loss_data(self, input_pts, output_pts):
|
||||
"""
|
||||
The data loss for the PINN solver. It computes the loss between
|
||||
the network output against the true solution. This function
|
||||
should not be override if not intentionally.
|
||||
Compute the data loss for the PINN solver by evaluating the loss
|
||||
between the network's output and the true solution. This method
|
||||
should only be overridden intentionally.
|
||||
|
||||
:param LabelTensor input_pts: The input to the neural networks.
|
||||
:param LabelTensor output_pts: The true solution to compare the
|
||||
network solution.
|
||||
:return: The residual loss averaged on the input coordinates
|
||||
:param LabelTensor input_pts: The input points to the neural network.
|
||||
:param LabelTensor output_pts: The true solution to compare with the
|
||||
network's output.
|
||||
:return: The supervised loss, averaged over the number of observations.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return self._loss(self.forward(input_pts), output_pts)
|
||||
@@ -95,28 +124,23 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
@abstractmethod
|
||||
def loss_phys(self, samples, equation):
|
||||
"""
|
||||
Computes the physics loss for the physics informed solver based on given
|
||||
samples and equation. This method must be override by all inherited
|
||||
classes and it is the core to define a new physics informed solver.
|
||||
Computes the physics loss for the physics-informed solver based on the
|
||||
provided samples and equation. This method must be overridden in
|
||||
subclasses. It distinguishes different types of PINN solvers.
|
||||
|
||||
:param LabelTensor samples: The samples to evaluate the physics loss.
|
||||
:param EquationInterface equation: The governing equation
|
||||
representing the physics.
|
||||
:return: The physics loss calculated based on given
|
||||
samples and equation.
|
||||
:param EquationInterface equation: The governing equation.
|
||||
:return: The computed physics loss.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
|
||||
def compute_residual(self, samples, equation):
|
||||
"""
|
||||
Compute the residual for Physics Informed learning. This function
|
||||
returns the :obj:`~pina.equation.equation.Equation` specified in the
|
||||
:obj:`~pina.condition.Condition` evaluated at the ``samples`` points.
|
||||
Compute the residuals of the equation.
|
||||
|
||||
:param LabelTensor samples: The samples to evaluate the physics loss.
|
||||
:param EquationInterface equation: The governing equation
|
||||
representing the physics.
|
||||
:return: The residual of the neural network solution.
|
||||
:param LabelTensor samples: The samples to evaluate the loss.
|
||||
:param EquationInterface equation: The governing equation.
|
||||
:return: The residual of the solution of the model.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
try:
|
||||
@@ -129,10 +153,27 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
return residual
|
||||
|
||||
def _residual_loss(self, samples, equation):
|
||||
"""
|
||||
Compute the residual loss.
|
||||
|
||||
:param LabelTensor samples: The samples to evaluate the loss.
|
||||
:param EquationInterface equation: The governing equation.
|
||||
:return: The residual loss.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
residuals = self.compute_residual(samples, equation)
|
||||
return self.loss(residuals, torch.zeros_like(residuals))
|
||||
|
||||
def _run_optimization_cycle(self, batch, loss_residuals):
|
||||
"""
|
||||
Compute, given a batch, the loss for each condition and return a
|
||||
dictionary with the condition name as key and the loss as value.
|
||||
|
||||
:param dict batch: The batch of data to use in the optimization cycle.
|
||||
:param function loss_residuals: The loss function to be minimized.
|
||||
:return: The loss for each condition.
|
||||
:rtype dict
|
||||
"""
|
||||
condition_loss = {}
|
||||
for condition_name, points in batch:
|
||||
self.__metric = condition_name
|
||||
@@ -158,8 +199,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
|
||||
def _clamp_inverse_problem_params(self):
|
||||
"""
|
||||
Clamps the parameters of the inverse problem
|
||||
solver to the specified ranges.
|
||||
Clamps the parameters of the inverse problem solver to specified ranges.
|
||||
"""
|
||||
for v in self._params:
|
||||
self._params[v].data.clamp_(
|
||||
@@ -170,7 +210,10 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
@property
|
||||
def loss(self):
|
||||
"""
|
||||
Loss used for training.
|
||||
The loss used for training.
|
||||
|
||||
:return: The loss function used for training.
|
||||
:rtype: torch.nn.Module
|
||||
"""
|
||||
return self._loss
|
||||
|
||||
@@ -178,5 +221,8 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
def current_condition_name(self):
|
||||
"""
|
||||
The current condition name.
|
||||
|
||||
:return: The current condition name.
|
||||
:rtype: str
|
||||
"""
|
||||
return self.__metric
|
||||
|
||||
Reference in New Issue
Block a user