fix rendering part 2
This commit is contained in:
@@ -18,12 +18,12 @@ from ...condition import (
|
||||
class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
"""
|
||||
Base class for Physics-Informed Neural Network (PINN) solvers, implementing
|
||||
the :class:`~pina.solver.SolverInterface` class.
|
||||
the :class:`~pina.solver.solver.SolverInterface` class.
|
||||
|
||||
The `PINNInterface` class can be used to define PINNs that work with one or
|
||||
multiple optimizers and/or models. By default, it is compatible with
|
||||
problems defined by :class:`~pina.problem.AbstractProblem`, and users can
|
||||
choose the problem type the solver is meant to address.
|
||||
problems defined by :class:`~pina.problem.abstract_problem.AbstractProblem`,
|
||||
and users can choose the problem type the solver is meant to address.
|
||||
"""
|
||||
|
||||
accepted_conditions_types = (
|
||||
@@ -38,10 +38,10 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
|
||||
:param AbstractProblem problem: The problem to be solved.
|
||||
:param torch.nn.Module loss: The loss function to be minimized.
|
||||
If ``None``, the Mean Squared Error (MSE) loss is used.
|
||||
Default is ``None``.
|
||||
If `None`, the :class:`torch.nn.MSELoss` loss is used.
|
||||
Default is `None`.
|
||||
:param kwargs: Additional keyword arguments to be passed to the
|
||||
:class:`~pina.solver.SolverInterface` class.
|
||||
:class:`~pina.solver.solver.SolverInterface` class.
|
||||
"""
|
||||
|
||||
if loss is None:
|
||||
@@ -73,9 +73,12 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
loss as argument, thus distinguishing the training step from the
|
||||
validation and test steps.
|
||||
|
||||
:param dict batch: The batch of data to use in the optimization cycle.
|
||||
:return: The loss of the optimization cycle.
|
||||
:rtype: torch.Tensor
|
||||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||
tuple containing a condition name and a dictionary of points.
|
||||
:return: The losses computed for all conditions in the batch, casted
|
||||
to a subclass of :class:`torch.Tensor`. It should return a dict
|
||||
containing the condition name and the associated scalar loss.
|
||||
:rtype: dict
|
||||
"""
|
||||
return self._run_optimization_cycle(batch, self.loss_phys)
|
||||
|
||||
@@ -84,7 +87,8 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
"""
|
||||
The validation step for the PINN solver.
|
||||
|
||||
:param dict batch: The batch of data to use in the validation step.
|
||||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||
tuple containing a condition name and a dictionary of points.
|
||||
:return: The loss of the validation step.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
@@ -98,7 +102,8 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
"""
|
||||
The test step for the PINN solver.
|
||||
|
||||
:param dict batch: The batch of data to use in the test step.
|
||||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||
tuple containing a condition name and a dictionary of points.
|
||||
:return: The loss of the test step.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
@@ -169,10 +174,13 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
Compute, given a batch, the loss for each condition and return a
|
||||
dictionary with the condition name as key and the loss as value.
|
||||
|
||||
:param dict batch: The batch of data to use in the optimization cycle.
|
||||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||
tuple containing a condition name and a dictionary of points.
|
||||
:param function loss_residuals: The loss function to be minimized.
|
||||
:return: The loss for each condition.
|
||||
:rtype dict
|
||||
:return: The losses computed for all conditions in the batch, casted
|
||||
to a subclass of :class:`torch.Tensor`. It should return a dict
|
||||
containing the condition name and the associated scalar loss.
|
||||
:rtype: dict
|
||||
"""
|
||||
condition_loss = {}
|
||||
for condition_name, points in batch:
|
||||
|
||||
Reference in New Issue
Block a user