Refactoring solvers (#541)

* Refactoring solvers

* Simplify logic compile
* Improve and update doc
* Create SupervisedSolverInterface
* Specialize SupervisedSolver and ReducedOrderModelSolver
* Create EnsembleSolverInterface + EnsembleSupervisedSolver
* Create tests ensemble solvers

* formatter

* codacy

* fix issues + speedup test
This commit is contained in:
Dario Coscia
2025-04-09 14:51:42 +02:00
committed by FilippoOlivo
parent fa6fda0bd5
commit 1bb3c125ac
37 changed files with 1514 additions and 510 deletions

View File

@@ -68,6 +68,8 @@ Solvers
SolverInterface <solver/solver_interface.rst>
SingleSolverInterface <solver/single_solver_interface.rst>
MultiSolverInterface <solver/multi_solver_interface.rst>
SupervisedSolverInterface <solver/supervised_solver/supervised_solver_interface>
DeepEnsembleSolverInterface <solver/ensemble_solver/ensemble_solver_interface>
PINNInterface <solver/physics_informed_solver/pinn_interface.rst>
PINN <solver/physics_informed_solver/pinn.rst>
GradientPINN <solver/physics_informed_solver/gradient_pinn.rst>
@@ -75,8 +77,10 @@ Solvers
CompetitivePINN <solver/physics_informed_solver/competitive_pinn.rst>
SelfAdaptivePINN <solver/physics_informed_solver/self_adaptive_pinn.rst>
RBAPINN <solver/physics_informed_solver/rba_pinn.rst>
SupervisedSolver <solver/supervised.rst>
ReducedOrderModelSolver <solver/reduced_order_model.rst>
DeepEnsemblePINN <solver/ensemble_solver/ensemble_pinn>
SupervisedSolver <solver/supervised_solver/supervised.rst>
DeepEnsembleSupervisedSolver <solver/ensemble_solver/ensemble_supervised>
ReducedOrderModelSolver <solver/supervised_solver/reduced_order_model.rst>
GAROM <solver/garom.rst>

View File

@@ -0,0 +1,8 @@
DeepEnsemblePINN
==================
.. currentmodule:: pina.solver.ensemble_solver.ensemble_pinn
.. autoclass:: DeepEnsemblePINN
:show-inheritance:
:members:

View File

@@ -0,0 +1,8 @@
DeepEnsembleSolverInterface
=============================
.. currentmodule:: pina.solver.ensemble_solver.ensemble_solver_interface
.. autoclass:: DeepEnsembleSolverInterface
:show-inheritance:
:members:

View File

@@ -0,0 +1,8 @@
DeepEnsembleSupervisedSolver
=============================
.. currentmodule:: pina.solver.ensemble_solver.ensemble_supervised
.. autoclass:: DeepEnsembleSupervisedSolver
:show-inheritance:
:members:

View File

@@ -1,6 +1,6 @@
ReducedOrderModelSolver
==========================
.. currentmodule:: pina.solver.reduced_order_model
.. currentmodule:: pina.solver.supervised_solver.reduced_order_model
.. autoclass:: ReducedOrderModelSolver
:members:

View File

@@ -1,6 +1,6 @@
SupervisedSolver
===================
.. currentmodule:: pina.solver.supervised
.. currentmodule:: pina.solver.supervised_solver.supervised
.. autoclass:: SupervisedSolver
:members:

View File

@@ -0,0 +1,8 @@
SupervisedSolverInterface
==========================
.. currentmodule:: pina.solver.supervised_solver.supervised_solver_interface
.. autoclass:: SupervisedSolverInterface
:show-inheritance:
:members:

View File

@@ -11,13 +11,33 @@ __all__ = [
"CompetitivePINN",
"SelfAdaptivePINN",
"RBAPINN",
"SupervisedSolverInterface",
"SupervisedSolver",
"ReducedOrderModelSolver",
"DeepEnsembleSolverInterface",
"DeepEnsembleSupervisedSolver",
"DeepEnsemblePINN",
"GAROM",
]
from .solver import SolverInterface, SingleSolverInterface, MultiSolverInterface
from .physics_informed_solver import *
from .supervised import SupervisedSolver
from .reduced_order_model import ReducedOrderModelSolver
from .physics_informed_solver import (
PINNInterface,
PINN,
GradientPINN,
CausalPINN,
CompetitivePINN,
SelfAdaptivePINN,
RBAPINN,
)
from .supervised_solver import (
SupervisedSolverInterface,
SupervisedSolver,
ReducedOrderModelSolver,
)
from .ensemble_solver import (
DeepEnsembleSolverInterface,
DeepEnsembleSupervisedSolver,
DeepEnsemblePINN,
)
from .garom import GAROM

View File

@@ -0,0 +1,11 @@
"""Module for the Ensemble solver classes."""
__all__ = [
"DeepEnsembleSolverInterface",
"DeepEnsembleSupervisedSolver",
"DeepEnsemblePINN",
]
from .ensemble_solver_interface import DeepEnsembleSolverInterface
from .ensemble_supervised import DeepEnsembleSupervisedSolver
from .ensemble_pinn import DeepEnsemblePINN

View File

@@ -0,0 +1,170 @@
"""Module for the DeepEnsemble physics solver."""
import torch
from .ensemble_solver_interface import DeepEnsembleSolverInterface
from ..physics_informed_solver import PINNInterface
from ...problem import InverseProblem
class DeepEnsemblePINN(PINNInterface, DeepEnsembleSolverInterface):
r"""
Deep Ensemble Physics Informed Solver class. This class implements a
Deep Ensemble for Physics Informed Neural Networks using user
specified ``model``s to solve a specific ``problem``.
An ensemble model is constructed by combining multiple models that solve
the same type of problem. Mathematically, this creates an implicit
distribution :math:`p(\mathbf{u} \mid \mathbf{s})` over the possible
outputs :math:`\mathbf{u}`, given the original input :math:`\mathbf{s}`.
The models :math:`\mathcal{M}_{i\in (1,\dots,r)}` in
the ensemble work collaboratively to capture different
aspects of the data or task, with each model contributing a distinct
prediction :math:`\mathbf{y}_{i}=\mathcal{M}_i(\mathbf{u} \mid \mathbf{s})`.
By aggregating these predictions, the ensemble
model can achieve greater robustness and accuracy compared to individual
models, leveraging the diversity of the models to reduce overfitting and
improve generalization. Furthemore, statistical metrics can
be computed, e.g. the ensemble mean and variance:
.. math::
\mathbf{\mu} = \frac{1}{N}\sum_{i=1}^r \mathbf{y}_{i}
.. math::
\mathbf{\sigma^2} = \frac{1}{N}\sum_{i=1}^r
(\mathbf{y}_{i} - \mathbf{\mu})^2
During training the PINN loss is minimized by each ensemble model:
.. math::
\mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^4
\mathcal{L}(\mathcal{A}[\mathbf{u}](\mathbf{x}_i)) +
\frac{1}{N}\sum_{i=1}^N
\mathcal{L}(\mathcal{B}[\mathbf{u}](\mathbf{x}_i)),
for the differential system:
.. math::
\begin{cases}
\mathcal{A}[\mathbf{u}](\mathbf{x})=0\quad,\mathbf{x}\in\Omega\\
\mathcal{B}[\mathbf{u}](\mathbf{x})=0\quad,
\mathbf{x}\in\partial\Omega
\end{cases}
:math:`\mathcal{L}` indicates a specific loss function, typically the MSE:
.. math::
\mathcal{L}(v) = \| v \|^2_2.
.. seealso::
**Original reference**: Zou, Z., Wang, Z., & Karniadakis, G. E. (2025).
*Learning and discovering multiple solutions using physics-informed
neural networks with random initialization and deep ensemble*.
DOI: `arXiv:2503.06320 <https://arxiv.org/abs/2503.06320>`_.
.. warning::
This solver does not work with inverse problem. Hence in the ``problem``
definition must not inherit from
:class:`~pina.problem.inverse_problem.InverseProblem`.
"""
def __init__(
self,
problem,
models,
loss=None,
optimizers=None,
schedulers=None,
weighting=None,
ensemble_dim=0,
):
"""
Initialization of the :class:`DeepEnsemblePINN` class.
:param AbstractProblem problem: The problem to be solved.
:param torch.nn.Module models: The neural network models to be used.
:param torch.nn.Module loss: The loss function to be minimized.
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
Default is ``None``.
:param Optimizer optimizer: The optimizer to be used.
If ``None``, the :class:`torch.optim.Adam` optimizer is used.
Default is ``None``.
:param Scheduler scheduler: Learning rate scheduler.
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If ``None``, no weighting schema is used. Default is ``None``.
:param int ensemble_dim: The dimension along which the ensemble
outputs are stacked. Default is 0.
:raises NotImplementedError: If an inverse problem is passed.
"""
if isinstance(problem, InverseProblem):
raise NotImplementedError(
"DeepEnsemblePINN can not be used to solve inverse problems."
)
super().__init__(
problem=problem,
models=models,
loss=loss,
optimizers=optimizers,
schedulers=schedulers,
weighting=weighting,
ensemble_dim=ensemble_dim,
)
def loss_data(self, input, target):
"""
Compute the data loss for the ensemble PINN solver by evaluating
the loss between the network's output and the true solution for each
model. This method should not be overridden, if not intentionally.
:param input: The input to the neural network.
:type input: LabelTensor | torch.Tensor | Graph | Data
:param target: The target to compare with the network's output.
:type target: LabelTensor | torch.Tensor | Graph | Data
:return: The supervised loss, averaged over the number of observations.
:rtype: torch.Tensor
"""
predictions = self.forward(input)
loss = sum(
self._loss_fn(predictions[idx], target)
for idx in range(self.num_ensemble)
)
return loss / self.num_ensemble
def loss_phys(self, samples, equation):
"""
Computes the physics loss for the ensemble PINN solver by evaluating
the loss between the network's output and the true solution for each
model. This method should not be overridden, if not intentionally.
:param LabelTensor samples: The samples to evaluate the physics loss.
:param EquationInterface equation: The governing equation.
:return: The computed physics loss.
:rtype: LabelTensor
"""
return self._residual_loss(samples, equation)
def _residual_loss(self, samples, equation):
"""
Computes the physics loss for the physics-informed solver based on the
provided samples and equation. This method should never be overridden
by the user, if not intentionally,
since it is used internally to compute validation loss. It overrides the
:obj:`~pina.solver.physics_informed_solver.PINNInterface._residual_loss`
method.
:param LabelTensor samples: The samples to evaluate the loss.
:param EquationInterface equation: The governing equation.
:return: The residual loss.
:rtype: torch.Tensor
"""
loss = 0
predictions = self.forward(samples)
for idx in range(self.num_ensemble):
residuals = equation.residual(samples, predictions[idx])
target = torch.zeros_like(residuals, requires_grad=True)
loss = loss + self._loss_fn(residuals, target)
return loss / self.num_ensemble

View File

@@ -0,0 +1,152 @@
"""Module for the DeepEnsemble solver interface."""
import torch
from ..solver import MultiSolverInterface
from ...utils import check_consistency
class DeepEnsembleSolverInterface(MultiSolverInterface):
r"""
A class for handling ensemble models in a multi-solver training framework.
It allows for manual optimization, as well as the ability to train,
validate, and test multiple models as part of an ensemble.
The ensemble dimension can be customized to control how outputs are stacked.
By default, it is compatible with problems defined by
:class:`~pina.problem.abstract_problem.AbstractProblem`,
and users can choose the problem type the solver is meant to address.
An ensemble model is constructed by combining multiple models that solve
the same type of problem. Mathematically, this creates an implicit
distribution :math:`p(\mathbf{u} \mid \mathbf{s})` over the possible
outputs :math:`\mathbf{u}`, given the original input :math:`\mathbf{s}`.
The models :math:`\mathcal{M}_{i\in (1,\dots,r)}` in
the ensemble work collaboratively to capture different
aspects of the data or task, with each model contributing a distinct
prediction :math:`\mathbf{y}_{i}=\mathcal{M}_i(\mathbf{u} \mid \mathbf{s})`.
By aggregating these predictions, the ensemble
model can achieve greater robustness and accuracy compared to individual
models, leveraging the diversity of the models to reduce overfitting and
improve generalization. Furthemore, statistical metrics can
be computed, e.g. the ensemble mean and variance:
.. math::
\mathbf{\mu} = \frac{1}{N}\sum_{i=1}^r \mathbf{y}_{i}
.. math::
\mathbf{\sigma^2} = \frac{1}{N}\sum_{i=1}^r
(\mathbf{y}_{i} - \mathbf{\mu})^2
.. seealso::
**Original reference**: Lakshminarayanan, B., Pritzel, A., & Blundell,
C. (2017). *Simple and scalable predictive uncertainty estimation
using deep ensembles*. Advances in neural information
processing systems, 30.
DOI: `arXiv:1612.01474 <https://arxiv.org/abs/1612.01474>`_.
"""
def __init__(
self,
problem,
models,
optimizers=None,
schedulers=None,
weighting=None,
use_lt=True,
ensemble_dim=0,
):
"""
Initialization of the :class:`DeepEnsembleSolverInterface` class.
:param AbstractProblem problem: The problem to be solved.
:param torch.nn.Module models: The neural network models to be used.
:param Optimizer optimizer: The optimizer to be used.
If ``None``, the :class:`torch.optim.Adam` optimizer is used.
Default is ``None``.
:param Scheduler scheduler: Learning rate scheduler.
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If ``None``, no weighting schema is used. Default is ``None``.
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
Default is ``True``.
:param int ensemble_dim: The dimension along which the ensemble
outputs are stacked. Default is 0.
"""
super().__init__(
problem, models, optimizers, schedulers, weighting, use_lt
)
# check consistency
check_consistency(ensemble_dim, int)
self._ensemble_dim = ensemble_dim
def forward(self, x, ensemble_idx=None):
"""
Forward pass through the ensemble models. If an `ensemble_idx` is
provided, it returns the output of the specific model
corresponding to that index. If no index is given, it stacks the outputs
of all models along the ensemble dimension.
:param LabelTensor x: The input tensor to the models.
:param int ensemble_idx: Optional index to select a specific
model from the ensemble. If ``None`` results for all models are
stacked in ``ensemble_dim`` dimension. Default is ``None``.
:return: The output of the selected model or the stacked
outputs from all models.
:rtype: LabelTensor
"""
# if an index is passed, return the specific model output for that index
if ensemble_idx is not None:
return self.models[ensemble_idx].forward(x)
# otherwise return the stacked output
return torch.stack(
[self.forward(x, idx) for idx in range(self.num_ensemble)],
dim=self.ensemble_dim,
)
def training_step(self, batch):
"""
Training step for the solver, overridden for manual optimization.
This method performs a forward pass, calculates the loss, and applies
manual backward propagation and optimization steps for each model in
the ensemble.
:param list[tuple[str, dict]] batch: A batch of training data.
Each element is a tuple containing a condition name and a
dictionary of points.
:return: The aggregated loss after the training step.
:rtype: torch.Tensor
"""
# zero grad for optimizer
for opt in self.optimizers:
opt.instance.zero_grad()
# perform forward passes and aggregate losses
loss = super().training_step(batch)
# perform backpropagation
self.manual_backward(loss)
# optimize
for opt, sched in zip(self.optimizers, self.schedulers):
opt.instance.step()
sched.instance.step()
return loss
@property
def ensemble_dim(self):
"""
The dimension along which the ensemble outputs are stacked.
:return: The ensemble dimension.
:rtype: int
"""
return self._ensemble_dim
@property
def num_ensemble(self):
"""
The number of models in the ensemble.
:return: The number of models in the ensemble.
:rtype: int
"""
return len(self.models)

View File

@@ -0,0 +1,122 @@
"""Module for the DeepEnsemble supervised solver."""
from .ensemble_solver_interface import DeepEnsembleSolverInterface
from ..supervised_solver import SupervisedSolverInterface
class DeepEnsembleSupervisedSolver(
SupervisedSolverInterface, DeepEnsembleSolverInterface
):
r"""
Deep Ensemble Supervised Solver class. This class implements a
Deep Ensemble Supervised Solver using user specified ``model``s to solve
a specific ``problem``.
An ensemble model is constructed by combining multiple models that solve
the same type of problem. Mathematically, this creates an implicit
distribution :math:`p(\mathbf{u} \mid \mathbf{s})` over the possible
outputs :math:`\mathbf{u}`, given the original input :math:`\mathbf{s}`.
The models :math:`\mathcal{M}_{i\in (1,\dots,r)}` in
the ensemble work collaboratively to capture different
aspects of the data or task, with each model contributing a distinct
prediction :math:`\mathbf{y}_{i}=\mathcal{M}_i(\mathbf{u} \mid \mathbf{s})`.
By aggregating these predictions, the ensemble
model can achieve greater robustness and accuracy compared to individual
models, leveraging the diversity of the models to reduce overfitting and
improve generalization. Furthemore, statistical metrics can
be computed, e.g. the ensemble mean and variance:
.. math::
\mathbf{\mu} = \frac{1}{N}\sum_{i=1}^r \mathbf{y}_{i}
.. math::
\mathbf{\sigma^2} = \frac{1}{N}\sum_{i=1}^r
(\mathbf{y}_{i} - \mathbf{\mu})^2
During training the supervised loss is minimized by each ensemble model:
.. math::
\mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N
\mathcal{L}(\mathbf{u}_i - \mathcal{M}_{j}(\mathbf{s}_i)),
\quad j \in (1,\dots,N_{ensemble})
where :math:`\mathcal{L}` is a specific loss function, typically the MSE:
.. math::
\mathcal{L}(v) = \| v \|^2_2.
In this context, :math:`\mathbf{u}_i` and :math:`\mathbf{s}_i` indicates
the will to approximate multiple (discretised) functions given multiple
(discretised) input functions.
.. seealso::
**Original reference**: Lakshminarayanan, B., Pritzel, A., & Blundell,
C. (2017). *Simple and scalable predictive uncertainty estimation
using deep ensembles*. Advances in neural information
processing systems, 30.
DOI: `arXiv:1612.01474 <https://arxiv.org/abs/1612.01474>`_.
"""
def __init__(
self,
problem,
models,
loss=None,
optimizers=None,
schedulers=None,
weighting=None,
use_lt=False,
ensemble_dim=0,
):
"""
Initialization of the :class:`DeepEnsembleSupervisedSolver` class.
:param AbstractProblem problem: The problem to be solved.
:param torch.nn.Module models: The neural network models to be used.
:param torch.nn.Module loss: The loss function to be minimized.
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
Default is ``None``.
:param Optimizer optimizer: The optimizer to be used.
If ``None``, the :class:`torch.optim.Adam` optimizer is used.
Default is ``None``.
:param Scheduler scheduler: Learning rate scheduler.
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If ``None``, no weighting schema is used. Default is ``None``.
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
Default is ``True``.
:param int ensemble_dim: The dimension along which the ensemble
outputs are stacked. Default is 0.
"""
super().__init__(
problem=problem,
models=models,
loss=loss,
optimizers=optimizers,
schedulers=schedulers,
weighting=weighting,
use_lt=use_lt,
ensemble_dim=ensemble_dim,
)
def loss_data(self, input, target):
"""
Compute the data loss for the EnsembleSupervisedSolver by evaluating
the loss between the network's output and the true solution for each
model. This method should not be overridden, if not intentionally.
:param input: The input to the neural network.
:type input: LabelTensor | torch.Tensor | Graph | Data
:param target: The target to compare with the network's output.
:type target: LabelTensor | torch.Tensor | Graph | Data
:return: The supervised loss, averaged over the number of observations.
:rtype: torch.Tensor
"""
predictions = self.forward(input)
loss = sum(
self._loss_fn(predictions[idx], target)
for idx in range(self.num_ensemble)
)
return loss / self.num_ensemble

View File

@@ -48,18 +48,18 @@ class GAROM(MultiSolverInterface):
If ``None``, :class:`~pina.loss.power_loss.PowerLoss` with ``p=1``
is used. Default is ``None``.
:param Optimizer optimizer_generator: The optimizer for the generator.
If `None`, the :class:`torch.optim.Adam` optimizer is used.
If ``None``, the :class:`torch.optim.Adam` optimizer is used.
Default is ``None``.
:param Optimizer optimizer_discriminator: The optimizer for the
discriminator. If `None`, the :class:`torch.optim.Adam` optimizer is
used. Default is ``None``.
discriminator. If ``None``, the :class:`torch.optim.Adam`
optimizer is used. Default is ``None``.
:param Scheduler scheduler_generator: The learning rate scheduler for
the generator.
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param Scheduler scheduler_discriminator: The learning rate scheduler
for the discriminator.
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param float gamma: Ratio of expected loss for generator and
discriminator. Default is ``0.3``.
@@ -88,7 +88,7 @@ class GAROM(MultiSolverInterface):
check_consistency(
loss, (LossInterface, _Loss, torch.nn.Module), subclass=False
)
self._loss = loss
self._loss_fn = loss
# set automatic optimization for GANs
self.automatic_optimization = False
@@ -157,10 +157,11 @@ class GAROM(MultiSolverInterface):
generated_snapshots = self.sample(parameters)
# generator loss
r_loss = self._loss(snapshots, generated_snapshots)
r_loss = self._loss_fn(snapshots, generated_snapshots)
d_fake = self.discriminator([generated_snapshots, parameters])
g_loss = (
self._loss(d_fake, generated_snapshots) + self.regularizer * r_loss
self._loss_fn(d_fake, generated_snapshots)
+ self.regularizer * r_loss
)
# backward step
@@ -170,24 +171,6 @@ class GAROM(MultiSolverInterface):
return r_loss, g_loss
def on_train_batch_end(self, outputs, batch, batch_idx):
"""
This method is called at the end of each training batch and overrides
the PyTorch Lightning implementation to log checkpoints.
:param torch.Tensor outputs: The ``model``'s output for the current
batch.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:param int batch_idx: The index of the current batch.
"""
# increase by one the counter of optimization to save loggers
(
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed
) += 1
return super().on_train_batch_end(outputs, batch, batch_idx)
def _train_discriminator(self, parameters, snapshots):
"""
Train the discriminator model.
@@ -207,8 +190,8 @@ class GAROM(MultiSolverInterface):
d_fake = self.discriminator([generated_snapshots, parameters])
# evaluate loss
d_loss_real = self._loss(d_real, snapshots)
d_loss_fake = self._loss(d_fake, generated_snapshots.detach())
d_loss_real = self._loss_fn(d_real, snapshots)
d_loss_fake = self._loss_fn(d_fake, generated_snapshots.detach())
d_loss = d_loss_real - self.k * d_loss_fake
# backward step
@@ -288,7 +271,7 @@ class GAROM(MultiSolverInterface):
points["target"],
)
snapshots_gen = self.generator(parameters)
condition_loss[condition_name] = self._loss(
condition_loss[condition_name] = self._loss_fn(
snapshots, snapshots_gen
)
loss = self.weighting.aggregate(condition_loss)
@@ -311,7 +294,7 @@ class GAROM(MultiSolverInterface):
points["target"],
)
snapshots_gen = self.generator(parameters)
condition_loss[condition_name] = self._loss(
condition_loss[condition_name] = self._loss_fn(
snapshots, snapshots_gen
)
loss = self.weighting.aggregate(condition_loss)

View File

@@ -83,15 +83,15 @@ class CausalPINN(PINN):
:class:`~pina.problem.time_dependent_problem.TimeDependentProblem`.
:param torch.nn.Module model: The neural network model to be used.
:param Optimizer optimizer: The optimizer to be used.
If `None`, the :class:`torch.optim.Adam` optimizer is used.
If ``None``, the :class:`torch.optim.Adam` optimizer is used.
Default is ``None``.
:param torch.optim.LRScheduler scheduler: Learning rate scheduler.
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If `None`, no weighting schema is used. Default is ``None``.
If ``None``, no weighting schema is used. Default is ``None``.
:param torch.nn.Module loss: The loss function to be minimized.
If `None`, the :class:`torch.nn.MSELoss` loss is used.
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
Default is `None`.
:param float eps: The exponential decay parameter. Default is ``100``.
:raises ValueError: If the problem is not a TimeDependentProblem.
@@ -134,7 +134,7 @@ class CausalPINN(PINN):
chunk.labels = labels
# classical PINN loss
residual = self.compute_residual(samples=chunk, equation=equation)
loss_val = self.loss(
loss_val = self._loss_fn(
torch.zeros_like(residual, requires_grad=True), residual
)
time_loss.append(loss_val)

View File

@@ -69,26 +69,26 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
:param AbstractProblem problem: The problem to be solved.
:param torch.nn.Module model: The neural network model to be used.
:param torch.nn.Module discriminator: The discriminator to be used.
If `None`, the discriminator is a deepcopy of the ``model``.
If ``None``, the discriminator is a deepcopy of the ``model``.
Default is ``None``.
:param torch.optim.Optimizer optimizer_model: The optimizer of the
``model``. If `None`, the :class:`torch.optim.Adam` optimizer is
``model``. If ``None``, the :class:`torch.optim.Adam` optimizer is
used. Default is ``None``.
:param torch.optim.Optimizer optimizer_discriminator: The optimizer of
the ``discriminator``. If `None`, the :class:`torch.optim.Adam`
the ``discriminator``. If ``None``, the :class:`torch.optim.Adam`
optimizer is used. Default is ``None``.
:param Scheduler scheduler_model: Learning rate scheduler for the
``model``.
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param Scheduler scheduler_discriminator: Learning rate scheduler for
the ``discriminator``.
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If `None`, no weighting schema is used. Default is ``None``.
If ``None``, no weighting schema is used. Default is ``None``.
:param torch.nn.Module loss: The loss function to be minimized.
If `None`, the :class:`torch.nn.MSELoss` loss is used.
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
Default is `None`.
"""
if discriminator is None:
@@ -156,12 +156,27 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
residual = residual * discriminator_bets
# Compute competitive residual.
loss_val = self.loss(
loss_val = self._loss_fn(
torch.zeros_like(residual, requires_grad=True),
residual,
)
return loss_val
def loss_data(self, input, target):
"""
Compute the data loss for the PINN solver by evaluating the loss
between the network's output and the true solution. This method should
not be overridden, if not intentionally.
:param input: The input to the neural network.
:type input: LabelTensor
:param target: The target to compare with the network's output.
:type target: LabelTensor
:return: The supervised loss, averaged over the number of observations.
:rtype: LabelTensor
"""
return self._loss_fn(self.forward(input), target)
def configure_optimizers(self):
"""
Optimizer configuration.
@@ -195,24 +210,6 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
],
)
def on_train_batch_end(self, outputs, batch, batch_idx):
"""
This method is called at the end of each training batch and overrides
the PyTorch Lightning implementation to log checkpoints.
:param torch.Tensor outputs: The ``model``'s output for the current
batch.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:param int batch_idx: The index of the current batch.
"""
# increase by one the counter of optimization to save loggers
(
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed
) += 1
return super().on_train_batch_end(outputs, batch, batch_idx)
@property
def neural_net(self):
"""

View File

@@ -75,15 +75,15 @@ class GradientPINN(PINN):
gradient of the loss.
:param torch.nn.Module model: The neural network model to be used.
:param Optimizer optimizer: The optimizer to be used.
If `None`, the :class:`torch.optim.Adam` optimizer is used.
If ``None``, the :class:`torch.optim.Adam` optimizer is used.
Default is ``None``.
:param Scheduler scheduler: Learning rate scheduler.
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If `None`, no weighting schema is used. Default is ``None``.
If ``None``, no weighting schema is used. Default is ``None``.
:param torch.nn.Module loss: The loss function to be minimized.
If `None`, the :class:`torch.nn.MSELoss` loss is used.
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
Default is `None`.
:raises ValueError: If the problem is not a SpatialProblem.
"""
@@ -116,7 +116,7 @@ class GradientPINN(PINN):
"""
# classical PINN loss
residual = self.compute_residual(samples=samples, equation=equation)
loss_value = self.loss(
loss_value = self._loss_fn(
torch.zeros_like(residual, requires_grad=True), residual
)
@@ -124,7 +124,7 @@ class GradientPINN(PINN):
loss_value = loss_value.reshape(-1, 1)
loss_value.labels = ["__loss"]
loss_grad = grad(loss_value, samples, d=self.problem.spatial_variables)
g_loss_phys = self.loss(
g_loss_phys = self._loss_fn(
torch.zeros_like(loss_grad, requires_grad=True), loss_grad
)
return loss_value + g_loss_phys

View File

@@ -62,15 +62,15 @@ class PINN(PINNInterface, SingleSolverInterface):
:param AbstractProblem problem: The problem to be solved.
:param torch.nn.Module model: The neural network model to be used.
:param Optimizer optimizer: The optimizer to be used.
If `None`, the :class:`torch.optim.Adam` optimizer is used.
If ``None``, the :class:`torch.optim.Adam` optimizer is used.
Default is ``None``.
:param Scheduler scheduler: Learning rate scheduler.
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If `None`, no weighting schema is used. Default is ``None``.
If ``None``, no weighting schema is used. Default is ``None``.
:param torch.nn.Module loss: The loss function to be minimized.
If `None`, the :class:`torch.nn.MSELoss` loss is used.
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
Default is `None`.
"""
super().__init__(
@@ -82,6 +82,21 @@ class PINN(PINNInterface, SingleSolverInterface):
loss=loss,
)
def loss_data(self, input, target):
"""
Compute the data loss for the PINN solver by evaluating the loss
between the network's output and the true solution. This method should
not be overridden, if not intentionally.
:param input: The input to the neural network.
:type input: LabelTensor
:param target: The target to compare with the network's output.
:type target: LabelTensor
:return: The supervised loss, averaged over the number of observations.
:rtype: LabelTensor
"""
return self._loss_fn(self.forward(input), target)
def loss_phys(self, samples, equation):
"""
Computes the physics loss for the physics-informed solver based on the
@@ -92,11 +107,8 @@ class PINN(PINNInterface, SingleSolverInterface):
:return: The computed physics loss.
:rtype: LabelTensor
"""
residual = self.compute_residual(samples=samples, equation=equation)
loss_value = self.loss(
torch.zeros_like(residual, requires_grad=True), residual
)
return loss_value
residuals = self.compute_residual(samples, equation)
return self._loss_fn(residuals, torch.zeros_like(residuals))
def configure_optimizers(self):
"""

View File

@@ -38,7 +38,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
:param AbstractProblem problem: The problem to be solved.
:param torch.nn.Module loss: The loss function to be minimized.
If `None`, the :class:`torch.nn.MSELoss` loss is used.
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
Default is `None`.
:param kwargs: Additional keyword arguments to be passed to the
:class:`~pina.solver.solver.SolverInterface` class.
@@ -53,7 +53,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
check_consistency(loss, (LossInterface, _Loss), subclass=False)
# assign variables
self._loss = loss
self._loss_fn = loss
# inverse problem handling
if isinstance(self.problem, InverseProblem):
@@ -65,7 +65,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
self.__metric = None
def optimization_cycle(self, batch):
def optimization_cycle(self, batch, loss_residuals=None):
"""
The optimization cycle for the PINN solver.
@@ -80,51 +80,74 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
containing the condition name and the associated scalar loss.
:rtype: dict
"""
return self._run_optimization_cycle(batch, self.loss_phys)
# which losses to use
if loss_residuals is None:
loss_residuals = self.loss_phys
# compute optimization cycle
condition_loss = {}
for condition_name, points in batch:
self.__metric = condition_name
# if equations are passed
if "target" not in points:
input_pts = points["input"]
condition = self.problem.conditions[condition_name]
loss = loss_residuals(
input_pts.requires_grad_(), condition.equation
)
# if data are passed
else:
input_pts = points["input"]
output_pts = points["target"]
loss = self.loss_data(
input=input_pts.requires_grad_(), target=output_pts
)
# append loss
condition_loss[condition_name] = loss
# clamp unknown parameters in InverseProblem (if needed)
self._clamp_params()
return condition_loss
@torch.set_grad_enabled(True)
def validation_step(self, batch):
"""
The validation step for the PINN solver.
The validation step for the PINN solver. It returns the average residual
computed with the ``loss`` function not aggregated.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:return: The loss of the validation step.
:rtype: torch.Tensor
"""
losses = self._run_optimization_cycle(batch, self._residual_loss)
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
self.store_log("val_loss", loss, self.get_batch_size(batch))
return loss
return super().validation_step(
batch, loss_residuals=self._residual_loss
)
@torch.set_grad_enabled(True)
def test_step(self, batch):
"""
The test step for the PINN solver.
The test step for the PINN solver. It returns the average residual
computed with the ``loss`` function not aggregated.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:return: The loss of the test step.
:rtype: torch.Tensor
"""
losses = self._run_optimization_cycle(batch, self._residual_loss)
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
self.store_log("test_loss", loss, self.get_batch_size(batch))
return loss
return super().test_step(batch, loss_residuals=self._residual_loss)
def loss_data(self, input_pts, output_pts):
@abstractmethod
def loss_data(self, input, target):
"""
Compute the data loss for the PINN solver by evaluating the loss
between the network's output and the true solution. This method should
not be overridden, if not intentionally.
be overridden by the derived class.
:param LabelTensor input_pts: The input points to the neural network.
:param LabelTensor output_pts: The true solution to compare with the
:param LabelTensor input: The input to the neural network.
:param LabelTensor target: The target to compare with the
network's output.
:return: The supervised loss, averaged over the number of observations.
:rtype: torch.Tensor
:rtype: LabelTensor
"""
return self._loss(self.forward(input_pts), output_pts)
@abstractmethod
def loss_phys(self, samples, equation):
@@ -159,7 +182,11 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
def _residual_loss(self, samples, equation):
"""
Compute the residual loss.
Computes the physics loss for the physics-informed solver based on the
provided samples and equation. This method should never be overridden
by the user, if not intentionally,
since it is used internally to compute validation loss.
:param LabelTensor samples: The samples to evaluate the loss.
:param EquationInterface equation: The governing equation.
@@ -167,43 +194,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
:rtype: torch.Tensor
"""
residuals = self.compute_residual(samples, equation)
return self.loss(residuals, torch.zeros_like(residuals))
def _run_optimization_cycle(self, batch, loss_residuals):
"""
Compute, given a batch, the loss for each condition and return a
dictionary with the condition name as key and the loss as value.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:param function loss_residuals: The loss function to be minimized.
:return: The losses computed for all conditions in the batch, casted
to a subclass of :class:`torch.Tensor`. It should return a dict
containing the condition name and the associated scalar loss.
:rtype: dict
"""
condition_loss = {}
for condition_name, points in batch:
self.__metric = condition_name
# if equations are passed
if "target" not in points:
input_pts = points["input"]
condition = self.problem.conditions[condition_name]
loss = loss_residuals(
input_pts.requires_grad_(), condition.equation
)
# if data are passed
else:
input_pts = points["input"]
output_pts = points["target"]
loss = self.loss_data(
input_pts=input_pts.requires_grad_(), output_pts=output_pts
)
# append loss
condition_loss[condition_name] = loss
# clamp unknown parameters in InverseProblem (if needed)
self._clamp_params()
return condition_loss
return self._loss_fn(residuals, torch.zeros_like(residuals))
def _clamp_inverse_problem_params(self):
"""
@@ -223,7 +214,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
:return: The loss function used for training.
:rtype: torch.nn.Module
"""
return self._loss
return self._loss_fn
@property
def current_condition_name(self):

View File

@@ -83,15 +83,15 @@ class RBAPINN(PINN):
:param AbstractProblem problem: The problem to be solved.
:param torch.nn.Module model: The neural network model to be used.
:param Optimizer optimizer: The optimizer to be used.
If `None`, the :class:`torch.optim.Adam` optimizer is used.
If ``None``, the :class:`torch.optim.Adam` optimizer is used.
Default is ``None``.
:param Scheduler scheduler: Learning rate scheduler.
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If `None`, no weighting schema is used. Default is ``None``.
If ``None``, no weighting schema is used. Default is ``None``.
:param torch.nn.Module loss: The loss function to be minimized.
If `None`, the :class:`torch.nn.MSELoss` loss is used.
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
Default is `None`.
:param float | int eta: The learning rate for the weights of the
residuals. Default is ``0.001``.

View File

@@ -120,24 +120,24 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
:param torch.nn.Module weight_function: The Self-Adaptive mask model.
Default is ``torch.nn.Sigmoid()``.
:param Optimizer optimizer_model: The optimizer of the ``model``.
If `None`, the :class:`torch.optim.Adam` optimizer is used.
If ``None``, the :class:`torch.optim.Adam` optimizer is used.
Default is ``None``.
:param Optimizer optimizer_weights: The optimizer of the
``weight_function``.
If `None`, the :class:`torch.optim.Adam` optimizer is used.
If ``None``, the :class:`torch.optim.Adam` optimizer is used.
Default is ``None``.
:param Scheduler scheduler_model: Learning rate scheduler for the
``model``.
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param Scheduler scheduler_weights: Learning rate scheduler for the
``weight_function``.
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If `None`, no weighting schema is used. Default is ``None``.
If ``None``, no weighting schema is used. Default is ``None``.
:param torch.nn.Module loss: The loss function to be minimized.
If `None`, the :class:`torch.nn.MSELoss` loss is used.
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
Default is `None`.
"""
# check consistency weitghs_function
@@ -223,24 +223,6 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
[self.scheduler_model.instance, self.scheduler_weights.instance],
)
def on_train_batch_end(self, outputs, batch, batch_idx):
"""
This method is called at the end of each training batch and overrides
the PyTorch Lightning implementation to log checkpoints.
:param torch.Tensor outputs: The ``model``'s output for the current
batch.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:param int batch_idx: The index of the current batch.
"""
# increase by one the counter of optimization to save loggers
(
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed
) += 1
return super().on_train_batch_end(outputs, batch, batch_idx)
def on_train_start(self):
"""
This method is called at the start of the training process to set the
@@ -304,6 +286,21 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
)
return self._vect_to_scalar(weights * loss_value)
def loss_data(self, input, target):
"""
Compute the data loss for the PINN solver by evaluating the loss
between the network's output and the true solution. This method should
not be overridden, if not intentionally.
:param input: The input to the neural network.
:type input: LabelTensor
:param target: The target to compare with the network's output.
:type target: LabelTensor
:return: The supervised loss, averaged over the number of observations.
:rtype: LabelTensor
"""
return self._loss_fn(self.forward(input), target)
def _vect_to_scalar(self, loss_value):
"""
Computation of the scalar loss.

View File

@@ -4,7 +4,7 @@ from abc import ABCMeta, abstractmethod
import lightning
import torch
from torch._dynamo.eval_frame import OptimizedModule
from torch._dynamo import OptimizedModule
from ..problem import AbstractProblem
from ..optim import Optimizer, Scheduler, TorchOptimizer, TorchScheduler
from ..loss import WeightingInterface
@@ -29,7 +29,7 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
:param AbstractProblem problem: The problem to be solved.
:param WeightingInterface weighting: The weighting schema to be used.
If `None`, no weighting schema is used. Default is ``None``.
If ``None``, no weighting schema is used. Default is ``None``.
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
"""
super().__init__()
@@ -64,18 +64,20 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
self._pina_optimizers = None
self._pina_schedulers = None
def _check_solver_consistency(self, problem):
@abstractmethod
def forward(self, *args, **kwargs):
"""
Check the consistency of the solver with the problem formulation.
Abstract method for the forward pass implementation.
:param AbstractProblem problem: The problem to be solved.
:param args: The input tensor.
:type args: torch.Tensor | LabelTensor | Data | Graph
:param dict kwargs: Additional keyword arguments.
"""
for condition in problem.conditions.values():
check_consistency(condition, self.accepted_conditions_types)
def _optimization_cycle(self, batch):
@abstractmethod
def optimization_cycle(self, batch):
"""
Aggregate the loss for each condition in the batch.
The optimization cycle for the solvers.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
@@ -84,46 +86,58 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
containing the condition name and the associated scalar loss.
:rtype: dict
"""
losses = self.optimization_cycle(batch)
for name, value in losses.items():
self.store_log(
f"{name}_loss", value.item(), self.get_batch_size(batch)
)
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
return loss
def training_step(self, batch):
def training_step(self, batch, **kwargs):
"""
Solver training step.
Solver training step. It computes the optimization cycle and aggregates
the losses using the ``weighting`` attribute.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:param dict kwargs: Additional keyword arguments passed to
``optimization_cycle``.
:return: The loss of the training step.
:rtype: LabelTensor
:rtype: torch.Tensor
"""
loss = self._optimization_cycle(batch=batch)
loss = self._optimization_cycle(batch=batch, **kwargs)
self.store_log("train_loss", loss, self.get_batch_size(batch))
return loss
def validation_step(self, batch):
def validation_step(self, batch, **kwargs):
"""
Solver validation step.
Solver validation step. It computes the optimization cycle and
averages the losses. No aggregation using the ``weighting`` attribute is
performed.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:param dict kwargs: Additional keyword arguments passed to
``optimization_cycle``.
:return: The loss of the training step.
:rtype: torch.Tensor
"""
loss = self._optimization_cycle(batch=batch)
losses = self.optimization_cycle(batch=batch, **kwargs)
loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor)
self.store_log("val_loss", loss, self.get_batch_size(batch))
return loss
def test_step(self, batch):
def test_step(self, batch, **kwargs):
"""
Solver test step.
Solver test step. It computes the optimization cycle and
averages the losses. No aggregation using the ``weighting`` attribute is
performed.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:param dict kwargs: Additional keyword arguments passed to
``optimization_cycle``.
:return: The loss of the training step.
:rtype: torch.Tensor
"""
loss = self._optimization_cycle(batch=batch)
losses = self.optimization_cycle(batch=batch, **kwargs)
loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor)
self.store_log("test_loss", loss, self.get_batch_size(batch))
return loss
def store_log(self, name, value, batch_size):
"""
@@ -141,58 +155,118 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
**self.trainer.logging_kwargs,
)
@abstractmethod
def forward(self, *args, **kwargs):
def setup(self, stage):
"""
Abstract method for the forward pass implementation.
This method is called at the start of the train and test process to
compile the model if the :class:`~pina.trainer.Trainer`
``compile`` is ``True``.
:param args: The input tensor.
:type args: torch.Tensor | LabelTensor
:param dict kwargs: Additional keyword arguments.
"""
@abstractmethod
def optimization_cycle(self, batch):
"""
The optimization cycle for the solvers.
if stage == "fit" and self.trainer.compile:
self._setup_compile()
if stage == "test" and (
self.trainer.compile and not self._is_compiled()
):
self._setup_compile()
return super().setup(stage)
def _is_compiled(self):
"""
Check if the model is compiled.
:return: ``True`` if the model is compiled, ``False`` otherwise.
:rtype: bool
"""
for model in self._pina_models:
if not isinstance(model, OptimizedModule):
return False
return True
def _setup_compile(self):
"""
Compile all models in the solver using ``torch.compile``.
This method iterates through each model stored in the solver
list and attempts to compile them for optimized execution. It supports
models of type `torch.nn.Module` and `torch.nn.ModuleDict`. For models
stored in a `ModuleDict`, each submodule is compiled individually.
Models on Apple Silicon (MPS) use the 'eager' backend,
while others use 'inductor'.
:raises RuntimeError: If a model is neither `torch.nn.Module`
nor `torch.nn.ModuleDict`.
"""
for i, model in enumerate(self._pina_models):
if isinstance(model, torch.nn.ModuleDict):
for name, module in model.items():
self._pina_models[i][name] = self._compile_modules(module)
elif isinstance(model, torch.nn.Module):
self._pina_models[i] = self._compile_modules(model)
else:
raise RuntimeError(
"Compilation available only for "
"torch.nn.Module or torch.nn.ModuleDict."
)
def _check_solver_consistency(self, problem):
"""
Check the consistency of the solver with the problem formulation.
:param AbstractProblem problem: The problem to be solved.
"""
for condition in problem.conditions.values():
check_consistency(condition, self.accepted_conditions_types)
def _optimization_cycle(self, batch, **kwargs):
"""
Aggregate the loss for each condition in the batch.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:param dict kwargs: Additional keyword arguments passed to
``optimization_cycle``.
:return: The losses computed for all conditions in the batch, casted
to a subclass of :class:`torch.Tensor`. It should return a dict
containing the condition name and the associated scalar loss.
:rtype: dict
"""
losses = self.optimization_cycle(batch)
for name, value in losses.items():
self.store_log(
f"{name}_loss", value.item(), self.get_batch_size(batch)
)
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
return loss
@property
def problem(self):
@staticmethod
def _compile_modules(model):
"""
The problem instance.
Perform the compilation of the model.
:return: The problem instance.
:rtype: :class:`~pina.problem.abstract_problem.AbstractProblem`
"""
return self._pina_problem
This method attempts to compile the given PyTorch model
using ``torch.compile`` to improve execution performance. The
backend is selected based on the device on which the model resides:
``eager`` is used for MPS devices (Apple Silicon), and ``inductor``
is used for all others.
@property
def use_lt(self):
"""
Using LabelTensors as input during training.
If compilation fails, the method prints the error and returns the
original, uncompiled model.
:return: The use_lt attribute.
:rtype: bool
:param torch.nn.Module model: The model to compile.
:raises Exception: If the compilation fails.
:return: The compiled model.
:rtype: torch.nn.Module
"""
return self._use_lt
@property
def weighting(self):
"""
The weighting schema.
:return: The weighting schema.
:rtype: :class:`~pina.loss.weighting_interface.WeightingInterface`
"""
return self._pina_weighting
model_device = next(model.parameters()).device
try:
if model_device == torch.device("mps:0"):
model = torch.compile(model, backend="eager")
else:
model = torch.compile(model, backend="inductor")
except Exception as e:
print("Compilation failed, running in normal mode.:\n", e)
return model
@staticmethod
def get_batch_size(batch):
@@ -232,62 +306,35 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
return TorchScheduler(torch.optim.lr_scheduler.ConstantLR)
def on_train_start(self):
@property
def problem(self):
"""
This method is called at the start of the training process to compile
the model if the :class:`~pina.trainer.Trainer` ``compile`` is ``True``.
"""
super().on_train_start()
if self.trainer.compile:
self._compile_model()
The problem instance.
def on_test_start(self):
:return: The problem instance.
:rtype: :class:`~pina.problem.abstract_problem.AbstractProblem`
"""
This method is called at the start of the test process to compile
the model if the :class:`~pina.trainer.Trainer` ``compile`` is ``True``.
"""
super().on_train_start()
if self.trainer.compile and not self._check_already_compiled():
self._compile_model()
return self._pina_problem
def _check_already_compiled(self):
@property
def use_lt(self):
"""
Check if the model is already compiled.
Using LabelTensors as input during training.
:return: ``True`` if the model is already compiled, ``False`` otherwise.
:return: The use_lt attribute.
:rtype: bool
"""
return self._use_lt
models = self._pina_models
if len(models) == 1 and isinstance(
self._pina_models[0], torch.nn.ModuleDict
):
models = list(self._pina_models.values())
for model in models:
if not isinstance(model, (OptimizedModule, torch.nn.ModuleDict)):
return False
return True
@staticmethod
def _perform_compilation(model):
@property
def weighting(self):
"""
Perform the compilation of the model.
The weighting schema.
:param torch.nn.Module model: The model to compile.
:raises Exception: If the compilation fails.
:return: The compiled model.
:rtype: torch.nn.Module
:return: The weighting schema.
:rtype: :class:`~pina.loss.weighting_interface.WeightingInterface`
"""
model_device = next(model.parameters()).device
try:
if model_device == torch.device("mps:0"):
model = torch.compile(model, backend="eager")
else:
model = torch.compile(model, backend="inductor")
except Exception as e:
print("Compilation failed, running in normal mode.:\n", e)
return model
return self._pina_weighting
class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
@@ -310,13 +357,13 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
:param AbstractProblem problem: The problem to be solved.
:param torch.nn.Module model: The neural network model to be used.
:param Optimizer optimizer: The optimizer to be used.
If `None`, the :class:`torch.optim.Adam` optimizer is
If ``None``, the :class:`torch.optim.Adam` optimizer is
used. Default is ``None``.
:param Scheduler scheduler: The scheduler to be used.
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If `None`, no weighting schema is used. Default is ``None``.
If ``None``, no weighting schema is used. Default is ``None``.
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
"""
if optimizer is None:
@@ -344,12 +391,11 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
Forward pass implementation.
:param x: Input tensor.
:type x: torch.Tensor | LabelTensor
:type x: torch.Tensor | LabelTensor | Graph | Data
:return: Solver solution.
:rtype: torch.Tensor | LabelTensor
:rtype: torch.Tensor | LabelTensor | Graph | Data
"""
x = self.model(x)
return x
return self.model(x)
def configure_optimizers(self):
"""
@@ -362,28 +408,6 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
self.scheduler.hook(self.optimizer)
return ([self.optimizer.instance], [self.scheduler.instance])
def _compile_model(self):
"""
Compile the model.
"""
if isinstance(self._pina_models[0], torch.nn.ModuleDict):
self._compile_module_dict()
else:
self._compile_single_model()
def _compile_module_dict(self):
"""
Compile the model if it is a :class:`torch.nn.ModuleDict`.
"""
for name, model in self._pina_models[0].items():
self._pina_models[0][name] = self._perform_compilation(model)
def _compile_single_model(self):
"""
Compile the model if it is a single :class:`torch.nn.Module`.
"""
self._pina_models[0] = self._perform_compilation(self._pina_models[0])
@property
def model(self):
"""
@@ -436,13 +460,13 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
:param models: The neural network models to be used.
:type model: list[torch.nn.Module] | tuple[torch.nn.Module]
:param list[Optimizer] optimizers: The optimizers to be used.
If `None`, the :class:`torch.optim.Adam` optimizer is used for all
If ``None``, the :class:`torch.optim.Adam` optimizer is used for all
models. Default is ``None``.
:param list[Scheduler] schedulers: The schedulers to be used.
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used for all the models. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If `None`, no weighting schema is used. Default is ``None``.
If ``None``, no weighting schema is used. Default is ``None``.
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
:raises ValueError: If the models are not a list or tuple with length
greater than one.
@@ -519,6 +543,22 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
# http://lightning.ai/docs/pytorch/stable/model/manual_optimization.html
self.automatic_optimization = False
def on_train_batch_end(self, outputs, batch, batch_idx):
"""
This method is called at the end of each training batch and overrides
the PyTorch Lightning implementation to log checkpoints.
:param torch.Tensor outputs: The ``model``'s output for the current
batch.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:param int batch_idx: The index of the current batch.
"""
# increase by one the counter of optimization to save loggers
epoch_loop = self.trainer.fit_loop.epoch_loop
epoch_loop.manual_optimization.optim_step_progress.total.completed += 1
return super().on_train_batch_end(outputs, batch, batch_idx)
def configure_optimizers(self):
"""
Optimizer configuration for the solver.
@@ -537,14 +577,6 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
[scheduler.instance for scheduler in self.schedulers],
)
def _compile_model(self):
"""
Compile the model.
"""
for i, model in enumerate(self._pina_models):
if not isinstance(model, torch.nn.ModuleDict):
self._pina_models[i] = self._perform_compilation(model)
@property
def models(self):
"""

View File

@@ -1,132 +0,0 @@
"""Module for the Supervised solver."""
import torch
from torch.nn.modules.loss import _Loss
from .solver import SingleSolverInterface
from ..utils import check_consistency
from ..loss.loss_interface import LossInterface
from ..condition import InputTargetCondition
class SupervisedSolver(SingleSolverInterface):
r"""
Supervised Solver solver class. This class implements a Supervised Solver,
using a user specified ``model`` to solve a specific ``problem``.
The Supervised Solver class aims to find a map between the input
:math:`\mathbf{s}:\Omega\rightarrow\mathbb{R}^m` and the output
:math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`.
Given a model :math:`\mathcal{M}`, the following loss function is
minimized during training:
.. math::
\mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N
\mathcal{L}(\mathbf{u}_i - \mathcal{M}(\mathbf{v}_i)),
where :math:`\mathcal{L}` is a specific loss function, typically the MSE:
.. math::
\mathcal{L}(v) = \| v \|^2_2.
In this context, :math:`\mathbf{u}_i` and :math:`\mathbf{v}_i` indicates
the will to approximate multiple (discretised) functions given multiple
(discretised) input functions.
"""
accepted_conditions_types = InputTargetCondition
def __init__(
self,
problem,
model,
loss=None,
optimizer=None,
scheduler=None,
weighting=None,
use_lt=True,
):
"""
Initialization of the :class:`SupervisedSolver` class.
:param AbstractProblem problem: The problem to be solved.
:param torch.nn.Module model: The neural network model to be used.
:param torch.nn.Module loss: The loss function to be minimized.
If `None`, the :class:`torch.nn.MSELoss` loss is used.
Default is `None`.
:param Optimizer optimizer: The optimizer to be used.
If `None`, the :class:`torch.optim.Adam` optimizer is used.
Default is ``None``.
:param Scheduler scheduler: Learning rate scheduler.
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If `None`, no weighting schema is used. Default is ``None``.
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
Default is ``True``.
"""
if loss is None:
loss = torch.nn.MSELoss()
super().__init__(
model=model,
problem=problem,
optimizer=optimizer,
scheduler=scheduler,
weighting=weighting,
use_lt=use_lt,
)
# check consistency
check_consistency(
loss, (LossInterface, _Loss, torch.nn.Module), subclass=False
)
self._loss = loss
def optimization_cycle(self, batch):
"""
The optimization cycle for the solvers.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:return: The losses computed for all conditions in the batch, casted
to a subclass of :class:`torch.Tensor`. It should return a dict
containing the condition name and the associated scalar loss.
:rtype: dict
"""
condition_loss = {}
for condition_name, points in batch:
input_pts, output_pts = (
points["input"],
points["target"],
)
condition_loss[condition_name] = self.loss_data(
input_pts=input_pts, output_pts=output_pts
)
return condition_loss
def loss_data(self, input_pts, output_pts):
"""
Compute the data loss for the Supervised solver by evaluating the loss
between the network's output and the true solution. This method should
not be overridden, if not intentionally.
:param input_pts: The input points to the neural network.
:type input_pts: LabelTensor | torch.Tensor
:param output_pts: The true solution to compare with the network's
output.
:type output_pts: LabelTensor | torch.Tensor
:return: The supervised loss, averaged over the number of observations.
:rtype: torch.Tensor
"""
return self._loss(self.forward(input_pts), output_pts)
@property
def loss(self):
"""
The loss function to be minimized.
:return: The loss function to be minimized.
:rtype: torch.nn.Module
"""
return self._loss

View File

@@ -0,0 +1,11 @@
"""Module for the Supervised solvers."""
__all__ = [
"SupervisedSolverInterface",
"SupervisedSolver",
"ReducedOrderModelSolver",
]
from .supervised_solver_interface import SupervisedSolverInterface
from .supervised import SupervisedSolver
from .reduced_order_model import ReducedOrderModelSolver

View File

@@ -1,10 +1,11 @@
"""Module for the Reduced Order Model solver"""
import torch
from .supervised import SupervisedSolver
from .supervised_solver_interface import SupervisedSolverInterface
from ..solver import SingleSolverInterface
class ReducedOrderModelSolver(SupervisedSolver):
class ReducedOrderModelSolver(SupervisedSolverInterface, SingleSolverInterface):
r"""
Reduced Order Model solver class. This class implements the Reduced Order
Model solver, using user specified ``reduction_network`` and
@@ -51,6 +52,14 @@ class ReducedOrderModelSolver(SupervisedSolver):
DOI `10.1016/j.jcp.2018.02.037
<https://doi.org/10.1016/j.jcp.2018.02.037>`_.
Pichi, Federico, Beatriz Moya, and Jan S.
Hesthaven.
*A graph convolutional autoencoder approach to model order reduction
for parametrized PDEs.*
Journal of Computational Physics 501 (2024): 112762.
DOI `10.1016/j.jcp.2024.112762
<https://doi.org/10.1016/j.jcp.2024.112762>`_.
.. note::
The specified ``reduction_network`` must contain two methods, namely
``encode`` for input encoding, and ``decode`` for decoding the former
@@ -64,15 +73,6 @@ class ReducedOrderModelSolver(SupervisedSolver):
simultaneously. For reference on this trainig strategy look at the
following:
..seealso::
**Original reference**: Pichi, Federico, Beatriz Moya, and Jan S.
Hesthaven.
*A graph convolutional autoencoder approach to model order reduction
for parametrized PDEs.*
Journal of Computational Physics 501 (2024): 112762.
DOI `10.1016/j.jcp.2024.112762
<https://doi.org/10.1016/j.jcp.2024.112762>`_.
.. warning::
This solver works only for data-driven model. Hence in the ``problem``
definition the codition must only contain ``input``
@@ -102,16 +102,16 @@ class ReducedOrderModelSolver(SupervisedSolver):
for interpolating the control parameters to latent space obtained by
the ``reduction_network`` encoding.
:param torch.nn.Module loss: The loss function to be minimized.
If `None`, the :class:`torch.nn.MSELoss` loss is used.
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
Default is `None`.
:param Optimizer optimizer: The optimizer to be used.
If `None`, the :class:`torch.optim.Adam` optimizer is used.
If ``None``, the :class:`torch.optim.Adam` optimizer is used.
Default is ``None``.
:param Scheduler scheduler: Learning rate scheduler.
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If `None`, no weighting schema is used. Default is ``None``.
If ``None``, no weighting schema is used. Default is ``None``.
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
Default is ``True``.
"""
@@ -153,39 +153,38 @@ class ReducedOrderModelSolver(SupervisedSolver):
of the ``interpolation_network`` on the input, and maps it to output
space by calling the decode methode of the ``reduction_network``.
:param x: Input tensor.
:type x: torch.Tensor | LabelTensor
:return: Solver solution.
:rtype: torch.Tensor | LabelTensor
:param x: The input to the neural network.
:type x: LabelTensor | torch.Tensor | Graph | Data
:return: The solver solution.
:rtype: LabelTensor | torch.Tensor | Graph | Data
"""
reduction_network = self.model["reduction_network"]
interpolation_network = self.model["interpolation_network"]
return reduction_network.decode(interpolation_network(x))
def loss_data(self, input_pts, output_pts):
def loss_data(self, input, target):
"""
Compute the data loss by evaluating the loss between the network's
output and the true solution. This method should not be overridden, if
not intentionally.
:param LabelTensor input_pts: The input points to the neural network.
:param LabelTensor output_pts: The true solution to compare with the
network's output.
:param input: The input to the neural network.
:type input: LabelTensor | torch.Tensor | Graph | Data
:param target: The target to compare with the network's output.
:type target: LabelTensor | torch.Tensor | Graph | Data
:return: The supervised loss, averaged over the number of observations.
:rtype: torch.Tensor
:rtype: LabelTensor | torch.Tensor | Graph | Data
"""
# extract networks
reduction_network = self.model["reduction_network"]
interpolation_network = self.model["interpolation_network"]
# encoded representations loss
encode_repr_inter_net = interpolation_network(input_pts)
encode_repr_reduction_network = reduction_network.encode(output_pts)
loss_encode = self.loss(
encode_repr_inter_net = interpolation_network(input)
encode_repr_reduction_network = reduction_network.encode(target)
loss_encode = self._loss_fn(
encode_repr_inter_net, encode_repr_reduction_network
)
# reconstruction loss
loss_reconstruction = self.loss(
reduction_network.decode(encode_repr_reduction_network), output_pts
)
decode = reduction_network.decode(encode_repr_reduction_network)
loss_reconstruction = self._loss_fn(decode, target)
return loss_encode + loss_reconstruction

View File

@@ -0,0 +1,85 @@
"""Module for the Supervised solver."""
from .supervised_solver_interface import SupervisedSolverInterface
from ..solver import SingleSolverInterface
class SupervisedSolver(SupervisedSolverInterface, SingleSolverInterface):
r"""
Supervised Solver solver class. This class implements a Supervised Solver,
using a user specified ``model`` to solve a specific ``problem``.
The Supervised Solver class aims to find a map between the input
:math:`\mathbf{s}:\Omega\rightarrow\mathbb{R}^m` and the output
:math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`.
Given a model :math:`\mathcal{M}`, the following loss function is
minimized during training:
.. math::
\mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N
\mathcal{L}(\mathbf{u}_i - \mathcal{M}(\mathbf{s}_i)),
where :math:`\mathcal{L}` is a specific loss function, typically the MSE:
.. math::
\mathcal{L}(v) = \| v \|^2_2.
In this context, :math:`\mathbf{u}_i` and :math:`\mathbf{s}_i` indicates
the will to approximate multiple (discretised) functions given multiple
(discretised) input functions.
"""
def __init__(
self,
problem,
model,
loss=None,
optimizer=None,
scheduler=None,
weighting=None,
use_lt=True,
):
"""
Initialization of the :class:`SupervisedSolver` class.
:param AbstractProblem problem: The problem to be solved.
:param torch.nn.Module model: The neural network model to be used.
:param torch.nn.Module loss: The loss function to be minimized.
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
Default is `None`.
:param Optimizer optimizer: The optimizer to be used.
If ``None``, the :class:`torch.optim.Adam` optimizer is used.
Default is ``None``.
:param Scheduler scheduler: Learning rate scheduler.
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If ``None``, no weighting schema is used. Default is ``None``.
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
Default is ``True``.
"""
super().__init__(
model=model,
problem=problem,
loss=loss,
optimizer=optimizer,
scheduler=scheduler,
weighting=weighting,
use_lt=use_lt,
)
def loss_data(self, input, target):
"""
Compute the data loss for the Supervised solver by evaluating the loss
between the network's output and the true solution. This method should
not be overridden, if not intentionally.
:param input: The input to the neural network.
:type input: LabelTensor | torch.Tensor | Graph | Data
:param target: The target to compare with the network's output.
:type target: LabelTensor | torch.Tensor | Graph | Data
:return: The supervised loss, averaged over the number of observations.
:rtype: LabelTensor | torch.Tensor | Graph | Data
"""
return self._loss_fn(self.forward(input), target)

View File

@@ -0,0 +1,90 @@
"""Module for the Supervised solver interface."""
from abc import abstractmethod
import torch
from torch.nn.modules.loss import _Loss
from ..solver import SolverInterface
from ...utils import check_consistency
from ...loss.loss_interface import LossInterface
from ...condition import InputTargetCondition
class SupervisedSolverInterface(SolverInterface):
r"""
Base class for Supervised solvers. This class implements a Supervised Solver
, using a user specified ``model`` to solve a specific ``problem``.
The ``SupervisedSolverInterface`` class can be used to define
Supervised solvers that work with one or multiple optimizers and/or models.
By default, it is compatible with problems defined by
:class:`~pina.problem.abstract_problem.AbstractProblem`,
and users can choose the problem type the solver is meant to address.
"""
accepted_conditions_types = InputTargetCondition
def __init__(self, loss=None, **kwargs):
"""
Initialization of the :class:`SupervisedSolver` class.
:param AbstractProblem problem: The problem to be solved.
:param torch.nn.Module loss: The loss function to be minimized.
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
Default is `None`.
:param kwargs: Additional keyword arguments to be passed to the
:class:`~pina.solver.solver.SolverInterface` class.
"""
if loss is None:
loss = torch.nn.MSELoss()
super().__init__(**kwargs)
# check consistency
check_consistency(loss, (LossInterface, _Loss), subclass=False)
# assign variables
self._loss_fn = loss
def optimization_cycle(self, batch):
"""
The optimization cycle for the solvers.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:return: The losses computed for all conditions in the batch, casted
to a subclass of :class:`torch.Tensor`. It should return a dict
containing the condition name and the associated scalar loss.
:rtype: dict
"""
condition_loss = {}
for condition_name, points in batch:
condition_loss[condition_name] = self.loss_data(
input=points["input"], target=points["target"]
)
return condition_loss
@abstractmethod
def loss_data(self, input, target):
"""
Compute the data loss for the Supervised. This method is abstract and
should be override by derived classes.
:param input: The input to the neural network.
:type input: LabelTensor | torch.Tensor | Graph | Data
:param target: The target to compare with the network's output.
:type target: LabelTensor | torch.Tensor | Graph | Data
:return: The supervised loss, averaged over the number of observations.
:rtype: LabelTensor | torch.Tensor | Graph | Data
"""
@property
def loss(self):
"""
The loss function to be minimized.
:return: The loss function to be minimized.
:rtype: torch.nn.Module
"""
return self._loss_fn

View File

@@ -72,18 +72,22 @@ def labelize_forward(forward, input_variables, output_variables):
:rtype: Callable
"""
def wrapper(x):
def wrapper(x, *args, **kwargs):
"""
Decorated forward function.
:param LabelTensor x: The labelized input of the forward pass of an
instance of :class:`torch.nn.Module`.
:param Iterable args: Additional positional arguments passed to
``forward`` method.
:param dict kwargs: Additional keyword arguments passed to
``forward`` method.
:return: The labelized output of the forward pass of an instance of
:class:`torch.nn.Module`.
:rtype: LabelTensor
"""
x = x.extract(input_variables)
output = forward(x)
output = forward(x, *args, **kwargs)
# keep it like this, directly using LabelTensor(...) raises errors
# when compiling the code
output = output.as_subclass(LabelTensor)

View File

@@ -27,12 +27,12 @@ class DummySpatialProblem(SpatialProblem):
# define problems
problem = DiffusionReactionProblem()
problem.discretise_domain(50)
problem.discretise_domain(10)
# add input-output condition to test supervised learning
input_pts = torch.rand(50, len(problem.input_variables))
input_pts = torch.rand(10, len(problem.input_variables))
input_pts = LabelTensor(input_pts, problem.input_variables)
output_pts = torch.rand(50, len(problem.output_variables))
output_pts = torch.rand(10, len(problem.output_variables))
output_pts = LabelTensor(output_pts, problem.output_variables)
problem.conditions["data"] = Condition(input=input_pts, target=output_pts)

View File

@@ -19,9 +19,9 @@ from torch._dynamo.eval_frame import OptimizedModule
# define problems
problem = Poisson()
problem.discretise_domain(50)
problem.discretise_domain(10)
inverse_problem = InversePoisson()
inverse_problem.discretise_domain(50)
inverse_problem.discretise_domain(10)
# reduce the number of data points to speed up testing
data_condition = inverse_problem.conditions["data"]
@@ -29,9 +29,9 @@ data_condition.input = data_condition.input[:10]
data_condition.target = data_condition.target[:10]
# add input-output condition to test supervised learning
input_pts = torch.rand(50, len(problem.input_variables))
input_pts = torch.rand(10, len(problem.input_variables))
input_pts = LabelTensor(input_pts, problem.input_variables)
output_pts = torch.rand(50, len(problem.output_variables))
output_pts = torch.rand(10, len(problem.output_variables))
output_pts = LabelTensor(output_pts, problem.output_variables)
problem.conditions["data"] = Condition(input=input_pts, target=output_pts)

View File

@@ -0,0 +1,149 @@
import pytest
import torch
from pina import LabelTensor, Condition
from pina.model import FeedForward
from pina.trainer import Trainer
from pina.solver import DeepEnsemblePINN
from pina.condition import (
InputTargetCondition,
InputEquationCondition,
DomainEquationCondition,
)
from pina.problem.zoo import Poisson2DSquareProblem as Poisson
from torch._dynamo.eval_frame import OptimizedModule
# define problems
problem = Poisson()
problem.discretise_domain(10)
# add input-output condition to test supervised learning
input_pts = torch.rand(10, len(problem.input_variables))
input_pts = LabelTensor(input_pts, problem.input_variables)
output_pts = torch.rand(10, len(problem.output_variables))
output_pts = LabelTensor(output_pts, problem.output_variables)
problem.conditions["data"] = Condition(input=input_pts, target=output_pts)
# define models
models = [
FeedForward(
len(problem.input_variables), len(problem.output_variables), n_layers=1
)
for _ in range(5)
]
def test_constructor():
solver = DeepEnsemblePINN(problem=problem, models=models)
assert solver.accepted_conditions_types == (
InputTargetCondition,
InputEquationCondition,
DomainEquationCondition,
)
assert solver.num_ensemble == 5
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
@pytest.mark.parametrize("compile", [True, False])
def test_solver_train(batch_size, compile):
solver = DeepEnsemblePINN(models=models, problem=problem)
trainer = Trainer(
solver=solver,
max_epochs=2,
accelerator="cpu",
batch_size=batch_size,
train_size=1.0,
val_size=0.0,
test_size=0.0,
compile=compile,
)
trainer.train()
if trainer.compile:
assert all(
[isinstance(model, OptimizedModule) for model in solver.models]
)
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
@pytest.mark.parametrize("compile", [True, False])
def test_solver_validation(batch_size, compile):
solver = DeepEnsemblePINN(models=models, problem=problem)
trainer = Trainer(
solver=solver,
max_epochs=2,
accelerator="cpu",
batch_size=batch_size,
train_size=0.9,
val_size=0.1,
test_size=0.0,
compile=compile,
)
trainer.train()
if trainer.compile:
assert all(
[isinstance(model, OptimizedModule) for model in solver.models]
)
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
@pytest.mark.parametrize("compile", [True, False])
def test_solver_test(batch_size, compile):
solver = DeepEnsemblePINN(models=models, problem=problem)
trainer = Trainer(
solver=solver,
max_epochs=2,
accelerator="cpu",
batch_size=batch_size,
train_size=0.7,
val_size=0.2,
test_size=0.1,
compile=compile,
)
trainer.test()
if trainer.compile:
assert all(
[isinstance(model, OptimizedModule) for model in solver.models]
)
def test_train_load_restore():
dir = "tests/test_solver/tmp"
solver = DeepEnsemblePINN(models=models, problem=problem)
trainer = Trainer(
solver=solver,
max_epochs=5,
accelerator="cpu",
batch_size=None,
train_size=0.7,
val_size=0.2,
test_size=0.1,
default_root_dir=dir,
)
trainer.train()
# restore
new_trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
new_trainer.train(
ckpt_path=f"{dir}/lightning_logs/version_0/checkpoints/"
+ "epoch=4-step=5.ckpt"
)
# loading
new_solver = DeepEnsemblePINN.load_from_checkpoint(
f"{dir}/lightning_logs/version_0/checkpoints/epoch=4-step=5.ckpt",
problem=problem,
models=models,
)
test_pts = LabelTensor(torch.rand(20, 2), problem.input_variables)
assert new_solver.forward(test_pts).shape == solver.forward(test_pts).shape
torch.testing.assert_close(
new_solver.forward(test_pts), solver.forward(test_pts)
)
# rm directories
import shutil
shutil.rmtree("tests/test_solver/tmp")

View File

@@ -0,0 +1,275 @@
import torch
import pytest
from torch._dynamo.eval_frame import OptimizedModule
from torch_geometric.nn import GCNConv
from pina import Condition, LabelTensor
from pina.condition import InputTargetCondition
from pina.problem import AbstractProblem
from pina.solver import DeepEnsembleSupervisedSolver
from pina.model import FeedForward
from pina.trainer import Trainer
from pina.graph import KNNGraph
class LabelTensorProblem(AbstractProblem):
input_variables = ["u_0", "u_1"]
output_variables = ["u"]
conditions = {
"data": Condition(
input=LabelTensor(torch.randn(20, 2), ["u_0", "u_1"]),
target=LabelTensor(torch.randn(20, 1), ["u"]),
),
}
class TensorProblem(AbstractProblem):
input_variables = ["u_0", "u_1"]
output_variables = ["u"]
conditions = {
"data": Condition(input=torch.randn(20, 2), target=torch.randn(20, 1))
}
x = torch.rand((15, 20, 5))
pos = torch.rand((15, 20, 2))
output_ = torch.rand((15, 20, 1))
input_ = [
KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True)
for x_, pos_ in zip(x, pos)
]
class GraphProblem(AbstractProblem):
output_variables = None
conditions = {"data": Condition(input=input_, target=output_)}
x = LabelTensor(torch.rand((15, 20, 5)), ["a", "b", "c", "d", "e"])
pos = LabelTensor(torch.rand((15, 20, 2)), ["x", "y"])
output_ = LabelTensor(torch.rand((15, 20, 1)), ["u"])
input_ = [
KNNGraph(x=x[i], pos=pos[i], neighbours=3, edge_attr=True)
for i in range(len(x))
]
class GraphProblemLT(AbstractProblem):
output_variables = ["u"]
input_variables = ["a", "b", "c", "d", "e"]
conditions = {"data": Condition(input=input_, target=output_)}
models = [FeedForward(2, 1) for i in range(10)]
class Models(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lift = torch.nn.Linear(5, 10)
self.activation = torch.nn.Tanh()
self.output = torch.nn.Linear(10, 1)
self.conv = GCNConv(10, 10)
def forward(self, batch):
x = batch.x
edge_index = batch.edge_index
for _ in range(1):
y = self.lift(x)
y = self.activation(y)
y = self.conv(y, edge_index)
y = self.activation(y)
y = self.output(y)
return y
graph_models = [Models() for i in range(10)]
def test_constructor():
solver = DeepEnsembleSupervisedSolver(
problem=TensorProblem(), models=models
)
DeepEnsembleSupervisedSolver(problem=LabelTensorProblem(), models=models)
assert DeepEnsembleSupervisedSolver.accepted_conditions_types == (
InputTargetCondition
)
assert solver.num_ensemble == 10
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
@pytest.mark.parametrize("use_lt", [True, False])
@pytest.mark.parametrize("compile", [True, False])
def test_solver_train(use_lt, batch_size, compile):
problem = LabelTensorProblem() if use_lt else TensorProblem()
solver = DeepEnsembleSupervisedSolver(
problem=problem, models=models, use_lt=use_lt
)
trainer = Trainer(
solver=solver,
max_epochs=2,
accelerator="cpu",
batch_size=batch_size,
train_size=1.0,
test_size=0.0,
val_size=0.0,
compile=compile,
)
trainer.train()
if trainer.compile:
assert all(
[isinstance(model, OptimizedModule) for model in solver.models]
)
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
@pytest.mark.parametrize("use_lt", [True, False])
def test_solver_train_graph(batch_size, use_lt):
problem = GraphProblemLT() if use_lt else GraphProblem()
solver = DeepEnsembleSupervisedSolver(
problem=problem, models=graph_models, use_lt=use_lt
)
trainer = Trainer(
solver=solver,
max_epochs=2,
accelerator="cpu",
batch_size=batch_size,
train_size=1.0,
test_size=0.0,
val_size=0.0,
)
trainer.train()
@pytest.mark.parametrize("use_lt", [True, False])
@pytest.mark.parametrize("compile", [True, False])
def test_solver_validation(use_lt, compile):
problem = LabelTensorProblem() if use_lt else TensorProblem()
solver = DeepEnsembleSupervisedSolver(
problem=problem, models=models, use_lt=use_lt
)
trainer = Trainer(
solver=solver,
max_epochs=2,
accelerator="cpu",
batch_size=None,
train_size=0.9,
val_size=0.1,
test_size=0.0,
compile=compile,
)
trainer.train()
if trainer.compile:
assert all(
[isinstance(model, OptimizedModule) for model in solver.models]
)
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
@pytest.mark.parametrize("use_lt", [True, False])
def test_solver_validation_graph(batch_size, use_lt):
problem = GraphProblemLT() if use_lt else GraphProblem()
solver = DeepEnsembleSupervisedSolver(
problem=problem, models=graph_models, use_lt=use_lt
)
trainer = Trainer(
solver=solver,
max_epochs=2,
accelerator="cpu",
batch_size=batch_size,
train_size=0.9,
val_size=0.1,
test_size=0.0,
)
trainer.train()
@pytest.mark.parametrize("use_lt", [True, False])
@pytest.mark.parametrize("compile", [True, False])
def test_solver_test(use_lt, compile):
problem = LabelTensorProblem() if use_lt else TensorProblem()
solver = DeepEnsembleSupervisedSolver(
problem=problem, models=models, use_lt=use_lt
)
trainer = Trainer(
solver=solver,
max_epochs=2,
accelerator="cpu",
batch_size=None,
train_size=0.8,
val_size=0.1,
test_size=0.1,
compile=compile,
)
trainer.test()
if trainer.compile:
assert all(
[isinstance(model, OptimizedModule) for model in solver.models]
)
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
@pytest.mark.parametrize("use_lt", [True, False])
def test_solver_test_graph(batch_size, use_lt):
problem = GraphProblemLT() if use_lt else GraphProblem()
solver = DeepEnsembleSupervisedSolver(
problem=problem, models=graph_models, use_lt=use_lt
)
trainer = Trainer(
solver=solver,
max_epochs=2,
accelerator="cpu",
batch_size=batch_size,
train_size=0.8,
val_size=0.1,
test_size=0.1,
)
trainer.test()
def test_train_load_restore():
dir = "tests/test_solver/tmp/"
problem = LabelTensorProblem()
solver = DeepEnsembleSupervisedSolver(problem=problem, models=models)
trainer = Trainer(
solver=solver,
max_epochs=5,
accelerator="cpu",
batch_size=None,
train_size=0.9,
test_size=0.1,
val_size=0.0,
default_root_dir=dir,
)
trainer.train()
# restore
new_trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
new_trainer.train(
ckpt_path=f"{dir}/lightning_logs/version_0/checkpoints/"
+ "epoch=4-step=5.ckpt"
)
# loading
new_solver = DeepEnsembleSupervisedSolver.load_from_checkpoint(
f"{dir}/lightning_logs/version_0/checkpoints/epoch=4-step=5.ckpt",
problem=problem,
models=models,
)
test_pts = LabelTensor(torch.rand(20, 2), problem.input_variables)
assert new_solver.forward(test_pts).shape == solver.forward(test_pts).shape
torch.testing.assert_close(
new_solver.forward(test_pts), solver.forward(test_pts)
)
# rm directories
import shutil
shutil.rmtree("tests/test_solver/tmp")

View File

@@ -2,7 +2,7 @@ import torch
import torch.nn as nn
import pytest
from pina import Condition, LabelTensor
from pina import Condition
from pina.solver import GAROM
from pina.condition import InputTargetCondition
from pina.problem import AbstractProblem
@@ -15,7 +15,7 @@ class TensorProblem(AbstractProblem):
input_variables = ["u_0", "u_1"]
output_variables = ["u"]
conditions = {
"data": Condition(target=torch.randn(50, 2), input=torch.randn(50, 1))
"data": Condition(target=torch.randn(10, 2), input=torch.randn(10, 1))
}

View File

@@ -30,9 +30,9 @@ class DummyTimeProblem(TimeDependentProblem):
# define problems
problem = Poisson()
problem.discretise_domain(50)
problem.discretise_domain(10)
inverse_problem = InversePoisson()
inverse_problem.discretise_domain(50)
inverse_problem.discretise_domain(10)
# reduce the number of data points to speed up testing
data_condition = inverse_problem.conditions["data"]
@@ -40,9 +40,9 @@ data_condition.input = data_condition.input[:10]
data_condition.target = data_condition.target[:10]
# add input-output condition to test supervised learning
input_pts = torch.rand(50, len(problem.input_variables))
input_pts = torch.rand(10, len(problem.input_variables))
input_pts = LabelTensor(input_pts, problem.input_variables)
output_pts = torch.rand(50, len(problem.output_variables))
output_pts = torch.rand(10, len(problem.output_variables))
output_pts = LabelTensor(output_pts, problem.output_variables)
problem.conditions["data"] = Condition(input=input_pts, target=output_pts)

View File

@@ -19,9 +19,9 @@ from torch._dynamo.eval_frame import OptimizedModule
# define problems
problem = Poisson()
problem.discretise_domain(50)
problem.discretise_domain(10)
inverse_problem = InversePoisson()
inverse_problem.discretise_domain(50)
inverse_problem.discretise_domain(10)
# reduce the number of data points to speed up testing
data_condition = inverse_problem.conditions["data"]
@@ -29,9 +29,9 @@ data_condition.input = data_condition.input[:10]
data_condition.target = data_condition.target[:10]
# add input-output condition to test supervised learning
input_pts = torch.rand(50, len(problem.input_variables))
input_pts = torch.rand(10, len(problem.input_variables))
input_pts = LabelTensor(input_pts, problem.input_variables)
output_pts = torch.rand(50, len(problem.output_variables))
output_pts = torch.rand(10, len(problem.output_variables))
output_pts = LabelTensor(output_pts, problem.output_variables)
problem.conditions["data"] = Condition(input=input_pts, target=output_pts)

View File

@@ -18,9 +18,9 @@ from torch._dynamo.eval_frame import OptimizedModule
# define problems
problem = Poisson()
problem.discretise_domain(50)
problem.discretise_domain(10)
inverse_problem = InversePoisson()
inverse_problem.discretise_domain(50)
inverse_problem.discretise_domain(10)
# reduce the number of data points to speed up testing
data_condition = inverse_problem.conditions["data"]
@@ -28,9 +28,9 @@ data_condition.input = data_condition.input[:10]
data_condition.target = data_condition.target[:10]
# add input-output condition to test supervised learning
input_pts = torch.rand(50, len(problem.input_variables))
input_pts = torch.rand(10, len(problem.input_variables))
input_pts = LabelTensor(input_pts, problem.input_variables)
output_pts = torch.rand(50, len(problem.output_variables))
output_pts = torch.rand(10, len(problem.output_variables))
output_pts = LabelTensor(output_pts, problem.output_variables)
problem.conditions["data"] = Condition(input=input_pts, target=output_pts)

View File

@@ -19,9 +19,9 @@ from torch._dynamo.eval_frame import OptimizedModule
# define problems
problem = Poisson()
problem.discretise_domain(50)
problem.discretise_domain(10)
inverse_problem = InversePoisson()
inverse_problem.discretise_domain(50)
inverse_problem.discretise_domain(10)
# reduce the number of data points to speed up testing
data_condition = inverse_problem.conditions["data"]
@@ -29,9 +29,9 @@ data_condition.input = data_condition.input[:10]
data_condition.target = data_condition.target[:10]
# add input-output condition to test supervised learning
input_pts = torch.rand(50, len(problem.input_variables))
input_pts = torch.rand(10, len(problem.input_variables))
input_pts = LabelTensor(input_pts, problem.input_variables)
output_pts = torch.rand(50, len(problem.output_variables))
output_pts = torch.rand(10, len(problem.output_variables))
output_pts = LabelTensor(output_pts, problem.output_variables)
problem.conditions["data"] = Condition(input=input_pts, target=output_pts)

View File

@@ -30,9 +30,9 @@ class TensorProblem(AbstractProblem):
}
x = torch.rand((100, 20, 5))
pos = torch.rand((100, 20, 2))
output_ = torch.rand((100, 20, 1))
x = torch.rand((15, 20, 5))
pos = torch.rand((15, 20, 2))
output_ = torch.rand((15, 20, 1))
input_ = [
KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True)
for x_, pos_ in zip(x, pos)
@@ -44,9 +44,9 @@ class GraphProblem(AbstractProblem):
conditions = {"data": Condition(input=input_, target=output_)}
x = LabelTensor(torch.rand((100, 20, 5)), ["a", "b", "c", "d", "e"])
pos = LabelTensor(torch.rand((100, 20, 2)), ["x", "y"])
output_ = LabelTensor(torch.rand((100, 20, 1)), ["u"])
x = LabelTensor(torch.rand((15, 20, 5)), ["a", "b", "c", "d", "e"])
pos = LabelTensor(torch.rand((15, 20, 2)), ["x", "y"])
output_ = LabelTensor(torch.rand((15, 20, 1)), ["u"])
input_ = [
KNNGraph(x=x[i], pos=pos[i], neighbours=3, edge_attr=True)
for i in range(len(x))