fix doc solver
This commit is contained in:
@@ -14,17 +14,19 @@ from ..utils import check_consistency, labelize_forward
|
||||
|
||||
class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
"""
|
||||
SolverInterface base class. This class is a wrapper of LightningModule.
|
||||
Abstract base class for PINA solvers. All specific solvers should inherit
|
||||
from this interface. This class is a wrapper of
|
||||
:class:`~lightning.pytorch.LightningModule`.
|
||||
"""
|
||||
|
||||
def __init__(self, problem, weighting, use_lt):
|
||||
"""
|
||||
:param problem: A problem definition instance.
|
||||
:type problem: AbstractProblem
|
||||
:param weighting: The loss weighting to use.
|
||||
:type weighting: WeightingInterface
|
||||
:param use_lt: Using LabelTensors as input during training.
|
||||
:type use_lt: bool
|
||||
Initialization of the :class:`SolverInterface` class.
|
||||
|
||||
:param AbstractProblem problem: The problem to be solved.
|
||||
:param WeightingInterface weighting: The weighting schema to be used.
|
||||
If `None`, no weighting schema is used. Default is ``None``.
|
||||
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@@ -59,22 +61,24 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
self._pina_schedulers = None
|
||||
|
||||
def _check_solver_consistency(self, problem):
|
||||
"""
|
||||
Check the consistency of the solver with the problem formulation.
|
||||
|
||||
:param AbstractProblem problem: The problem to be solved.
|
||||
"""
|
||||
for condition in problem.conditions.values():
|
||||
check_consistency(condition, self.accepted_conditions_types)
|
||||
|
||||
def _optimization_cycle(self, batch):
|
||||
"""
|
||||
Perform a private optimization cycle by computing the loss for each
|
||||
condition in the given batch. The loss are later aggregated using the
|
||||
specific weighting schema.
|
||||
Aggregate the loss for each condition in the batch.
|
||||
|
||||
:param batch: A batch of data, where each element is a tuple containing
|
||||
a condition name and a dictionary of points.
|
||||
:type batch: list of tuples (str, dict)
|
||||
:return: The computed loss for the all conditions in the batch,
|
||||
cast to a subclass of `torch.Tensor`. It should return a dict
|
||||
containing the condition name and the associated scalar loss.
|
||||
:rtype: dict(torch.Tensor)
|
||||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||
tuple containing a condition name and a dictionary of points.
|
||||
:return: The computed loss for the all conditions in the batch, casted
|
||||
to a subclass of `torch.Tensor`. It should return a dict containing
|
||||
the condition name and the associated scalar loss.
|
||||
:rtype: dict
|
||||
"""
|
||||
losses = self.optimization_cycle(batch)
|
||||
for name, value in losses.items():
|
||||
@@ -88,9 +92,8 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
"""
|
||||
Solver training step.
|
||||
|
||||
:param batch: The batch element in the dataloader.
|
||||
:type batch: tuple
|
||||
:return: The sum of the loss functions.
|
||||
:param list[tuple[str, dict]] batch: The batch element in the dataloader.
|
||||
:return: The loss of the training step.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
loss = self._optimization_cycle(batch=batch)
|
||||
@@ -101,8 +104,7 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
"""
|
||||
Solver validation step.
|
||||
|
||||
:param batch: The batch element in the dataloader.
|
||||
:type batch: tuple
|
||||
:param list[tuple[str, dict]] batch: The batch element in the dataloader.
|
||||
"""
|
||||
loss = self._optimization_cycle(batch=batch)
|
||||
self.store_log("val_loss", loss, self.get_batch_size(batch))
|
||||
@@ -111,15 +113,18 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
"""
|
||||
Solver test step.
|
||||
|
||||
:param batch: The batch element in the dataloader.
|
||||
:type batch: tuple
|
||||
:param list[tuple[str, dict]] batch: The batch element in the dataloader.
|
||||
"""
|
||||
loss = self._optimization_cycle(batch=batch)
|
||||
self.store_log("test_loss", loss, self.get_batch_size(batch))
|
||||
|
||||
def store_log(self, name, value, batch_size):
|
||||
"""
|
||||
TODO
|
||||
Store the log of the solver.
|
||||
|
||||
:param str name: The name of the log.
|
||||
:param torch.Tensor value: The value of the log.
|
||||
:param int batch_size: The size of the batch.
|
||||
"""
|
||||
|
||||
self.log(
|
||||
@@ -132,49 +137,59 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
@abstractmethod
|
||||
def forward(self, *args, **kwargs):
|
||||
"""
|
||||
TODO
|
||||
Abstract method for the forward pass implementation.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def optimization_cycle(self, batch):
|
||||
"""
|
||||
Perform an optimization cycle by computing the loss for each condition
|
||||
in the given batch.
|
||||
The optimization cycle for the solvers.
|
||||
|
||||
:param batch: A batch of data, where each element is a tuple containing
|
||||
a condition name and a dictionary of points.
|
||||
:type batch: list of tuples (str, dict)
|
||||
:return: The computed loss for the all conditions in the batch,
|
||||
cast to a subclass of `torch.Tensor`. It should return a dict
|
||||
containing the condition name and the associated scalar loss.
|
||||
:rtype: dict(torch.Tensor)
|
||||
:param list[tuple[str, dict]] batch: The batch element in the dataloader.
|
||||
:return: The computed loss for the all conditions in the batch, casted
|
||||
to a subclass of `torch.Tensor`. It should return a dict containing
|
||||
the condition name and the associated scalar loss.
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
@property
|
||||
def problem(self):
|
||||
"""
|
||||
The problem formulation.
|
||||
The problem instance.
|
||||
|
||||
:return: The problem instance.
|
||||
:rtype: :class:`~pina.problem.abstract_problem.AbstractProblem`
|
||||
"""
|
||||
return self._pina_problem
|
||||
|
||||
@property
|
||||
def use_lt(self):
|
||||
"""
|
||||
Using LabelTensor in training.
|
||||
Using LabelTensors as input during training.
|
||||
|
||||
:return: The use_lt attribute.
|
||||
:rtype: bool
|
||||
"""
|
||||
return self._use_lt
|
||||
|
||||
@property
|
||||
def weighting(self):
|
||||
"""
|
||||
The weighting mechanism.
|
||||
The weighting schema.
|
||||
|
||||
:return: The weighting schema.
|
||||
:rtype: :class:`~pina.loss.weighting_interface.WeightingInterface`
|
||||
"""
|
||||
return self._pina_weighting
|
||||
|
||||
@staticmethod
|
||||
def get_batch_size(batch):
|
||||
"""
|
||||
TODO
|
||||
Get the batch size.
|
||||
|
||||
:param list[tuple[str, dict]] batch: The batch element in the dataloader.
|
||||
:return: The size of the batch.
|
||||
:rtype: int
|
||||
"""
|
||||
|
||||
batch_size = 0
|
||||
@@ -185,23 +200,29 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
@staticmethod
|
||||
def default_torch_optimizer():
|
||||
"""
|
||||
TODO
|
||||
"""
|
||||
Set the default optimizer to :class:`torch.optim.Adam`.
|
||||
|
||||
:return: The default optimizer.
|
||||
:rtype: Optimizer
|
||||
"""
|
||||
return TorchOptimizer(torch.optim.Adam, lr=0.001)
|
||||
|
||||
@staticmethod
|
||||
def default_torch_scheduler():
|
||||
"""
|
||||
TODO
|
||||
Set the default scheduler to
|
||||
:class:`torch.optim.lr_scheduler.ConstantLR`.
|
||||
|
||||
:return: The default scheduler.
|
||||
:rtype: Scheduler
|
||||
"""
|
||||
|
||||
return TorchScheduler(torch.optim.lr_scheduler.ConstantLR)
|
||||
|
||||
def on_train_start(self):
|
||||
"""
|
||||
Hook that is called before training begins.
|
||||
Used to compile the model if the trainer is set to compile.
|
||||
This method is called at the start of the training process to compile
|
||||
the model if the :class:`~pina.trainer.Trainer` ``compile`` is ``True``.
|
||||
"""
|
||||
super().on_train_start()
|
||||
if self.trainer.compile:
|
||||
@@ -209,8 +230,8 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
|
||||
def on_test_start(self):
|
||||
"""
|
||||
Hook that is called before training begins.
|
||||
Used to compile the model if the trainer is set to compile.
|
||||
This method is called at the start of the test process to compile
|
||||
the model if the :class:`~pina.trainer.Trainer` ``compile`` is ``True``.
|
||||
"""
|
||||
super().on_train_start()
|
||||
if self.trainer.compile and not self._check_already_compiled():
|
||||
@@ -218,7 +239,10 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
|
||||
def _check_already_compiled(self):
|
||||
"""
|
||||
TODO
|
||||
Check if the model is already compiled.
|
||||
|
||||
:return: ``True`` if the model is already compiled, ``False`` otherwise.
|
||||
:rtype: bool
|
||||
"""
|
||||
|
||||
models = self._pina_models
|
||||
@@ -234,7 +258,12 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
@staticmethod
|
||||
def _perform_compilation(model):
|
||||
"""
|
||||
TODO
|
||||
Perform the compilation of the model.
|
||||
|
||||
:param torch.nn.Module model: The model to compile.
|
||||
:raises Exception: If the compilation fails.
|
||||
:return: The compiled model.
|
||||
:rtype: torch.nn.Module
|
||||
"""
|
||||
|
||||
model_device = next(model.parameters()).device
|
||||
@@ -249,8 +278,9 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
|
||||
|
||||
class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
|
||||
"""TODO"""
|
||||
|
||||
"""
|
||||
Base class for PINA solvers using a single :class:`torch.nn.Module`.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
problem,
|
||||
@@ -261,14 +291,18 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
|
||||
use_lt=True,
|
||||
):
|
||||
"""
|
||||
:param problem: A problem definition instance.
|
||||
:type problem: AbstractProblem
|
||||
:param model: A torch nn.Module instances.
|
||||
:type model: torch.nn.Module
|
||||
:param Optimizer optimizers: A neural network optimizers to use.
|
||||
:param Scheduler optimizers: A neural network scheduler to use.
|
||||
:param WeightingInterface weighting: The loss weighting to use.
|
||||
:param bool use_lt: Using LabelTensors as input during training.
|
||||
Initialization of the :class:`SingleSolverInterface` class.
|
||||
|
||||
:param AbstractProblem problem: The problem to be solved.
|
||||
:param torch.nn.Module model: The neural network model to be used.
|
||||
:param Optimizer optimizer: The optimizer to be used.
|
||||
If `None`, the Adam optimizer is used. Default is ``None``.
|
||||
:param Scheduler scheduler: The scheduler to be used.
|
||||
If `None`, the constant learning rate scheduler is used.
|
||||
Default is ``None``.
|
||||
:param WeightingInterface weighting: The weighting schema to be used.
|
||||
If `None`, no weighting schema is used. Default is ``None``.
|
||||
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
|
||||
"""
|
||||
if optimizer is None:
|
||||
optimizer = self.default_torch_optimizer()
|
||||
@@ -292,11 +326,12 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass implementation for the solver.
|
||||
Forward pass implementation.
|
||||
|
||||
:param torch.Tensor x: Input tensor.
|
||||
:param x: Input tensor.
|
||||
:type x: torch.Tensor | LabelTensor
|
||||
:return: Solver solution.
|
||||
:rtype: torch.Tensor
|
||||
:rtype: torch.Tensor | LabelTensor
|
||||
"""
|
||||
x = self.model(x)
|
||||
return x
|
||||
@@ -305,7 +340,7 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
|
||||
"""
|
||||
Optimizer configuration for the solver.
|
||||
|
||||
:return: The optimizers and the schedulers
|
||||
:return: The optimizer and the scheduler
|
||||
:rtype: tuple(list, list)
|
||||
"""
|
||||
self.optimizer.hook(self.model.parameters())
|
||||
@@ -313,44 +348,61 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
|
||||
return ([self.optimizer.instance], [self.scheduler.instance])
|
||||
|
||||
def _compile_model(self):
|
||||
"""
|
||||
Compile the model.
|
||||
"""
|
||||
if isinstance(self._pina_models[0], torch.nn.ModuleDict):
|
||||
self._compile_module_dict()
|
||||
else:
|
||||
self._compile_single_model()
|
||||
|
||||
def _compile_module_dict(self):
|
||||
"""
|
||||
Compile the model if it is a :class:`torch.nn.ModuleDict`.
|
||||
"""
|
||||
for name, model in self._pina_models[0].items():
|
||||
self._pina_models[0][name] = self._perform_compilation(model)
|
||||
|
||||
def _compile_single_model(self):
|
||||
"""
|
||||
Compile the model if it is a single :class:`torch.nn.Module`.
|
||||
"""
|
||||
self._pina_models[0] = self._perform_compilation(self._pina_models[0])
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
"""
|
||||
Model for training.
|
||||
The model used for training.
|
||||
|
||||
:return: The model used for training.
|
||||
:rtype: torch.nn.Module
|
||||
"""
|
||||
return self._pina_models[0]
|
||||
|
||||
@property
|
||||
def scheduler(self):
|
||||
"""
|
||||
Scheduler for training.
|
||||
The scheduler used for training.
|
||||
|
||||
:return: The scheduler used for training.
|
||||
:rtype: Scheduler
|
||||
"""
|
||||
return self._pina_schedulers[0]
|
||||
|
||||
@property
|
||||
def optimizer(self):
|
||||
"""
|
||||
Optimizer for training.
|
||||
The optimizer used for training.
|
||||
|
||||
:return: The optimizer used for training.
|
||||
:rtype: Optimizer
|
||||
"""
|
||||
return self._pina_optimizers[0]
|
||||
|
||||
|
||||
class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
|
||||
"""
|
||||
Multiple Solver base class. This class inherits is a wrapper of
|
||||
SolverInterface class
|
||||
Base class for PINA solvers using multiple :class:`torch.nn.Module`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -363,16 +415,22 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
|
||||
use_lt=True,
|
||||
):
|
||||
"""
|
||||
:param problem: A problem definition instance.
|
||||
:type problem: AbstractProblem
|
||||
:param models: Multiple torch nn.Module instances.
|
||||
Initialization of the :class:`MultiSolverInterface` class.
|
||||
|
||||
:param AbstractProblem problem: The problem to be solved.
|
||||
:param models: The neural network models to be used.
|
||||
:type model: list[torch.nn.Module] | tuple[torch.nn.Module]
|
||||
:param list(Optimizer) optimizers: A list of neural network
|
||||
optimizers to use.
|
||||
:param list(Scheduler) optimizers: A list of neural network
|
||||
schedulers to use.
|
||||
:param WeightingInterface weighting: The loss weighting to use.
|
||||
:param bool use_lt: Using LabelTensors as input during training.
|
||||
:param list[Optimizer] optimizers: The optimizers to be used.
|
||||
If `None`, the Adam optimizer is used for all models.
|
||||
Default is ``None``.
|
||||
:param list[Scheduler] schedulers: The schedulers to be used.
|
||||
If `None`, the constant learning rate scheduler is used for all the
|
||||
models. Default is ``None``.
|
||||
:param WeightingInterface weighting: The weighting schema to be used.
|
||||
If `None`, no weighting schema is used. Default is ``None``.
|
||||
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
|
||||
:raises ValueError: If the models are not a list or tuple with length
|
||||
greater than one.
|
||||
"""
|
||||
if not isinstance(models, (list, tuple)) or len(models) < 2:
|
||||
raise ValueError(
|
||||
@@ -418,9 +476,10 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
|
||||
self._pina_schedulers = schedulers
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""Optimizer configuration for the solver.
|
||||
"""
|
||||
Optimizer configuration for the solver.
|
||||
|
||||
:return: The optimizers and the schedulers
|
||||
:return: The optimizer and the scheduler
|
||||
:rtype: tuple(list, list)
|
||||
"""
|
||||
for optimizer, scheduler, model in zip(
|
||||
@@ -435,6 +494,9 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
|
||||
)
|
||||
|
||||
def _compile_model(self):
|
||||
"""
|
||||
Compile the model.
|
||||
"""
|
||||
for i, model in enumerate(self._pina_models):
|
||||
if not isinstance(model, torch.nn.ModuleDict):
|
||||
self._pina_models[i] = self._perform_compilation(model)
|
||||
@@ -442,17 +504,29 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
|
||||
@property
|
||||
def models(self):
|
||||
"""
|
||||
The torch model."""
|
||||
The models used for training.
|
||||
|
||||
:return: The models used for training.
|
||||
:rtype: torch.nn.ModuleList
|
||||
"""
|
||||
return self._pina_models
|
||||
|
||||
@property
|
||||
def optimizers(self):
|
||||
"""
|
||||
The torch model."""
|
||||
The optimizers used for training.
|
||||
|
||||
:return: The optimizers used for training.
|
||||
:rtype: list[Optimizer]
|
||||
"""
|
||||
return self._pina_optimizers
|
||||
|
||||
@property
|
||||
def schedulers(self):
|
||||
"""
|
||||
The torch model."""
|
||||
The schedulers used for training.
|
||||
|
||||
:return: The schedulers used for training.
|
||||
:rtype: list[Scheduler]
|
||||
"""
|
||||
return self._pina_schedulers
|
||||
|
||||
Reference in New Issue
Block a user