fix rendering part 1
This commit is contained in:
@@ -34,8 +34,9 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
"""
|
||||
Initialization of the :class:`Trainer` class.
|
||||
|
||||
:param SolverInterface solver: A :class:`~pina.solver.SolverInterface`
|
||||
solver used to solve a :class:`~pina.problem.AbstractProblem`.
|
||||
:param SolverInterface solver: A
|
||||
:class:`~pina.solver.solver.SolverInterface` solver used to solve a
|
||||
:class:`~pina.problem.abstract_problem.AbstractProblem`.
|
||||
:param int batch_size: The number of samples per batch to load.
|
||||
If ``None``, all samples are loaded and data is not batched.
|
||||
Default is ``None``.
|
||||
@@ -56,11 +57,10 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
transfer to GPU. Default is ``False``.
|
||||
:param bool shuffle: Whether to shuffle the data during training.
|
||||
Default is ``True``.
|
||||
|
||||
:Keyword Arguments:
|
||||
Additional keyword arguments that specify the training setup.
|
||||
These can be selected from the pytorch-lightning Trainer API
|
||||
<https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>_.
|
||||
:param dict kwargs: Additional keyword arguments that specify the
|
||||
training setup. These can be selected from the `pytorch-lightning
|
||||
Trainer API
|
||||
<https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_.
|
||||
"""
|
||||
# check consistency for init types
|
||||
self._check_input_consistency(
|
||||
@@ -132,7 +132,8 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
def _move_to_device(self):
|
||||
"""
|
||||
Moves the ``unknown_parameters`` of an instance of
|
||||
:class:`~pina.problem.AbstractProblem` to the :class:`Trainer` device.
|
||||
:class:`~pina.problem.abstract_problem.AbstractProblem` to the
|
||||
:class:`Trainer` device.
|
||||
"""
|
||||
device = self._accelerator_connector._parallel_devices[0]
|
||||
# move parameters to device
|
||||
@@ -205,12 +206,16 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
def train(self, **kwargs):
|
||||
"""
|
||||
Manage the training process of the solver.
|
||||
|
||||
:param dict kwargs: Additional keyword arguments.
|
||||
"""
|
||||
return super().fit(self.solver, datamodule=self.data_module, **kwargs)
|
||||
|
||||
def test(self, **kwargs):
|
||||
"""
|
||||
Manage the test process of the solver.
|
||||
|
||||
:param dict kwargs: Additional keyword arguments.
|
||||
"""
|
||||
return super().test(self.solver, datamodule=self.data_module, **kwargs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user