Refactoring solvers (#541)

* Refactoring solvers

* Simplify logic compile
* Improve and update doc
* Create SupervisedSolverInterface
* Specialize SupervisedSolver and ReducedOrderModelSolver
* Create EnsembleSolverInterface + EnsembleSupervisedSolver
* Create tests ensemble solvers

* formatter

* codacy

* fix issues + speedup test
This commit is contained in:
Dario Coscia
2025-04-09 14:51:42 +02:00
parent 485c8dd789
commit 6dd7bd2825
37 changed files with 1514 additions and 510 deletions

View File

@@ -0,0 +1,11 @@
"""Module for the Ensemble solver classes."""
__all__ = [
"DeepEnsembleSolverInterface",
"DeepEnsembleSupervisedSolver",
"DeepEnsemblePINN",
]
from .ensemble_solver_interface import DeepEnsembleSolverInterface
from .ensemble_supervised import DeepEnsembleSupervisedSolver
from .ensemble_pinn import DeepEnsemblePINN

View File

@@ -0,0 +1,170 @@
"""Module for the DeepEnsemble physics solver."""
import torch
from .ensemble_solver_interface import DeepEnsembleSolverInterface
from ..physics_informed_solver import PINNInterface
from ...problem import InverseProblem
class DeepEnsemblePINN(PINNInterface, DeepEnsembleSolverInterface):
r"""
Deep Ensemble Physics Informed Solver class. This class implements a
Deep Ensemble for Physics Informed Neural Networks using user
specified ``model``s to solve a specific ``problem``.
An ensemble model is constructed by combining multiple models that solve
the same type of problem. Mathematically, this creates an implicit
distribution :math:`p(\mathbf{u} \mid \mathbf{s})` over the possible
outputs :math:`\mathbf{u}`, given the original input :math:`\mathbf{s}`.
The models :math:`\mathcal{M}_{i\in (1,\dots,r)}` in
the ensemble work collaboratively to capture different
aspects of the data or task, with each model contributing a distinct
prediction :math:`\mathbf{y}_{i}=\mathcal{M}_i(\mathbf{u} \mid \mathbf{s})`.
By aggregating these predictions, the ensemble
model can achieve greater robustness and accuracy compared to individual
models, leveraging the diversity of the models to reduce overfitting and
improve generalization. Furthemore, statistical metrics can
be computed, e.g. the ensemble mean and variance:
.. math::
\mathbf{\mu} = \frac{1}{N}\sum_{i=1}^r \mathbf{y}_{i}
.. math::
\mathbf{\sigma^2} = \frac{1}{N}\sum_{i=1}^r
(\mathbf{y}_{i} - \mathbf{\mu})^2
During training the PINN loss is minimized by each ensemble model:
.. math::
\mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^4
\mathcal{L}(\mathcal{A}[\mathbf{u}](\mathbf{x}_i)) +
\frac{1}{N}\sum_{i=1}^N
\mathcal{L}(\mathcal{B}[\mathbf{u}](\mathbf{x}_i)),
for the differential system:
.. math::
\begin{cases}
\mathcal{A}[\mathbf{u}](\mathbf{x})=0\quad,\mathbf{x}\in\Omega\\
\mathcal{B}[\mathbf{u}](\mathbf{x})=0\quad,
\mathbf{x}\in\partial\Omega
\end{cases}
:math:`\mathcal{L}` indicates a specific loss function, typically the MSE:
.. math::
\mathcal{L}(v) = \| v \|^2_2.
.. seealso::
**Original reference**: Zou, Z., Wang, Z., & Karniadakis, G. E. (2025).
*Learning and discovering multiple solutions using physics-informed
neural networks with random initialization and deep ensemble*.
DOI: `arXiv:2503.06320 <https://arxiv.org/abs/2503.06320>`_.
.. warning::
This solver does not work with inverse problem. Hence in the ``problem``
definition must not inherit from
:class:`~pina.problem.inverse_problem.InverseProblem`.
"""
def __init__(
self,
problem,
models,
loss=None,
optimizers=None,
schedulers=None,
weighting=None,
ensemble_dim=0,
):
"""
Initialization of the :class:`DeepEnsemblePINN` class.
:param AbstractProblem problem: The problem to be solved.
:param torch.nn.Module models: The neural network models to be used.
:param torch.nn.Module loss: The loss function to be minimized.
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
Default is ``None``.
:param Optimizer optimizer: The optimizer to be used.
If ``None``, the :class:`torch.optim.Adam` optimizer is used.
Default is ``None``.
:param Scheduler scheduler: Learning rate scheduler.
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If ``None``, no weighting schema is used. Default is ``None``.
:param int ensemble_dim: The dimension along which the ensemble
outputs are stacked. Default is 0.
:raises NotImplementedError: If an inverse problem is passed.
"""
if isinstance(problem, InverseProblem):
raise NotImplementedError(
"DeepEnsemblePINN can not be used to solve inverse problems."
)
super().__init__(
problem=problem,
models=models,
loss=loss,
optimizers=optimizers,
schedulers=schedulers,
weighting=weighting,
ensemble_dim=ensemble_dim,
)
def loss_data(self, input, target):
"""
Compute the data loss for the ensemble PINN solver by evaluating
the loss between the network's output and the true solution for each
model. This method should not be overridden, if not intentionally.
:param input: The input to the neural network.
:type input: LabelTensor | torch.Tensor | Graph | Data
:param target: The target to compare with the network's output.
:type target: LabelTensor | torch.Tensor | Graph | Data
:return: The supervised loss, averaged over the number of observations.
:rtype: torch.Tensor
"""
predictions = self.forward(input)
loss = sum(
self._loss_fn(predictions[idx], target)
for idx in range(self.num_ensemble)
)
return loss / self.num_ensemble
def loss_phys(self, samples, equation):
"""
Computes the physics loss for the ensemble PINN solver by evaluating
the loss between the network's output and the true solution for each
model. This method should not be overridden, if not intentionally.
:param LabelTensor samples: The samples to evaluate the physics loss.
:param EquationInterface equation: The governing equation.
:return: The computed physics loss.
:rtype: LabelTensor
"""
return self._residual_loss(samples, equation)
def _residual_loss(self, samples, equation):
"""
Computes the physics loss for the physics-informed solver based on the
provided samples and equation. This method should never be overridden
by the user, if not intentionally,
since it is used internally to compute validation loss. It overrides the
:obj:`~pina.solver.physics_informed_solver.PINNInterface._residual_loss`
method.
:param LabelTensor samples: The samples to evaluate the loss.
:param EquationInterface equation: The governing equation.
:return: The residual loss.
:rtype: torch.Tensor
"""
loss = 0
predictions = self.forward(samples)
for idx in range(self.num_ensemble):
residuals = equation.residual(samples, predictions[idx])
target = torch.zeros_like(residuals, requires_grad=True)
loss = loss + self._loss_fn(residuals, target)
return loss / self.num_ensemble

View File

@@ -0,0 +1,152 @@
"""Module for the DeepEnsemble solver interface."""
import torch
from ..solver import MultiSolverInterface
from ...utils import check_consistency
class DeepEnsembleSolverInterface(MultiSolverInterface):
r"""
A class for handling ensemble models in a multi-solver training framework.
It allows for manual optimization, as well as the ability to train,
validate, and test multiple models as part of an ensemble.
The ensemble dimension can be customized to control how outputs are stacked.
By default, it is compatible with problems defined by
:class:`~pina.problem.abstract_problem.AbstractProblem`,
and users can choose the problem type the solver is meant to address.
An ensemble model is constructed by combining multiple models that solve
the same type of problem. Mathematically, this creates an implicit
distribution :math:`p(\mathbf{u} \mid \mathbf{s})` over the possible
outputs :math:`\mathbf{u}`, given the original input :math:`\mathbf{s}`.
The models :math:`\mathcal{M}_{i\in (1,\dots,r)}` in
the ensemble work collaboratively to capture different
aspects of the data or task, with each model contributing a distinct
prediction :math:`\mathbf{y}_{i}=\mathcal{M}_i(\mathbf{u} \mid \mathbf{s})`.
By aggregating these predictions, the ensemble
model can achieve greater robustness and accuracy compared to individual
models, leveraging the diversity of the models to reduce overfitting and
improve generalization. Furthemore, statistical metrics can
be computed, e.g. the ensemble mean and variance:
.. math::
\mathbf{\mu} = \frac{1}{N}\sum_{i=1}^r \mathbf{y}_{i}
.. math::
\mathbf{\sigma^2} = \frac{1}{N}\sum_{i=1}^r
(\mathbf{y}_{i} - \mathbf{\mu})^2
.. seealso::
**Original reference**: Lakshminarayanan, B., Pritzel, A., & Blundell,
C. (2017). *Simple and scalable predictive uncertainty estimation
using deep ensembles*. Advances in neural information
processing systems, 30.
DOI: `arXiv:1612.01474 <https://arxiv.org/abs/1612.01474>`_.
"""
def __init__(
self,
problem,
models,
optimizers=None,
schedulers=None,
weighting=None,
use_lt=True,
ensemble_dim=0,
):
"""
Initialization of the :class:`DeepEnsembleSolverInterface` class.
:param AbstractProblem problem: The problem to be solved.
:param torch.nn.Module models: The neural network models to be used.
:param Optimizer optimizer: The optimizer to be used.
If ``None``, the :class:`torch.optim.Adam` optimizer is used.
Default is ``None``.
:param Scheduler scheduler: Learning rate scheduler.
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If ``None``, no weighting schema is used. Default is ``None``.
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
Default is ``True``.
:param int ensemble_dim: The dimension along which the ensemble
outputs are stacked. Default is 0.
"""
super().__init__(
problem, models, optimizers, schedulers, weighting, use_lt
)
# check consistency
check_consistency(ensemble_dim, int)
self._ensemble_dim = ensemble_dim
def forward(self, x, ensemble_idx=None):
"""
Forward pass through the ensemble models. If an `ensemble_idx` is
provided, it returns the output of the specific model
corresponding to that index. If no index is given, it stacks the outputs
of all models along the ensemble dimension.
:param LabelTensor x: The input tensor to the models.
:param int ensemble_idx: Optional index to select a specific
model from the ensemble. If ``None`` results for all models are
stacked in ``ensemble_dim`` dimension. Default is ``None``.
:return: The output of the selected model or the stacked
outputs from all models.
:rtype: LabelTensor
"""
# if an index is passed, return the specific model output for that index
if ensemble_idx is not None:
return self.models[ensemble_idx].forward(x)
# otherwise return the stacked output
return torch.stack(
[self.forward(x, idx) for idx in range(self.num_ensemble)],
dim=self.ensemble_dim,
)
def training_step(self, batch):
"""
Training step for the solver, overridden for manual optimization.
This method performs a forward pass, calculates the loss, and applies
manual backward propagation and optimization steps for each model in
the ensemble.
:param list[tuple[str, dict]] batch: A batch of training data.
Each element is a tuple containing a condition name and a
dictionary of points.
:return: The aggregated loss after the training step.
:rtype: torch.Tensor
"""
# zero grad for optimizer
for opt in self.optimizers:
opt.instance.zero_grad()
# perform forward passes and aggregate losses
loss = super().training_step(batch)
# perform backpropagation
self.manual_backward(loss)
# optimize
for opt, sched in zip(self.optimizers, self.schedulers):
opt.instance.step()
sched.instance.step()
return loss
@property
def ensemble_dim(self):
"""
The dimension along which the ensemble outputs are stacked.
:return: The ensemble dimension.
:rtype: int
"""
return self._ensemble_dim
@property
def num_ensemble(self):
"""
The number of models in the ensemble.
:return: The number of models in the ensemble.
:rtype: int
"""
return len(self.models)

View File

@@ -0,0 +1,122 @@
"""Module for the DeepEnsemble supervised solver."""
from .ensemble_solver_interface import DeepEnsembleSolverInterface
from ..supervised_solver import SupervisedSolverInterface
class DeepEnsembleSupervisedSolver(
SupervisedSolverInterface, DeepEnsembleSolverInterface
):
r"""
Deep Ensemble Supervised Solver class. This class implements a
Deep Ensemble Supervised Solver using user specified ``model``s to solve
a specific ``problem``.
An ensemble model is constructed by combining multiple models that solve
the same type of problem. Mathematically, this creates an implicit
distribution :math:`p(\mathbf{u} \mid \mathbf{s})` over the possible
outputs :math:`\mathbf{u}`, given the original input :math:`\mathbf{s}`.
The models :math:`\mathcal{M}_{i\in (1,\dots,r)}` in
the ensemble work collaboratively to capture different
aspects of the data or task, with each model contributing a distinct
prediction :math:`\mathbf{y}_{i}=\mathcal{M}_i(\mathbf{u} \mid \mathbf{s})`.
By aggregating these predictions, the ensemble
model can achieve greater robustness and accuracy compared to individual
models, leveraging the diversity of the models to reduce overfitting and
improve generalization. Furthemore, statistical metrics can
be computed, e.g. the ensemble mean and variance:
.. math::
\mathbf{\mu} = \frac{1}{N}\sum_{i=1}^r \mathbf{y}_{i}
.. math::
\mathbf{\sigma^2} = \frac{1}{N}\sum_{i=1}^r
(\mathbf{y}_{i} - \mathbf{\mu})^2
During training the supervised loss is minimized by each ensemble model:
.. math::
\mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N
\mathcal{L}(\mathbf{u}_i - \mathcal{M}_{j}(\mathbf{s}_i)),
\quad j \in (1,\dots,N_{ensemble})
where :math:`\mathcal{L}` is a specific loss function, typically the MSE:
.. math::
\mathcal{L}(v) = \| v \|^2_2.
In this context, :math:`\mathbf{u}_i` and :math:`\mathbf{s}_i` indicates
the will to approximate multiple (discretised) functions given multiple
(discretised) input functions.
.. seealso::
**Original reference**: Lakshminarayanan, B., Pritzel, A., & Blundell,
C. (2017). *Simple and scalable predictive uncertainty estimation
using deep ensembles*. Advances in neural information
processing systems, 30.
DOI: `arXiv:1612.01474 <https://arxiv.org/abs/1612.01474>`_.
"""
def __init__(
self,
problem,
models,
loss=None,
optimizers=None,
schedulers=None,
weighting=None,
use_lt=False,
ensemble_dim=0,
):
"""
Initialization of the :class:`DeepEnsembleSupervisedSolver` class.
:param AbstractProblem problem: The problem to be solved.
:param torch.nn.Module models: The neural network models to be used.
:param torch.nn.Module loss: The loss function to be minimized.
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
Default is ``None``.
:param Optimizer optimizer: The optimizer to be used.
If ``None``, the :class:`torch.optim.Adam` optimizer is used.
Default is ``None``.
:param Scheduler scheduler: Learning rate scheduler.
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If ``None``, no weighting schema is used. Default is ``None``.
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
Default is ``True``.
:param int ensemble_dim: The dimension along which the ensemble
outputs are stacked. Default is 0.
"""
super().__init__(
problem=problem,
models=models,
loss=loss,
optimizers=optimizers,
schedulers=schedulers,
weighting=weighting,
use_lt=use_lt,
ensemble_dim=ensemble_dim,
)
def loss_data(self, input, target):
"""
Compute the data loss for the EnsembleSupervisedSolver by evaluating
the loss between the network's output and the true solution for each
model. This method should not be overridden, if not intentionally.
:param input: The input to the neural network.
:type input: LabelTensor | torch.Tensor | Graph | Data
:param target: The target to compare with the network's output.
:type target: LabelTensor | torch.Tensor | Graph | Data
:return: The supervised loss, averaged over the number of observations.
:rtype: torch.Tensor
"""
predictions = self.forward(input)
loss = sum(
self._loss_fn(predictions[idx], target)
for idx in range(self.num_ensemble)
)
return loss / self.num_ensemble