fix rendering part 2
This commit is contained in:
@@ -75,9 +75,9 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
|
||||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||
tuple containing a condition name and a dictionary of points.
|
||||
:return: The computed loss for the all conditions in the batch, casted
|
||||
to a subclass of `torch.Tensor`. It should return a dict containing
|
||||
the condition name and the associated scalar loss.
|
||||
:return: The losses computed for all conditions in the batch, casted
|
||||
to a subclass of :class:`torch.Tensor`. It should return a dict
|
||||
containing the condition name and the associated scalar loss.
|
||||
:rtype: dict
|
||||
"""
|
||||
losses = self.optimization_cycle(batch)
|
||||
@@ -92,7 +92,8 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
"""
|
||||
Solver training step.
|
||||
|
||||
:param list[tuple[str, dict]] batch: The batch element in the dataloader.
|
||||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||
tuple containing a condition name and a dictionary of points.
|
||||
:return: The loss of the training step.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
@@ -104,7 +105,8 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
"""
|
||||
Solver validation step.
|
||||
|
||||
:param list[tuple[str, dict]] batch: The batch element in the dataloader.
|
||||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||
tuple containing a condition name and a dictionary of points.
|
||||
"""
|
||||
loss = self._optimization_cycle(batch=batch)
|
||||
self.store_log("val_loss", loss, self.get_batch_size(batch))
|
||||
@@ -113,7 +115,8 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
"""
|
||||
Solver test step.
|
||||
|
||||
:param list[tuple[str, dict]] batch: The batch element in the dataloader.
|
||||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||
tuple containing a condition name and a dictionary of points.
|
||||
"""
|
||||
loss = self._optimization_cycle(batch=batch)
|
||||
self.store_log("test_loss", loss, self.get_batch_size(batch))
|
||||
@@ -138,6 +141,10 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
def forward(self, *args, **kwargs):
|
||||
"""
|
||||
Abstract method for the forward pass implementation.
|
||||
|
||||
:param args: The input tensor.
|
||||
:type args: torch.Tensor | LabelTensor
|
||||
:param dict kwargs: Additional keyword arguments.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@@ -145,10 +152,11 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
"""
|
||||
The optimization cycle for the solvers.
|
||||
|
||||
:param list[tuple[str, dict]] batch: The batch element in the dataloader.
|
||||
:return: The computed loss for the all conditions in the batch, casted
|
||||
to a subclass of `torch.Tensor`. It should return a dict containing
|
||||
the condition name and the associated scalar loss.
|
||||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||
tuple containing a condition name and a dictionary of points.
|
||||
:return: The losses computed for all conditions in the batch, casted
|
||||
to a subclass of :class:`torch.Tensor`. It should return a dict
|
||||
containing the condition name and the associated scalar loss.
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
@@ -187,7 +195,8 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
"""
|
||||
Get the batch size.
|
||||
|
||||
:param list[tuple[str, dict]] batch: The batch element in the dataloader.
|
||||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||
tuple containing a condition name and a dictionary of points.
|
||||
:return: The size of the batch.
|
||||
:rtype: int
|
||||
"""
|
||||
@@ -296,10 +305,11 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
|
||||
:param AbstractProblem problem: The problem to be solved.
|
||||
:param torch.nn.Module model: The neural network model to be used.
|
||||
:param Optimizer optimizer: The optimizer to be used.
|
||||
If `None`, the Adam optimizer is used. Default is ``None``.
|
||||
If `None`, the :class:`torch.optim.Adam` optimizer is
|
||||
used. Default is ``None``.
|
||||
:param Scheduler scheduler: The scheduler to be used.
|
||||
If `None`, the constant learning rate scheduler is used.
|
||||
Default is ``None``.
|
||||
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
|
||||
scheduler is used. Default is ``None``.
|
||||
:param WeightingInterface weighting: The weighting schema to be used.
|
||||
If `None`, no weighting schema is used. Default is ``None``.
|
||||
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
|
||||
@@ -341,7 +351,7 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
|
||||
Optimizer configuration for the solver.
|
||||
|
||||
:return: The optimizer and the scheduler
|
||||
:rtype: tuple(list, list)
|
||||
:rtype: tuple[list[Optimizer], list[Scheduler]]
|
||||
"""
|
||||
self.optimizer.hook(self.model.parameters())
|
||||
self.scheduler.hook(self.optimizer)
|
||||
@@ -421,11 +431,11 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
|
||||
:param models: The neural network models to be used.
|
||||
:type model: list[torch.nn.Module] | tuple[torch.nn.Module]
|
||||
:param list[Optimizer] optimizers: The optimizers to be used.
|
||||
If `None`, the Adam optimizer is used for all models.
|
||||
Default is ``None``.
|
||||
:param list[Scheduler] schedulers: The schedulers to be used.
|
||||
If `None`, the constant learning rate scheduler is used for all the
|
||||
If `None`, the :class:`torch.optim.Adam` optimizer is used for all
|
||||
models. Default is ``None``.
|
||||
:param list[Scheduler] schedulers: The schedulers to be used.
|
||||
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
|
||||
scheduler is used for all the models. Default is ``None``.
|
||||
:param WeightingInterface weighting: The weighting schema to be used.
|
||||
If `None`, no weighting schema is used. Default is ``None``.
|
||||
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
|
||||
@@ -480,7 +490,7 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
|
||||
Optimizer configuration for the solver.
|
||||
|
||||
:return: The optimizer and the scheduler
|
||||
:rtype: tuple(list, list)
|
||||
:rtype: tuple[list[Optimizer], list[Scheduler]]
|
||||
"""
|
||||
for optimizer, scheduler, model in zip(
|
||||
self.optimizers, self.schedulers, self.models
|
||||
|
||||
Reference in New Issue
Block a user