Refactoring solvers (#541)

* Refactoring solvers

* Simplify logic compile
* Improve and update doc
* Create SupervisedSolverInterface
* Specialize SupervisedSolver and ReducedOrderModelSolver
* Create EnsembleSolverInterface + EnsembleSupervisedSolver
* Create tests ensemble solvers

* formatter

* codacy

* fix issues + speedup test
This commit is contained in:
Dario Coscia
2025-04-09 14:51:42 +02:00
committed by FilippoOlivo
parent fa6fda0bd5
commit 1bb3c125ac
37 changed files with 1514 additions and 510 deletions

View File

@@ -120,24 +120,24 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
:param torch.nn.Module weight_function: The Self-Adaptive mask model.
Default is ``torch.nn.Sigmoid()``.
:param Optimizer optimizer_model: The optimizer of the ``model``.
If `None`, the :class:`torch.optim.Adam` optimizer is used.
If ``None``, the :class:`torch.optim.Adam` optimizer is used.
Default is ``None``.
:param Optimizer optimizer_weights: The optimizer of the
``weight_function``.
If `None`, the :class:`torch.optim.Adam` optimizer is used.
If ``None``, the :class:`torch.optim.Adam` optimizer is used.
Default is ``None``.
:param Scheduler scheduler_model: Learning rate scheduler for the
``model``.
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param Scheduler scheduler_weights: Learning rate scheduler for the
``weight_function``.
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If `None`, no weighting schema is used. Default is ``None``.
If ``None``, no weighting schema is used. Default is ``None``.
:param torch.nn.Module loss: The loss function to be minimized.
If `None`, the :class:`torch.nn.MSELoss` loss is used.
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
Default is `None`.
"""
# check consistency weitghs_function
@@ -223,24 +223,6 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
[self.scheduler_model.instance, self.scheduler_weights.instance],
)
def on_train_batch_end(self, outputs, batch, batch_idx):
"""
This method is called at the end of each training batch and overrides
the PyTorch Lightning implementation to log checkpoints.
:param torch.Tensor outputs: The ``model``'s output for the current
batch.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:param int batch_idx: The index of the current batch.
"""
# increase by one the counter of optimization to save loggers
(
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed
) += 1
return super().on_train_batch_end(outputs, batch, batch_idx)
def on_train_start(self):
"""
This method is called at the start of the training process to set the
@@ -304,6 +286,21 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
)
return self._vect_to_scalar(weights * loss_value)
def loss_data(self, input, target):
"""
Compute the data loss for the PINN solver by evaluating the loss
between the network's output and the true solution. This method should
not be overridden, if not intentionally.
:param input: The input to the neural network.
:type input: LabelTensor
:param target: The target to compare with the network's output.
:type target: LabelTensor
:return: The supervised loss, averaged over the number of observations.
:rtype: LabelTensor
"""
return self._loss_fn(self.forward(input), target)
def _vect_to_scalar(self, loss_value):
"""
Computation of the scalar loss.