""" Solver module. """ import lightning.pytorch as pl from .utils import check_consistency from .dataset import DummyLoader from .solver import SolverInterface class Trainer(pl.Trainer): def __init__(self, solver, kwargs={}): super().__init__(**kwargs) # check inheritance consistency for solver check_consistency(solver, SolverInterface) self._model = solver # create dataloader if solver.problem.have_sampled_points is False: raise RuntimeError(f'Input points in {solver.problem.not_sampled_points} ' 'training are None. Please ' 'sample points in your problem by calling ' 'discretise_domain function before train ' 'in the provided locations.') # TODO: make a better dataloader for train self._loader = DummyLoader(solver.problem.input_pts) def train(self): # TODO add kwargs and lightining capabilities return super().fit(self._model, self._loader)