PINN variants addition and Solvers Update (#263)
* gpinn/basepinn new classes, pinn restructure * codacy fix gpinn/basepinn/pinn * inverse problem fix * Causal PINN (#267) * fix GPU training in inverse problem (#283) * Create a `compute_residual` attribute for `PINNInterface` * Modify dataloading in solvers (#286) * Modify PINNInterface by removing _loss_phys, _loss_data * Adding in PINNInterface a variable to track the current condition during training * Modify GPINN,PINN,CausalPINN to match changes in PINNInterface * Competitive Pinn Addition (#288) * fixing after rebase/ fix loss * fixing final issues --------- Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local> * Modify min max formulation to max min for paper consistency * Adding SAPINN solver (#291) * rom solver * fix import --------- Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local> Co-authored-by: Anna Ivagnes <75523024+annaivagnes@users.noreply.github.com> Co-authored-by: valc89 <103250118+valc89@users.noreply.github.com> Co-authored-by: Monthly Tag bot <mtbot@noreply.github.com> Co-authored-by: Nicola Demo <demo.nicola@gmail.com>
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
""" Module for SupervisedSolver """
|
||||
|
||||
import torch
|
||||
import sys
|
||||
|
||||
try:
|
||||
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
|
||||
@@ -20,9 +19,32 @@ from torch.nn.modules.loss import _Loss
|
||||
|
||||
|
||||
class SupervisedSolver(SolverInterface):
|
||||
"""
|
||||
r"""
|
||||
SupervisedSolver solver class. This class implements a SupervisedSolver,
|
||||
using a user specified ``model`` to solve a specific ``problem``.
|
||||
|
||||
The Supervised Solver class aims to find
|
||||
a map between the input :math:`\mathbf{s}:\Omega\rightarrow\mathbb{R}^m`
|
||||
and the output :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`. The input
|
||||
can be discretised in space (as in :obj:`~pina.solvers.rom.ROMe2eSolver`),
|
||||
or not (e.g. when training Neural Operators).
|
||||
|
||||
Given a model :math:`\mathcal{M}`, the following loss function is
|
||||
minimized during training:
|
||||
|
||||
.. math::
|
||||
\mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N
|
||||
\mathcal{L}(\mathbf{u}_i - \mathcal{M}(\mathbf{v}_i))
|
||||
|
||||
where :math:`\mathcal{L}` is a specific loss function,
|
||||
default Mean Square Error:
|
||||
|
||||
.. math::
|
||||
\mathcal{L}(v) = \| v \|^2_2.
|
||||
|
||||
In this context :math:`\mathbf{u}_i` and :math:`\mathbf{v}_i` means that
|
||||
we are seeking to approximate multiple (discretised) functions given
|
||||
multiple (discretised) input functions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -96,18 +118,12 @@ class SupervisedSolver(SolverInterface):
|
||||
:return: The sum of the loss functions.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
|
||||
dataloader = self.trainer.train_dataloader
|
||||
|
||||
condition_idx = batch["condition"]
|
||||
|
||||
for condition_id in range(condition_idx.min(), condition_idx.max() + 1):
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
condition_name = dataloader.condition_names[condition_id]
|
||||
else:
|
||||
condition_name = dataloader.loaders.condition_names[
|
||||
condition_id
|
||||
]
|
||||
condition_name = self._dataloader.condition_names[condition_id]
|
||||
condition = self.problem.conditions[condition_name]
|
||||
pts = batch["pts"]
|
||||
out = batch["output"]
|
||||
@@ -118,14 +134,14 @@ class SupervisedSolver(SolverInterface):
|
||||
# for data driven mode
|
||||
if not hasattr(condition, "output_points"):
|
||||
raise NotImplementedError(
|
||||
"Supervised solver works only in data-driven mode."
|
||||
f"{type(self).__name__} works only in data-driven mode."
|
||||
)
|
||||
|
||||
output_pts = out[condition_idx == condition_id]
|
||||
input_pts = pts[condition_idx == condition_id]
|
||||
|
||||
loss = (
|
||||
self.loss(self.forward(input_pts), output_pts)
|
||||
self.loss_data(input_pts=input_pts, output_pts=output_pts)
|
||||
* condition.data_weight
|
||||
)
|
||||
loss = loss.as_subclass(torch.Tensor)
|
||||
@@ -133,6 +149,20 @@ class SupervisedSolver(SolverInterface):
|
||||
self.log("mean_loss", float(loss), prog_bar=True, logger=True)
|
||||
return loss
|
||||
|
||||
def loss_data(self, input_pts, output_pts):
|
||||
"""
|
||||
The data loss for the Supervised solver. It computes the loss between
|
||||
the network output against the true solution. This function
|
||||
should not be override if not intentionally.
|
||||
|
||||
:param LabelTensor input_tensor: The input to the neural networks.
|
||||
:param LabelTensor output_tensor: The true solution to compare the
|
||||
network solution.
|
||||
:return: The residual loss averaged on the input coordinates
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return self.loss(self.forward(input_pts), output_pts)
|
||||
|
||||
@property
|
||||
def scheduler(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user