fix rendering part 2

This commit is contained in:
giovanni
2025-03-14 00:10:18 +01:00
committed by Nicola Demo
parent e0ad4dc8a0
commit d2e3f458ab
17 changed files with 217 additions and 147 deletions

View File

@@ -18,12 +18,12 @@ from ...condition import (
class PINNInterface(SolverInterface, metaclass=ABCMeta):
"""
Base class for Physics-Informed Neural Network (PINN) solvers, implementing
the :class:`~pina.solver.SolverInterface` class.
the :class:`~pina.solver.solver.SolverInterface` class.
The `PINNInterface` class can be used to define PINNs that work with one or
multiple optimizers and/or models. By default, it is compatible with
problems defined by :class:`~pina.problem.AbstractProblem`, and users can
choose the problem type the solver is meant to address.
problems defined by :class:`~pina.problem.abstract_problem.AbstractProblem`,
and users can choose the problem type the solver is meant to address.
"""
accepted_conditions_types = (
@@ -38,10 +38,10 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
:param AbstractProblem problem: The problem to be solved.
:param torch.nn.Module loss: The loss function to be minimized.
If ``None``, the Mean Squared Error (MSE) loss is used.
Default is ``None``.
If `None`, the :class:`torch.nn.MSELoss` loss is used.
Default is `None`.
:param kwargs: Additional keyword arguments to be passed to the
:class:`~pina.solver.SolverInterface` class.
:class:`~pina.solver.solver.SolverInterface` class.
"""
if loss is None:
@@ -73,9 +73,12 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
loss as argument, thus distinguishing the training step from the
validation and test steps.
:param dict batch: The batch of data to use in the optimization cycle.
:return: The loss of the optimization cycle.
:rtype: torch.Tensor
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:return: The losses computed for all conditions in the batch, casted
to a subclass of :class:`torch.Tensor`. It should return a dict
containing the condition name and the associated scalar loss.
:rtype: dict
"""
return self._run_optimization_cycle(batch, self.loss_phys)
@@ -84,7 +87,8 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
"""
The validation step for the PINN solver.
:param dict batch: The batch of data to use in the validation step.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:return: The loss of the validation step.
:rtype: torch.Tensor
"""
@@ -98,7 +102,8 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
"""
The test step for the PINN solver.
:param dict batch: The batch of data to use in the test step.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:return: The loss of the test step.
:rtype: torch.Tensor
"""
@@ -169,10 +174,13 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
Compute, given a batch, the loss for each condition and return a
dictionary with the condition name as key and the loss as value.
:param dict batch: The batch of data to use in the optimization cycle.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:param function loss_residuals: The loss function to be minimized.
:return: The loss for each condition.
:rtype dict
:return: The losses computed for all conditions in the batch, casted
to a subclass of :class:`torch.Tensor`. It should return a dict
containing the condition name and the associated scalar loss.
:rtype: dict
"""
condition_loss = {}
for condition_name, points in batch: