fix rendering part 2

This commit is contained in:
giovanni
2025-03-14 00:10:18 +01:00
committed by Nicola Demo
parent e0ad4dc8a0
commit d2e3f458ab
17 changed files with 217 additions and 147 deletions

View File

@@ -75,9 +75,9 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:return: The computed loss for the all conditions in the batch, casted
to a subclass of `torch.Tensor`. It should return a dict containing
the condition name and the associated scalar loss.
:return: The losses computed for all conditions in the batch, casted
to a subclass of :class:`torch.Tensor`. It should return a dict
containing the condition name and the associated scalar loss.
:rtype: dict
"""
losses = self.optimization_cycle(batch)
@@ -92,7 +92,8 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
"""
Solver training step.
:param list[tuple[str, dict]] batch: The batch element in the dataloader.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:return: The loss of the training step.
:rtype: LabelTensor
"""
@@ -104,7 +105,8 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
"""
Solver validation step.
:param list[tuple[str, dict]] batch: The batch element in the dataloader.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
"""
loss = self._optimization_cycle(batch=batch)
self.store_log("val_loss", loss, self.get_batch_size(batch))
@@ -113,7 +115,8 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
"""
Solver test step.
:param list[tuple[str, dict]] batch: The batch element in the dataloader.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
"""
loss = self._optimization_cycle(batch=batch)
self.store_log("test_loss", loss, self.get_batch_size(batch))
@@ -138,6 +141,10 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
def forward(self, *args, **kwargs):
"""
Abstract method for the forward pass implementation.
:param args: The input tensor.
:type args: torch.Tensor | LabelTensor
:param dict kwargs: Additional keyword arguments.
"""
@abstractmethod
@@ -145,10 +152,11 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
"""
The optimization cycle for the solvers.
:param list[tuple[str, dict]] batch: The batch element in the dataloader.
:return: The computed loss for the all conditions in the batch, casted
to a subclass of `torch.Tensor`. It should return a dict containing
the condition name and the associated scalar loss.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:return: The losses computed for all conditions in the batch, casted
to a subclass of :class:`torch.Tensor`. It should return a dict
containing the condition name and the associated scalar loss.
:rtype: dict
"""
@@ -187,7 +195,8 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
"""
Get the batch size.
:param list[tuple[str, dict]] batch: The batch element in the dataloader.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:return: The size of the batch.
:rtype: int
"""
@@ -296,10 +305,11 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
:param AbstractProblem problem: The problem to be solved.
:param torch.nn.Module model: The neural network model to be used.
:param Optimizer optimizer: The optimizer to be used.
If `None`, the Adam optimizer is used. Default is ``None``.
If `None`, the :class:`torch.optim.Adam` optimizer is
used. Default is ``None``.
:param Scheduler scheduler: The scheduler to be used.
If `None`, the constant learning rate scheduler is used.
Default is ``None``.
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If `None`, no weighting schema is used. Default is ``None``.
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
@@ -341,7 +351,7 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
Optimizer configuration for the solver.
:return: The optimizer and the scheduler
:rtype: tuple(list, list)
:rtype: tuple[list[Optimizer], list[Scheduler]]
"""
self.optimizer.hook(self.model.parameters())
self.scheduler.hook(self.optimizer)
@@ -421,11 +431,11 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
:param models: The neural network models to be used.
:type model: list[torch.nn.Module] | tuple[torch.nn.Module]
:param list[Optimizer] optimizers: The optimizers to be used.
If `None`, the Adam optimizer is used for all models.
Default is ``None``.
:param list[Scheduler] schedulers: The schedulers to be used.
If `None`, the constant learning rate scheduler is used for all the
If `None`, the :class:`torch.optim.Adam` optimizer is used for all
models. Default is ``None``.
:param list[Scheduler] schedulers: The schedulers to be used.
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used for all the models. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If `None`, no weighting schema is used. Default is ``None``.
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
@@ -480,7 +490,7 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
Optimizer configuration for the solver.
:return: The optimizer and the scheduler
:rtype: tuple(list, list)
:rtype: tuple[list[Optimizer], list[Scheduler]]
"""
for optimizer, scheduler, model in zip(
self.optimizers, self.schedulers, self.models