Update MultiSolverInterface (#520)

This commit is contained in:
Dario Coscia
2025-04-07 14:13:26 +02:00
committed by FilippoOlivo
parent 578c5bc2f4
commit e250e3f5f7
3 changed files with 36 additions and 9 deletions

View File

@@ -14,9 +14,13 @@ from ..utils import check_consistency, labelize_forward
class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
"""
Abstract base class for PINA solvers. All specific solvers should inherit
from this interface. This class is a wrapper of
:class:`~lightning.pytorch.LightningModule`.
Abstract base class for PINA solvers. All specific solvers must inherit
from this interface. This class extends
:class:`~lightning.pytorch.core.LightningModule`, providing additional
functionalities for defining and optimizing Deep Learning models.
By inheriting from this base class, solvers gain access to built-in training
loops, logging utilities, and optimization techniques.
"""
def __init__(self, problem, weighting, use_lt):
@@ -442,6 +446,14 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
:raises ValueError: If the models are not a list or tuple with length
greater than one.
.. warning::
:class:`MultiSolverInterface` uses manual optimization by setting
``automatic_optimization=False`` in
:class:`~lightning.pytorch.core.LightningModule`. For more
information on manual optimization please
see `here <https://lightning.ai/docs/pytorch/stable/\
model/manual_optimization.html>`_.
"""
if not isinstance(models, (list, tuple)) or len(models) < 2:
raise ValueError(
@@ -450,6 +462,16 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
"one."
)
if optimizers is None:
optimizers = [
self.default_torch_optimizer() for _ in range(len(models))
]
if schedulers is None:
schedulers = [
self.default_torch_scheduler() for _ in range(len(models))
]
if any(opt is None for opt in optimizers):
optimizers = [
self.default_torch_optimizer() if opt is None else opt
@@ -480,12 +502,23 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
f"Got {len(models)} models, and {len(optimizers)}"
" optimizers."
)
if len(schedulers) != len(optimizers):
raise ValueError(
"You must define one scheduler for each optimizer."
f"Got {len(schedulers)} schedulers, and {len(optimizers)}"
" optimizers."
)
# initialize the model
self._pina_models = torch.nn.ModuleList(models)
self._pina_optimizers = optimizers
self._pina_schedulers = schedulers
# Set automatic optimization to False.
# For more information on manual optimization see:
# http://lightning.ai/docs/pytorch/stable/model/manual_optimization.html
self.automatic_optimization = False
def configure_optimizers(self):
"""
Optimizer configuration for the solver.