Update MultiSolverInterface (#520)

This commit is contained in:
Dario Coscia
2025-04-07 14:13:26 +02:00
committed by GitHub
parent 4357f8681f
commit 0a60ed4c9a
3 changed files with 36 additions and 9 deletions

View File

@@ -103,9 +103,6 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
loss=loss, loss=loss,
) )
# Set automatic optimization to False
self.automatic_optimization = False
def forward(self, x): def forward(self, x):
""" """
Forward pass. Forward pass.

View File

@@ -158,9 +158,6 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
loss=loss, loss=loss,
) )
# Set automatic optimization to False
self.automatic_optimization = False
self._vectorial_loss = deepcopy(self.loss) self._vectorial_loss = deepcopy(self.loss)
self._vectorial_loss.reduction = "none" self._vectorial_loss.reduction = "none"

View File

@@ -14,9 +14,13 @@ from ..utils import check_consistency, labelize_forward
class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
""" """
Abstract base class for PINA solvers. All specific solvers should inherit Abstract base class for PINA solvers. All specific solvers must inherit
from this interface. This class is a wrapper of from this interface. This class extends
:class:`~lightning.pytorch.LightningModule`. :class:`~lightning.pytorch.core.LightningModule`, providing additional
functionalities for defining and optimizing Deep Learning models.
By inheriting from this base class, solvers gain access to built-in training
loops, logging utilities, and optimization techniques.
""" """
def __init__(self, problem, weighting, use_lt): def __init__(self, problem, weighting, use_lt):
@@ -442,6 +446,14 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
:param bool use_lt: If ``True``, the solver uses LabelTensors as input. :param bool use_lt: If ``True``, the solver uses LabelTensors as input.
:raises ValueError: If the models are not a list or tuple with length :raises ValueError: If the models are not a list or tuple with length
greater than one. greater than one.
.. warning::
:class:`MultiSolverInterface` uses manual optimization by setting
``automatic_optimization=False`` in
:class:`~lightning.pytorch.core.LightningModule`. For more
information on manual optimization please
see `here <https://lightning.ai/docs/pytorch/stable/\
model/manual_optimization.html>`_.
""" """
if not isinstance(models, (list, tuple)) or len(models) < 2: if not isinstance(models, (list, tuple)) or len(models) < 2:
raise ValueError( raise ValueError(
@@ -450,6 +462,16 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
"one." "one."
) )
if optimizers is None:
optimizers = [
self.default_torch_optimizer() for _ in range(len(models))
]
if schedulers is None:
schedulers = [
self.default_torch_scheduler() for _ in range(len(models))
]
if any(opt is None for opt in optimizers): if any(opt is None for opt in optimizers):
optimizers = [ optimizers = [
self.default_torch_optimizer() if opt is None else opt self.default_torch_optimizer() if opt is None else opt
@@ -480,12 +502,23 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
f"Got {len(models)} models, and {len(optimizers)}" f"Got {len(models)} models, and {len(optimizers)}"
" optimizers." " optimizers."
) )
if len(schedulers) != len(optimizers):
raise ValueError(
"You must define one scheduler for each optimizer."
f"Got {len(schedulers)} schedulers, and {len(optimizers)}"
" optimizers."
)
# initialize the model # initialize the model
self._pina_models = torch.nn.ModuleList(models) self._pina_models = torch.nn.ModuleList(models)
self._pina_optimizers = optimizers self._pina_optimizers = optimizers
self._pina_schedulers = schedulers self._pina_schedulers = schedulers
# Set automatic optimization to False.
# For more information on manual optimization see:
# http://lightning.ai/docs/pytorch/stable/model/manual_optimization.html
self.automatic_optimization = False
def configure_optimizers(self): def configure_optimizers(self):
""" """
Optimizer configuration for the solver. Optimizer configuration for the solver.