Documentation for v0.1 version (#199)
* Adding Equations, solving typos * improve _code.rst * the team rst and restuctore index.rst * fixing errors --------- Co-authored-by: Dario Coscia <dariocoscia@dhcp-015.eduroam.sissa.it>
This commit is contained in:
committed by
Nicola Demo
parent
3f9305d475
commit
8b7b61b3bd
@@ -1,17 +1,35 @@
|
||||
""" Solver module. """
|
||||
""" Trainer module. """
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
import pytorch_lightning
|
||||
from .utils import check_consistency
|
||||
from .dataset import SamplePointDataset, SamplePointLoader, DataPointDataset
|
||||
from .solvers.solver import SolverInterface
|
||||
|
||||
class Trainer(Trainer):
|
||||
class Trainer(pytorch_lightning.Trainer):
|
||||
|
||||
def __init__(self, solver, batch_size=None, **kwargs):
|
||||
"""
|
||||
PINA Trainer class for costumizing every aspect of training via flags.
|
||||
|
||||
:param solver: A pina:class:`SolverInterface` solver for the differential problem.
|
||||
:type solver: SolverInterface
|
||||
:param batch_size: How many samples per batch to load. If ``batch_size=None`` all
|
||||
samples are loaded and data are not batched, defaults to None.
|
||||
:type batch_size: int | None
|
||||
|
||||
:Keyword Arguments:
|
||||
The additional keyword arguments specify the training setup
|
||||
and can be choosen from the `pytorch-lightning
|
||||
Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_
|
||||
"""
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# check inheritance consistency for solver
|
||||
# check inheritance consistency for solver and batch size
|
||||
check_consistency(solver, SolverInterface)
|
||||
if batch_size is not None:
|
||||
check_consistency(batch_size, int)
|
||||
|
||||
self._model = solver
|
||||
self.batch_size = batch_size
|
||||
|
||||
@@ -45,7 +63,7 @@ class Trainer(Trainer):
|
||||
|
||||
def train(self, **kwargs):
|
||||
"""
|
||||
Train the solver.
|
||||
Train the solver method.
|
||||
"""
|
||||
return super().fit(self._model, train_dataloaders=self._loader, **kwargs)
|
||||
|
||||
@@ -54,4 +72,4 @@ class Trainer(Trainer):
|
||||
"""
|
||||
Returning trainer solver.
|
||||
"""
|
||||
return self._model
|
||||
return self._model
|
||||
|
||||
Reference in New Issue
Block a user