fix rendering part 1

This commit is contained in:
giovanni
2025-03-13 22:31:26 +01:00
committed by Nicola Demo
parent 5d908a291d
commit e0ad4dc8a0
15 changed files with 89 additions and 63 deletions

View File

@@ -34,8 +34,9 @@ class Trainer(lightning.pytorch.Trainer):
"""
Initialization of the :class:`Trainer` class.
:param SolverInterface solver: A :class:`~pina.solver.SolverInterface`
solver used to solve a :class:`~pina.problem.AbstractProblem`.
:param SolverInterface solver: A
:class:`~pina.solver.solver.SolverInterface` solver used to solve a
:class:`~pina.problem.abstract_problem.AbstractProblem`.
:param int batch_size: The number of samples per batch to load.
If ``None``, all samples are loaded and data is not batched.
Default is ``None``.
@@ -56,11 +57,10 @@ class Trainer(lightning.pytorch.Trainer):
transfer to GPU. Default is ``False``.
:param bool shuffle: Whether to shuffle the data during training.
Default is ``True``.
:Keyword Arguments:
Additional keyword arguments that specify the training setup.
These can be selected from the pytorch-lightning Trainer API
<https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>_.
:param dict kwargs: Additional keyword arguments that specify the
training setup. These can be selected from the `pytorch-lightning
Trainer API
<https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_.
"""
# check consistency for init types
self._check_input_consistency(
@@ -132,7 +132,8 @@ class Trainer(lightning.pytorch.Trainer):
def _move_to_device(self):
"""
Moves the ``unknown_parameters`` of an instance of
:class:`~pina.problem.AbstractProblem` to the :class:`Trainer` device.
:class:`~pina.problem.abstract_problem.AbstractProblem` to the
:class:`Trainer` device.
"""
device = self._accelerator_connector._parallel_devices[0]
# move parameters to device
@@ -205,12 +206,16 @@ class Trainer(lightning.pytorch.Trainer):
def train(self, **kwargs):
"""
Manage the training process of the solver.
:param dict kwargs: Additional keyword arguments.
"""
return super().fit(self.solver, datamodule=self.data_module, **kwargs)
def test(self, **kwargs):
"""
Manage the test process of the solver.
:param dict kwargs: Additional keyword arguments.
"""
return super().test(self.solver, datamodule=self.data_module, **kwargs)