fix rendering part 2

This commit is contained in:
giovanni
2025-03-14 00:10:18 +01:00
committed by FilippoOlivo
parent 3d842cb9ec
commit 76f5be85ea
17 changed files with 217 additions and 147 deletions

View File

@@ -52,13 +52,14 @@ class SupervisedSolver(SingleSolverInterface):
:param AbstractProblem problem: The problem to be solved.
:param torch.nn.Module model: The neural network model to be used.
:param torch.nn.Module loss: The loss function to be minimized.
If `None`, the Mean Squared Error (MSE) loss is used.
If `None`, the :class:`torch.nn.MSELoss` loss is used.
Default is `None`.
:param torch.optim.Optimizer optimizer: The optimizer to be used.
If `None`, the Adam optimizer is used. Default is ``None``.
:param torch.optim.LRScheduler scheduler: Learning rate scheduler.
If `None`, the constant learning rate scheduler is used.
:param Optimizer optimizer: The optimizer to be used.
If `None`, the :class:`torch.optim.Adam` optimizer is used.
Default is ``None``.
:param Scheduler scheduler: Learning rate scheduler.
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If `None`, no weighting schema is used. Default is ``None``.
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
@@ -86,10 +87,11 @@ class SupervisedSolver(SingleSolverInterface):
"""
The optimization cycle for the solvers.
:param list[tuple[str, dict]] batch: The batch element in the dataloader.
:return: The computed loss for the all conditions in the batch, casted
to a subclass of `torch.Tensor`. It should return a dict containing
the condition name and the associated scalar loss.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:return: The losses computed for all conditions in the batch, casted
to a subclass of :class:`torch.Tensor`. It should return a dict
containing the condition name and the associated scalar loss.
:rtype: dict
"""
condition_loss = {}