Trainer train simplified, tests for load (#168)
- the arguments of Trainer.train now are passed to the fit - unittest for load/restoring from checkpoint
This commit is contained in:
@@ -7,7 +7,7 @@ from .solvers.solver import SolverInterface
|
||||
|
||||
class Trainer(pl.Trainer):
|
||||
|
||||
def __init__(self, solver, kwargs={}):
|
||||
def __init__(self, solver, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# get accellerator
|
||||
@@ -29,6 +29,6 @@ class Trainer(pl.Trainer):
|
||||
self._loader = DummyLoader(solver.problem.input_pts, device)
|
||||
|
||||
|
||||
def train(self): # TODO add kwargs and lightining capabilities
|
||||
return super().fit(self._model, self._loader)
|
||||
def train(self, **kwargs): # TODO add kwargs and lightining capabilities
|
||||
return super().fit(self._model, self._loader, **kwargs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user