Update MultiSolverInterface (#520)
This commit is contained in:
@@ -103,9 +103,6 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
|||||||
loss=loss,
|
loss=loss,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set automatic optimization to False
|
|
||||||
self.automatic_optimization = False
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""
|
"""
|
||||||
Forward pass.
|
Forward pass.
|
||||||
|
|||||||
@@ -158,9 +158,6 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
|
|||||||
loss=loss,
|
loss=loss,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set automatic optimization to False
|
|
||||||
self.automatic_optimization = False
|
|
||||||
|
|
||||||
self._vectorial_loss = deepcopy(self.loss)
|
self._vectorial_loss = deepcopy(self.loss)
|
||||||
self._vectorial_loss.reduction = "none"
|
self._vectorial_loss.reduction = "none"
|
||||||
|
|
||||||
|
|||||||
@@ -14,9 +14,13 @@ from ..utils import check_consistency, labelize_forward
|
|||||||
|
|
||||||
class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||||
"""
|
"""
|
||||||
Abstract base class for PINA solvers. All specific solvers should inherit
|
Abstract base class for PINA solvers. All specific solvers must inherit
|
||||||
from this interface. This class is a wrapper of
|
from this interface. This class extends
|
||||||
:class:`~lightning.pytorch.LightningModule`.
|
:class:`~lightning.pytorch.core.LightningModule`, providing additional
|
||||||
|
functionalities for defining and optimizing Deep Learning models.
|
||||||
|
|
||||||
|
By inheriting from this base class, solvers gain access to built-in training
|
||||||
|
loops, logging utilities, and optimization techniques.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, problem, weighting, use_lt):
|
def __init__(self, problem, weighting, use_lt):
|
||||||
@@ -442,6 +446,14 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
|
|||||||
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
|
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
|
||||||
:raises ValueError: If the models are not a list or tuple with length
|
:raises ValueError: If the models are not a list or tuple with length
|
||||||
greater than one.
|
greater than one.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
:class:`MultiSolverInterface` uses manual optimization by setting
|
||||||
|
``automatic_optimization=False`` in
|
||||||
|
:class:`~lightning.pytorch.core.LightningModule`. For more
|
||||||
|
information on manual optimization please
|
||||||
|
see `here <https://lightning.ai/docs/pytorch/stable/\
|
||||||
|
model/manual_optimization.html>`_.
|
||||||
"""
|
"""
|
||||||
if not isinstance(models, (list, tuple)) or len(models) < 2:
|
if not isinstance(models, (list, tuple)) or len(models) < 2:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -450,6 +462,16 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
|
|||||||
"one."
|
"one."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if optimizers is None:
|
||||||
|
optimizers = [
|
||||||
|
self.default_torch_optimizer() for _ in range(len(models))
|
||||||
|
]
|
||||||
|
|
||||||
|
if schedulers is None:
|
||||||
|
schedulers = [
|
||||||
|
self.default_torch_scheduler() for _ in range(len(models))
|
||||||
|
]
|
||||||
|
|
||||||
if any(opt is None for opt in optimizers):
|
if any(opt is None for opt in optimizers):
|
||||||
optimizers = [
|
optimizers = [
|
||||||
self.default_torch_optimizer() if opt is None else opt
|
self.default_torch_optimizer() if opt is None else opt
|
||||||
@@ -480,12 +502,23 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
|
|||||||
f"Got {len(models)} models, and {len(optimizers)}"
|
f"Got {len(models)} models, and {len(optimizers)}"
|
||||||
" optimizers."
|
" optimizers."
|
||||||
)
|
)
|
||||||
|
if len(schedulers) != len(optimizers):
|
||||||
|
raise ValueError(
|
||||||
|
"You must define one scheduler for each optimizer."
|
||||||
|
f"Got {len(schedulers)} schedulers, and {len(optimizers)}"
|
||||||
|
" optimizers."
|
||||||
|
)
|
||||||
|
|
||||||
# initialize the model
|
# initialize the model
|
||||||
self._pina_models = torch.nn.ModuleList(models)
|
self._pina_models = torch.nn.ModuleList(models)
|
||||||
self._pina_optimizers = optimizers
|
self._pina_optimizers = optimizers
|
||||||
self._pina_schedulers = schedulers
|
self._pina_schedulers = schedulers
|
||||||
|
|
||||||
|
# Set automatic optimization to False.
|
||||||
|
# For more information on manual optimization see:
|
||||||
|
# http://lightning.ai/docs/pytorch/stable/model/manual_optimization.html
|
||||||
|
self.automatic_optimization = False
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
"""
|
"""
|
||||||
Optimizer configuration for the solver.
|
Optimizer configuration for the solver.
|
||||||
|
|||||||
Reference in New Issue
Block a user