Update MultiSolverInterface (#520)
This commit is contained in:
@@ -103,9 +103,6 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
loss=loss,
|
||||
)
|
||||
|
||||
# Set automatic optimization to False
|
||||
self.automatic_optimization = False
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass.
|
||||
|
||||
@@ -158,9 +158,6 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
|
||||
loss=loss,
|
||||
)
|
||||
|
||||
# Set automatic optimization to False
|
||||
self.automatic_optimization = False
|
||||
|
||||
self._vectorial_loss = deepcopy(self.loss)
|
||||
self._vectorial_loss.reduction = "none"
|
||||
|
||||
|
||||
@@ -14,9 +14,13 @@ from ..utils import check_consistency, labelize_forward
|
||||
|
||||
class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
"""
|
||||
Abstract base class for PINA solvers. All specific solvers should inherit
|
||||
from this interface. This class is a wrapper of
|
||||
:class:`~lightning.pytorch.LightningModule`.
|
||||
Abstract base class for PINA solvers. All specific solvers must inherit
|
||||
from this interface. This class extends
|
||||
:class:`~lightning.pytorch.core.LightningModule`, providing additional
|
||||
functionalities for defining and optimizing Deep Learning models.
|
||||
|
||||
By inheriting from this base class, solvers gain access to built-in training
|
||||
loops, logging utilities, and optimization techniques.
|
||||
"""
|
||||
|
||||
def __init__(self, problem, weighting, use_lt):
|
||||
@@ -442,6 +446,14 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
|
||||
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
|
||||
:raises ValueError: If the models are not a list or tuple with length
|
||||
greater than one.
|
||||
|
||||
.. warning::
|
||||
:class:`MultiSolverInterface` uses manual optimization by setting
|
||||
``automatic_optimization=False`` in
|
||||
:class:`~lightning.pytorch.core.LightningModule`. For more
|
||||
information on manual optimization please
|
||||
see `here <https://lightning.ai/docs/pytorch/stable/\
|
||||
model/manual_optimization.html>`_.
|
||||
"""
|
||||
if not isinstance(models, (list, tuple)) or len(models) < 2:
|
||||
raise ValueError(
|
||||
@@ -450,6 +462,16 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
|
||||
"one."
|
||||
)
|
||||
|
||||
if optimizers is None:
|
||||
optimizers = [
|
||||
self.default_torch_optimizer() for _ in range(len(models))
|
||||
]
|
||||
|
||||
if schedulers is None:
|
||||
schedulers = [
|
||||
self.default_torch_scheduler() for _ in range(len(models))
|
||||
]
|
||||
|
||||
if any(opt is None for opt in optimizers):
|
||||
optimizers = [
|
||||
self.default_torch_optimizer() if opt is None else opt
|
||||
@@ -480,12 +502,23 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
|
||||
f"Got {len(models)} models, and {len(optimizers)}"
|
||||
" optimizers."
|
||||
)
|
||||
if len(schedulers) != len(optimizers):
|
||||
raise ValueError(
|
||||
"You must define one scheduler for each optimizer."
|
||||
f"Got {len(schedulers)} schedulers, and {len(optimizers)}"
|
||||
" optimizers."
|
||||
)
|
||||
|
||||
# initialize the model
|
||||
self._pina_models = torch.nn.ModuleList(models)
|
||||
self._pina_optimizers = optimizers
|
||||
self._pina_schedulers = schedulers
|
||||
|
||||
# Set automatic optimization to False.
|
||||
# For more information on manual optimization see:
|
||||
# http://lightning.ai/docs/pytorch/stable/model/manual_optimization.html
|
||||
self.automatic_optimization = False
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""
|
||||
Optimizer configuration for the solver.
|
||||
|
||||
Reference in New Issue
Block a user