Documentation for v0.1 version (#199)

* Adding Equations, solving typos
* improve _code.rst
* the team rst and restuctore index.rst
* fixing errors

---------

Co-authored-by: Dario Coscia <dariocoscia@dhcp-015.eduroam.sissa.it>
This commit is contained in:
Dario Coscia
2023-11-08 14:39:00 +01:00
committed by Nicola Demo
parent 3f9305d475
commit 8b7b61b3bd
144 changed files with 2741 additions and 1766 deletions

View File

@@ -1,17 +1,35 @@
""" Solver module. """
""" Trainer module. """
from pytorch_lightning import Trainer
import pytorch_lightning
from .utils import check_consistency
from .dataset import SamplePointDataset, SamplePointLoader, DataPointDataset
from .solvers.solver import SolverInterface
class Trainer(Trainer):
class Trainer(pytorch_lightning.Trainer):
def __init__(self, solver, batch_size=None, **kwargs):
"""
PINA Trainer class for costumizing every aspect of training via flags.
:param solver: A pina:class:`SolverInterface` solver for the differential problem.
:type solver: SolverInterface
:param batch_size: How many samples per batch to load. If ``batch_size=None`` all
samples are loaded and data are not batched, defaults to None.
:type batch_size: int | None
:Keyword Arguments:
The additional keyword arguments specify the training setup
and can be choosen from the `pytorch-lightning
Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_
"""
super().__init__(**kwargs)
# check inheritance consistency for solver
# check inheritance consistency for solver and batch size
check_consistency(solver, SolverInterface)
if batch_size is not None:
check_consistency(batch_size, int)
self._model = solver
self.batch_size = batch_size
@@ -45,7 +63,7 @@ class Trainer(Trainer):
def train(self, **kwargs):
"""
Train the solver.
Train the solver method.
"""
return super().fit(self._model, train_dataloaders=self._loader, **kwargs)
@@ -54,4 +72,4 @@ class Trainer(Trainer):
"""
Returning trainer solver.
"""
return self._model
return self._model