Refactoring solvers (#541)
* Refactoring solvers * Simplify logic compile * Improve and update doc * Create SupervisedSolverInterface * Specialize SupervisedSolver and ReducedOrderModelSolver * Create EnsembleSolverInterface + EnsembleSupervisedSolver * Create tests ensemble solvers * formatter * codacy * fix issues + speedup test
This commit is contained in:
committed by
FilippoOlivo
parent
fa6fda0bd5
commit
1bb3c125ac
@@ -4,7 +4,7 @@ from abc import ABCMeta, abstractmethod
|
||||
import lightning
|
||||
import torch
|
||||
|
||||
from torch._dynamo.eval_frame import OptimizedModule
|
||||
from torch._dynamo import OptimizedModule
|
||||
from ..problem import AbstractProblem
|
||||
from ..optim import Optimizer, Scheduler, TorchOptimizer, TorchScheduler
|
||||
from ..loss import WeightingInterface
|
||||
@@ -29,7 +29,7 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
|
||||
:param AbstractProblem problem: The problem to be solved.
|
||||
:param WeightingInterface weighting: The weighting schema to be used.
|
||||
If `None`, no weighting schema is used. Default is ``None``.
|
||||
If ``None``, no weighting schema is used. Default is ``None``.
|
||||
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
|
||||
"""
|
||||
super().__init__()
|
||||
@@ -64,18 +64,20 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
self._pina_optimizers = None
|
||||
self._pina_schedulers = None
|
||||
|
||||
def _check_solver_consistency(self, problem):
|
||||
@abstractmethod
|
||||
def forward(self, *args, **kwargs):
|
||||
"""
|
||||
Check the consistency of the solver with the problem formulation.
|
||||
Abstract method for the forward pass implementation.
|
||||
|
||||
:param AbstractProblem problem: The problem to be solved.
|
||||
:param args: The input tensor.
|
||||
:type args: torch.Tensor | LabelTensor | Data | Graph
|
||||
:param dict kwargs: Additional keyword arguments.
|
||||
"""
|
||||
for condition in problem.conditions.values():
|
||||
check_consistency(condition, self.accepted_conditions_types)
|
||||
|
||||
def _optimization_cycle(self, batch):
|
||||
@abstractmethod
|
||||
def optimization_cycle(self, batch):
|
||||
"""
|
||||
Aggregate the loss for each condition in the batch.
|
||||
The optimization cycle for the solvers.
|
||||
|
||||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||
tuple containing a condition name and a dictionary of points.
|
||||
@@ -84,46 +86,58 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
containing the condition name and the associated scalar loss.
|
||||
:rtype: dict
|
||||
"""
|
||||
losses = self.optimization_cycle(batch)
|
||||
for name, value in losses.items():
|
||||
self.store_log(
|
||||
f"{name}_loss", value.item(), self.get_batch_size(batch)
|
||||
)
|
||||
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
|
||||
return loss
|
||||
|
||||
def training_step(self, batch):
|
||||
def training_step(self, batch, **kwargs):
|
||||
"""
|
||||
Solver training step.
|
||||
Solver training step. It computes the optimization cycle and aggregates
|
||||
the losses using the ``weighting`` attribute.
|
||||
|
||||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||
tuple containing a condition name and a dictionary of points.
|
||||
:param dict kwargs: Additional keyword arguments passed to
|
||||
``optimization_cycle``.
|
||||
:return: The loss of the training step.
|
||||
:rtype: LabelTensor
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
loss = self._optimization_cycle(batch=batch)
|
||||
loss = self._optimization_cycle(batch=batch, **kwargs)
|
||||
self.store_log("train_loss", loss, self.get_batch_size(batch))
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch):
|
||||
def validation_step(self, batch, **kwargs):
|
||||
"""
|
||||
Solver validation step.
|
||||
Solver validation step. It computes the optimization cycle and
|
||||
averages the losses. No aggregation using the ``weighting`` attribute is
|
||||
performed.
|
||||
|
||||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||
tuple containing a condition name and a dictionary of points.
|
||||
:param dict kwargs: Additional keyword arguments passed to
|
||||
``optimization_cycle``.
|
||||
:return: The loss of the training step.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
loss = self._optimization_cycle(batch=batch)
|
||||
losses = self.optimization_cycle(batch=batch, **kwargs)
|
||||
loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor)
|
||||
self.store_log("val_loss", loss, self.get_batch_size(batch))
|
||||
return loss
|
||||
|
||||
def test_step(self, batch):
|
||||
def test_step(self, batch, **kwargs):
|
||||
"""
|
||||
Solver test step.
|
||||
Solver test step. It computes the optimization cycle and
|
||||
averages the losses. No aggregation using the ``weighting`` attribute is
|
||||
performed.
|
||||
|
||||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||
tuple containing a condition name and a dictionary of points.
|
||||
:param dict kwargs: Additional keyword arguments passed to
|
||||
``optimization_cycle``.
|
||||
:return: The loss of the training step.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
loss = self._optimization_cycle(batch=batch)
|
||||
losses = self.optimization_cycle(batch=batch, **kwargs)
|
||||
loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor)
|
||||
self.store_log("test_loss", loss, self.get_batch_size(batch))
|
||||
return loss
|
||||
|
||||
def store_log(self, name, value, batch_size):
|
||||
"""
|
||||
@@ -141,58 +155,118 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
**self.trainer.logging_kwargs,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, *args, **kwargs):
|
||||
def setup(self, stage):
|
||||
"""
|
||||
Abstract method for the forward pass implementation.
|
||||
This method is called at the start of the train and test process to
|
||||
compile the model if the :class:`~pina.trainer.Trainer`
|
||||
``compile`` is ``True``.
|
||||
|
||||
:param args: The input tensor.
|
||||
:type args: torch.Tensor | LabelTensor
|
||||
:param dict kwargs: Additional keyword arguments.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def optimization_cycle(self, batch):
|
||||
"""
|
||||
The optimization cycle for the solvers.
|
||||
if stage == "fit" and self.trainer.compile:
|
||||
self._setup_compile()
|
||||
if stage == "test" and (
|
||||
self.trainer.compile and not self._is_compiled()
|
||||
):
|
||||
self._setup_compile()
|
||||
return super().setup(stage)
|
||||
|
||||
def _is_compiled(self):
|
||||
"""
|
||||
Check if the model is compiled.
|
||||
|
||||
:return: ``True`` if the model is compiled, ``False`` otherwise.
|
||||
:rtype: bool
|
||||
"""
|
||||
for model in self._pina_models:
|
||||
if not isinstance(model, OptimizedModule):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _setup_compile(self):
|
||||
"""
|
||||
Compile all models in the solver using ``torch.compile``.
|
||||
|
||||
This method iterates through each model stored in the solver
|
||||
list and attempts to compile them for optimized execution. It supports
|
||||
models of type `torch.nn.Module` and `torch.nn.ModuleDict`. For models
|
||||
stored in a `ModuleDict`, each submodule is compiled individually.
|
||||
Models on Apple Silicon (MPS) use the 'eager' backend,
|
||||
while others use 'inductor'.
|
||||
|
||||
:raises RuntimeError: If a model is neither `torch.nn.Module`
|
||||
nor `torch.nn.ModuleDict`.
|
||||
"""
|
||||
for i, model in enumerate(self._pina_models):
|
||||
if isinstance(model, torch.nn.ModuleDict):
|
||||
for name, module in model.items():
|
||||
self._pina_models[i][name] = self._compile_modules(module)
|
||||
elif isinstance(model, torch.nn.Module):
|
||||
self._pina_models[i] = self._compile_modules(model)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Compilation available only for "
|
||||
"torch.nn.Module or torch.nn.ModuleDict."
|
||||
)
|
||||
|
||||
def _check_solver_consistency(self, problem):
|
||||
"""
|
||||
Check the consistency of the solver with the problem formulation.
|
||||
|
||||
:param AbstractProblem problem: The problem to be solved.
|
||||
"""
|
||||
for condition in problem.conditions.values():
|
||||
check_consistency(condition, self.accepted_conditions_types)
|
||||
|
||||
def _optimization_cycle(self, batch, **kwargs):
|
||||
"""
|
||||
Aggregate the loss for each condition in the batch.
|
||||
|
||||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||
tuple containing a condition name and a dictionary of points.
|
||||
:param dict kwargs: Additional keyword arguments passed to
|
||||
``optimization_cycle``.
|
||||
:return: The losses computed for all conditions in the batch, casted
|
||||
to a subclass of :class:`torch.Tensor`. It should return a dict
|
||||
containing the condition name and the associated scalar loss.
|
||||
:rtype: dict
|
||||
"""
|
||||
losses = self.optimization_cycle(batch)
|
||||
for name, value in losses.items():
|
||||
self.store_log(
|
||||
f"{name}_loss", value.item(), self.get_batch_size(batch)
|
||||
)
|
||||
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
|
||||
return loss
|
||||
|
||||
@property
|
||||
def problem(self):
|
||||
@staticmethod
|
||||
def _compile_modules(model):
|
||||
"""
|
||||
The problem instance.
|
||||
Perform the compilation of the model.
|
||||
|
||||
:return: The problem instance.
|
||||
:rtype: :class:`~pina.problem.abstract_problem.AbstractProblem`
|
||||
"""
|
||||
return self._pina_problem
|
||||
This method attempts to compile the given PyTorch model
|
||||
using ``torch.compile`` to improve execution performance. The
|
||||
backend is selected based on the device on which the model resides:
|
||||
``eager`` is used for MPS devices (Apple Silicon), and ``inductor``
|
||||
is used for all others.
|
||||
|
||||
@property
|
||||
def use_lt(self):
|
||||
"""
|
||||
Using LabelTensors as input during training.
|
||||
If compilation fails, the method prints the error and returns the
|
||||
original, uncompiled model.
|
||||
|
||||
:return: The use_lt attribute.
|
||||
:rtype: bool
|
||||
:param torch.nn.Module model: The model to compile.
|
||||
:raises Exception: If the compilation fails.
|
||||
:return: The compiled model.
|
||||
:rtype: torch.nn.Module
|
||||
"""
|
||||
return self._use_lt
|
||||
|
||||
@property
|
||||
def weighting(self):
|
||||
"""
|
||||
The weighting schema.
|
||||
|
||||
:return: The weighting schema.
|
||||
:rtype: :class:`~pina.loss.weighting_interface.WeightingInterface`
|
||||
"""
|
||||
return self._pina_weighting
|
||||
model_device = next(model.parameters()).device
|
||||
try:
|
||||
if model_device == torch.device("mps:0"):
|
||||
model = torch.compile(model, backend="eager")
|
||||
else:
|
||||
model = torch.compile(model, backend="inductor")
|
||||
except Exception as e:
|
||||
print("Compilation failed, running in normal mode.:\n", e)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def get_batch_size(batch):
|
||||
@@ -232,62 +306,35 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
|
||||
return TorchScheduler(torch.optim.lr_scheduler.ConstantLR)
|
||||
|
||||
def on_train_start(self):
|
||||
@property
|
||||
def problem(self):
|
||||
"""
|
||||
This method is called at the start of the training process to compile
|
||||
the model if the :class:`~pina.trainer.Trainer` ``compile`` is ``True``.
|
||||
"""
|
||||
super().on_train_start()
|
||||
if self.trainer.compile:
|
||||
self._compile_model()
|
||||
The problem instance.
|
||||
|
||||
def on_test_start(self):
|
||||
:return: The problem instance.
|
||||
:rtype: :class:`~pina.problem.abstract_problem.AbstractProblem`
|
||||
"""
|
||||
This method is called at the start of the test process to compile
|
||||
the model if the :class:`~pina.trainer.Trainer` ``compile`` is ``True``.
|
||||
"""
|
||||
super().on_train_start()
|
||||
if self.trainer.compile and not self._check_already_compiled():
|
||||
self._compile_model()
|
||||
return self._pina_problem
|
||||
|
||||
def _check_already_compiled(self):
|
||||
@property
|
||||
def use_lt(self):
|
||||
"""
|
||||
Check if the model is already compiled.
|
||||
Using LabelTensors as input during training.
|
||||
|
||||
:return: ``True`` if the model is already compiled, ``False`` otherwise.
|
||||
:return: The use_lt attribute.
|
||||
:rtype: bool
|
||||
"""
|
||||
return self._use_lt
|
||||
|
||||
models = self._pina_models
|
||||
if len(models) == 1 and isinstance(
|
||||
self._pina_models[0], torch.nn.ModuleDict
|
||||
):
|
||||
models = list(self._pina_models.values())
|
||||
for model in models:
|
||||
if not isinstance(model, (OptimizedModule, torch.nn.ModuleDict)):
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _perform_compilation(model):
|
||||
@property
|
||||
def weighting(self):
|
||||
"""
|
||||
Perform the compilation of the model.
|
||||
The weighting schema.
|
||||
|
||||
:param torch.nn.Module model: The model to compile.
|
||||
:raises Exception: If the compilation fails.
|
||||
:return: The compiled model.
|
||||
:rtype: torch.nn.Module
|
||||
:return: The weighting schema.
|
||||
:rtype: :class:`~pina.loss.weighting_interface.WeightingInterface`
|
||||
"""
|
||||
|
||||
model_device = next(model.parameters()).device
|
||||
try:
|
||||
if model_device == torch.device("mps:0"):
|
||||
model = torch.compile(model, backend="eager")
|
||||
else:
|
||||
model = torch.compile(model, backend="inductor")
|
||||
except Exception as e:
|
||||
print("Compilation failed, running in normal mode.:\n", e)
|
||||
return model
|
||||
return self._pina_weighting
|
||||
|
||||
|
||||
class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
|
||||
@@ -310,13 +357,13 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
|
||||
:param AbstractProblem problem: The problem to be solved.
|
||||
:param torch.nn.Module model: The neural network model to be used.
|
||||
:param Optimizer optimizer: The optimizer to be used.
|
||||
If `None`, the :class:`torch.optim.Adam` optimizer is
|
||||
If ``None``, the :class:`torch.optim.Adam` optimizer is
|
||||
used. Default is ``None``.
|
||||
:param Scheduler scheduler: The scheduler to be used.
|
||||
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
|
||||
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
|
||||
scheduler is used. Default is ``None``.
|
||||
:param WeightingInterface weighting: The weighting schema to be used.
|
||||
If `None`, no weighting schema is used. Default is ``None``.
|
||||
If ``None``, no weighting schema is used. Default is ``None``.
|
||||
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
|
||||
"""
|
||||
if optimizer is None:
|
||||
@@ -344,12 +391,11 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
|
||||
Forward pass implementation.
|
||||
|
||||
:param x: Input tensor.
|
||||
:type x: torch.Tensor | LabelTensor
|
||||
:type x: torch.Tensor | LabelTensor | Graph | Data
|
||||
:return: Solver solution.
|
||||
:rtype: torch.Tensor | LabelTensor
|
||||
:rtype: torch.Tensor | LabelTensor | Graph | Data
|
||||
"""
|
||||
x = self.model(x)
|
||||
return x
|
||||
return self.model(x)
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""
|
||||
@@ -362,28 +408,6 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
|
||||
self.scheduler.hook(self.optimizer)
|
||||
return ([self.optimizer.instance], [self.scheduler.instance])
|
||||
|
||||
def _compile_model(self):
|
||||
"""
|
||||
Compile the model.
|
||||
"""
|
||||
if isinstance(self._pina_models[0], torch.nn.ModuleDict):
|
||||
self._compile_module_dict()
|
||||
else:
|
||||
self._compile_single_model()
|
||||
|
||||
def _compile_module_dict(self):
|
||||
"""
|
||||
Compile the model if it is a :class:`torch.nn.ModuleDict`.
|
||||
"""
|
||||
for name, model in self._pina_models[0].items():
|
||||
self._pina_models[0][name] = self._perform_compilation(model)
|
||||
|
||||
def _compile_single_model(self):
|
||||
"""
|
||||
Compile the model if it is a single :class:`torch.nn.Module`.
|
||||
"""
|
||||
self._pina_models[0] = self._perform_compilation(self._pina_models[0])
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
"""
|
||||
@@ -436,13 +460,13 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
|
||||
:param models: The neural network models to be used.
|
||||
:type model: list[torch.nn.Module] | tuple[torch.nn.Module]
|
||||
:param list[Optimizer] optimizers: The optimizers to be used.
|
||||
If `None`, the :class:`torch.optim.Adam` optimizer is used for all
|
||||
If ``None``, the :class:`torch.optim.Adam` optimizer is used for all
|
||||
models. Default is ``None``.
|
||||
:param list[Scheduler] schedulers: The schedulers to be used.
|
||||
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
|
||||
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
|
||||
scheduler is used for all the models. Default is ``None``.
|
||||
:param WeightingInterface weighting: The weighting schema to be used.
|
||||
If `None`, no weighting schema is used. Default is ``None``.
|
||||
If ``None``, no weighting schema is used. Default is ``None``.
|
||||
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
|
||||
:raises ValueError: If the models are not a list or tuple with length
|
||||
greater than one.
|
||||
@@ -519,6 +543,22 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
|
||||
# http://lightning.ai/docs/pytorch/stable/model/manual_optimization.html
|
||||
self.automatic_optimization = False
|
||||
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx):
|
||||
"""
|
||||
This method is called at the end of each training batch and overrides
|
||||
the PyTorch Lightning implementation to log checkpoints.
|
||||
|
||||
:param torch.Tensor outputs: The ``model``'s output for the current
|
||||
batch.
|
||||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||
tuple containing a condition name and a dictionary of points.
|
||||
:param int batch_idx: The index of the current batch.
|
||||
"""
|
||||
# increase by one the counter of optimization to save loggers
|
||||
epoch_loop = self.trainer.fit_loop.epoch_loop
|
||||
epoch_loop.manual_optimization.optim_step_progress.total.completed += 1
|
||||
return super().on_train_batch_end(outputs, batch, batch_idx)
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""
|
||||
Optimizer configuration for the solver.
|
||||
@@ -537,14 +577,6 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
|
||||
[scheduler.instance for scheduler in self.schedulers],
|
||||
)
|
||||
|
||||
def _compile_model(self):
|
||||
"""
|
||||
Compile the model.
|
||||
"""
|
||||
for i, model in enumerate(self._pina_models):
|
||||
if not isinstance(model, torch.nn.ModuleDict):
|
||||
self._pina_models[i] = self._perform_compilation(model)
|
||||
|
||||
@property
|
||||
def models(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user