fix doc solver

This commit is contained in:
giovanni
2025-03-13 19:05:15 +01:00
committed by Nicola Demo
parent 3f8665b5d8
commit 5d908a291d
6 changed files with 374 additions and 251 deletions

View File

@@ -14,17 +14,19 @@ from ..utils import check_consistency, labelize_forward
class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
"""
SolverInterface base class. This class is a wrapper of LightningModule.
Abstract base class for PINA solvers. All specific solvers should inherit
from this interface. This class is a wrapper of
:class:`~lightning.pytorch.LightningModule`.
"""
def __init__(self, problem, weighting, use_lt):
"""
:param problem: A problem definition instance.
:type problem: AbstractProblem
:param weighting: The loss weighting to use.
:type weighting: WeightingInterface
:param use_lt: Using LabelTensors as input during training.
:type use_lt: bool
Initialization of the :class:`SolverInterface` class.
:param AbstractProblem problem: The problem to be solved.
:param WeightingInterface weighting: The weighting schema to be used.
If `None`, no weighting schema is used. Default is ``None``.
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
"""
super().__init__()
@@ -59,22 +61,24 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
self._pina_schedulers = None
def _check_solver_consistency(self, problem):
"""
Check the consistency of the solver with the problem formulation.
:param AbstractProblem problem: The problem to be solved.
"""
for condition in problem.conditions.values():
check_consistency(condition, self.accepted_conditions_types)
def _optimization_cycle(self, batch):
"""
Perform a private optimization cycle by computing the loss for each
condition in the given batch. The loss are later aggregated using the
specific weighting schema.
Aggregate the loss for each condition in the batch.
:param batch: A batch of data, where each element is a tuple containing
a condition name and a dictionary of points.
:type batch: list of tuples (str, dict)
:return: The computed loss for the all conditions in the batch,
cast to a subclass of `torch.Tensor`. It should return a dict
containing the condition name and the associated scalar loss.
:rtype: dict(torch.Tensor)
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:return: The computed loss for the all conditions in the batch, casted
to a subclass of `torch.Tensor`. It should return a dict containing
the condition name and the associated scalar loss.
:rtype: dict
"""
losses = self.optimization_cycle(batch)
for name, value in losses.items():
@@ -88,9 +92,8 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
"""
Solver training step.
:param batch: The batch element in the dataloader.
:type batch: tuple
:return: The sum of the loss functions.
:param list[tuple[str, dict]] batch: The batch element in the dataloader.
:return: The loss of the training step.
:rtype: LabelTensor
"""
loss = self._optimization_cycle(batch=batch)
@@ -101,8 +104,7 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
"""
Solver validation step.
:param batch: The batch element in the dataloader.
:type batch: tuple
:param list[tuple[str, dict]] batch: The batch element in the dataloader.
"""
loss = self._optimization_cycle(batch=batch)
self.store_log("val_loss", loss, self.get_batch_size(batch))
@@ -111,15 +113,18 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
"""
Solver test step.
:param batch: The batch element in the dataloader.
:type batch: tuple
:param list[tuple[str, dict]] batch: The batch element in the dataloader.
"""
loss = self._optimization_cycle(batch=batch)
self.store_log("test_loss", loss, self.get_batch_size(batch))
def store_log(self, name, value, batch_size):
"""
TODO
Store the log of the solver.
:param str name: The name of the log.
:param torch.Tensor value: The value of the log.
:param int batch_size: The size of the batch.
"""
self.log(
@@ -132,49 +137,59 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
@abstractmethod
def forward(self, *args, **kwargs):
"""
TODO
Abstract method for the forward pass implementation.
"""
@abstractmethod
def optimization_cycle(self, batch):
"""
Perform an optimization cycle by computing the loss for each condition
in the given batch.
The optimization cycle for the solvers.
:param batch: A batch of data, where each element is a tuple containing
a condition name and a dictionary of points.
:type batch: list of tuples (str, dict)
:return: The computed loss for the all conditions in the batch,
cast to a subclass of `torch.Tensor`. It should return a dict
containing the condition name and the associated scalar loss.
:rtype: dict(torch.Tensor)
:param list[tuple[str, dict]] batch: The batch element in the dataloader.
:return: The computed loss for the all conditions in the batch, casted
to a subclass of `torch.Tensor`. It should return a dict containing
the condition name and the associated scalar loss.
:rtype: dict
"""
@property
def problem(self):
"""
The problem formulation.
The problem instance.
:return: The problem instance.
:rtype: :class:`~pina.problem.abstract_problem.AbstractProblem`
"""
return self._pina_problem
@property
def use_lt(self):
"""
Using LabelTensor in training.
Using LabelTensors as input during training.
:return: The use_lt attribute.
:rtype: bool
"""
return self._use_lt
@property
def weighting(self):
"""
The weighting mechanism.
The weighting schema.
:return: The weighting schema.
:rtype: :class:`~pina.loss.weighting_interface.WeightingInterface`
"""
return self._pina_weighting
@staticmethod
def get_batch_size(batch):
"""
TODO
Get the batch size.
:param list[tuple[str, dict]] batch: The batch element in the dataloader.
:return: The size of the batch.
:rtype: int
"""
batch_size = 0
@@ -185,23 +200,29 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
@staticmethod
def default_torch_optimizer():
"""
TODO
"""
Set the default optimizer to :class:`torch.optim.Adam`.
:return: The default optimizer.
:rtype: Optimizer
"""
return TorchOptimizer(torch.optim.Adam, lr=0.001)
@staticmethod
def default_torch_scheduler():
"""
TODO
Set the default scheduler to
:class:`torch.optim.lr_scheduler.ConstantLR`.
:return: The default scheduler.
:rtype: Scheduler
"""
return TorchScheduler(torch.optim.lr_scheduler.ConstantLR)
def on_train_start(self):
"""
Hook that is called before training begins.
Used to compile the model if the trainer is set to compile.
This method is called at the start of the training process to compile
the model if the :class:`~pina.trainer.Trainer` ``compile`` is ``True``.
"""
super().on_train_start()
if self.trainer.compile:
@@ -209,8 +230,8 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
def on_test_start(self):
"""
Hook that is called before training begins.
Used to compile the model if the trainer is set to compile.
This method is called at the start of the test process to compile
the model if the :class:`~pina.trainer.Trainer` ``compile`` is ``True``.
"""
super().on_train_start()
if self.trainer.compile and not self._check_already_compiled():
@@ -218,7 +239,10 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
def _check_already_compiled(self):
"""
TODO
Check if the model is already compiled.
:return: ``True`` if the model is already compiled, ``False`` otherwise.
:rtype: bool
"""
models = self._pina_models
@@ -234,7 +258,12 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
@staticmethod
def _perform_compilation(model):
"""
TODO
Perform the compilation of the model.
:param torch.nn.Module model: The model to compile.
:raises Exception: If the compilation fails.
:return: The compiled model.
:rtype: torch.nn.Module
"""
model_device = next(model.parameters()).device
@@ -249,8 +278,9 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
"""TODO"""
"""
Base class for PINA solvers using a single :class:`torch.nn.Module`.
"""
def __init__(
self,
problem,
@@ -261,14 +291,18 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
use_lt=True,
):
"""
:param problem: A problem definition instance.
:type problem: AbstractProblem
:param model: A torch nn.Module instances.
:type model: torch.nn.Module
:param Optimizer optimizers: A neural network optimizers to use.
:param Scheduler optimizers: A neural network scheduler to use.
:param WeightingInterface weighting: The loss weighting to use.
:param bool use_lt: Using LabelTensors as input during training.
Initialization of the :class:`SingleSolverInterface` class.
:param AbstractProblem problem: The problem to be solved.
:param torch.nn.Module model: The neural network model to be used.
:param Optimizer optimizer: The optimizer to be used.
If `None`, the Adam optimizer is used. Default is ``None``.
:param Scheduler scheduler: The scheduler to be used.
If `None`, the constant learning rate scheduler is used.
Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If `None`, no weighting schema is used. Default is ``None``.
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
"""
if optimizer is None:
optimizer = self.default_torch_optimizer()
@@ -292,11 +326,12 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
def forward(self, x):
"""
Forward pass implementation for the solver.
Forward pass implementation.
:param torch.Tensor x: Input tensor.
:param x: Input tensor.
:type x: torch.Tensor | LabelTensor
:return: Solver solution.
:rtype: torch.Tensor
:rtype: torch.Tensor | LabelTensor
"""
x = self.model(x)
return x
@@ -305,7 +340,7 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
"""
Optimizer configuration for the solver.
:return: The optimizers and the schedulers
:return: The optimizer and the scheduler
:rtype: tuple(list, list)
"""
self.optimizer.hook(self.model.parameters())
@@ -313,44 +348,61 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
return ([self.optimizer.instance], [self.scheduler.instance])
def _compile_model(self):
"""
Compile the model.
"""
if isinstance(self._pina_models[0], torch.nn.ModuleDict):
self._compile_module_dict()
else:
self._compile_single_model()
def _compile_module_dict(self):
"""
Compile the model if it is a :class:`torch.nn.ModuleDict`.
"""
for name, model in self._pina_models[0].items():
self._pina_models[0][name] = self._perform_compilation(model)
def _compile_single_model(self):
"""
Compile the model if it is a single :class:`torch.nn.Module`.
"""
self._pina_models[0] = self._perform_compilation(self._pina_models[0])
@property
def model(self):
"""
Model for training.
The model used for training.
:return: The model used for training.
:rtype: torch.nn.Module
"""
return self._pina_models[0]
@property
def scheduler(self):
"""
Scheduler for training.
The scheduler used for training.
:return: The scheduler used for training.
:rtype: Scheduler
"""
return self._pina_schedulers[0]
@property
def optimizer(self):
"""
Optimizer for training.
The optimizer used for training.
:return: The optimizer used for training.
:rtype: Optimizer
"""
return self._pina_optimizers[0]
class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
"""
Multiple Solver base class. This class inherits is a wrapper of
SolverInterface class
Base class for PINA solvers using multiple :class:`torch.nn.Module`.
"""
def __init__(
@@ -363,16 +415,22 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
use_lt=True,
):
"""
:param problem: A problem definition instance.
:type problem: AbstractProblem
:param models: Multiple torch nn.Module instances.
Initialization of the :class:`MultiSolverInterface` class.
:param AbstractProblem problem: The problem to be solved.
:param models: The neural network models to be used.
:type model: list[torch.nn.Module] | tuple[torch.nn.Module]
:param list(Optimizer) optimizers: A list of neural network
optimizers to use.
:param list(Scheduler) optimizers: A list of neural network
schedulers to use.
:param WeightingInterface weighting: The loss weighting to use.
:param bool use_lt: Using LabelTensors as input during training.
:param list[Optimizer] optimizers: The optimizers to be used.
If `None`, the Adam optimizer is used for all models.
Default is ``None``.
:param list[Scheduler] schedulers: The schedulers to be used.
If `None`, the constant learning rate scheduler is used for all the
models. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If `None`, no weighting schema is used. Default is ``None``.
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
:raises ValueError: If the models are not a list or tuple with length
greater than one.
"""
if not isinstance(models, (list, tuple)) or len(models) < 2:
raise ValueError(
@@ -418,9 +476,10 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
self._pina_schedulers = schedulers
def configure_optimizers(self):
"""Optimizer configuration for the solver.
"""
Optimizer configuration for the solver.
:return: The optimizers and the schedulers
:return: The optimizer and the scheduler
:rtype: tuple(list, list)
"""
for optimizer, scheduler, model in zip(
@@ -435,6 +494,9 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
)
def _compile_model(self):
"""
Compile the model.
"""
for i, model in enumerate(self._pina_models):
if not isinstance(model, torch.nn.ModuleDict):
self._pina_models[i] = self._perform_compilation(model)
@@ -442,17 +504,29 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
@property
def models(self):
"""
The torch model."""
The models used for training.
:return: The models used for training.
:rtype: torch.nn.ModuleList
"""
return self._pina_models
@property
def optimizers(self):
"""
The torch model."""
The optimizers used for training.
:return: The optimizers used for training.
:rtype: list[Optimizer]
"""
return self._pina_optimizers
@property
def schedulers(self):
"""
The torch model."""
The schedulers used for training.
:return: The schedulers used for training.
:rtype: list[Scheduler]
"""
return self._pina_schedulers