Trainer train simplified, tests for load (#168)

- the arguments of Trainer.train now are passed to the fit
- unittest for load/restoring from checkpoint
This commit is contained in:
Nicola Demo
2023-07-25 17:23:12 +02:00
parent de0c3fca82
commit e84def3bf9
3 changed files with 44 additions and 7 deletions

View File

@@ -7,7 +7,7 @@ from .solvers.solver import SolverInterface
class Trainer(pl.Trainer):
def __init__(self, solver, kwargs={}):
def __init__(self, solver, **kwargs):
super().__init__(**kwargs)
# get accellerator
@@ -29,6 +29,6 @@ class Trainer(pl.Trainer):
self._loader = DummyLoader(solver.problem.input_pts, device)
def train(self): # TODO add kwargs and lightining capabilities
return super().fit(self._model, self._loader)
def train(self, **kwargs): # TODO add kwargs and lightining capabilities
return super().fit(self._model, self._loader, **kwargs)