Refactoring solvers (#541)

* Refactoring solvers

* Simplify logic compile
* Improve and update doc
* Create SupervisedSolverInterface
* Specialize SupervisedSolver and ReducedOrderModelSolver
* Create EnsembleSolverInterface + EnsembleSupervisedSolver
* Create tests ensemble solvers

* formatter

* codacy

* fix issues + speedup test
This commit is contained in:
Dario Coscia
2025-04-09 14:51:42 +02:00
committed by FilippoOlivo
parent fa6fda0bd5
commit 1bb3c125ac
37 changed files with 1514 additions and 510 deletions

View File

@@ -62,15 +62,15 @@ class PINN(PINNInterface, SingleSolverInterface):
:param AbstractProblem problem: The problem to be solved.
:param torch.nn.Module model: The neural network model to be used.
:param Optimizer optimizer: The optimizer to be used.
If `None`, the :class:`torch.optim.Adam` optimizer is used.
If ``None``, the :class:`torch.optim.Adam` optimizer is used.
Default is ``None``.
:param Scheduler scheduler: Learning rate scheduler.
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If `None`, no weighting schema is used. Default is ``None``.
If ``None``, no weighting schema is used. Default is ``None``.
:param torch.nn.Module loss: The loss function to be minimized.
If `None`, the :class:`torch.nn.MSELoss` loss is used.
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
Default is `None`.
"""
super().__init__(
@@ -82,6 +82,21 @@ class PINN(PINNInterface, SingleSolverInterface):
loss=loss,
)
def loss_data(self, input, target):
"""
Compute the data loss for the PINN solver by evaluating the loss
between the network's output and the true solution. This method should
not be overridden, if not intentionally.
:param input: The input to the neural network.
:type input: LabelTensor
:param target: The target to compare with the network's output.
:type target: LabelTensor
:return: The supervised loss, averaged over the number of observations.
:rtype: LabelTensor
"""
return self._loss_fn(self.forward(input), target)
def loss_phys(self, samples, equation):
"""
Computes the physics loss for the physics-informed solver based on the
@@ -92,11 +107,8 @@ class PINN(PINNInterface, SingleSolverInterface):
:return: The computed physics loss.
:rtype: LabelTensor
"""
residual = self.compute_residual(samples=samples, equation=equation)
loss_value = self.loss(
torch.zeros_like(residual, requires_grad=True), residual
)
return loss_value
residuals = self.compute_residual(samples, equation)
return self._loss_fn(residuals, torch.zeros_like(residuals))
def configure_optimizers(self):
"""