CPU/GPU/TPU training (#159)

* device training

---------

Co-authored-by: Dario Coscia <dcoscia@lovelace.maths.sissa.it>
Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local>
This commit is contained in:
Dario Coscia
2023-07-19 17:19:08 +02:00
committed by Nicola Demo
parent 38ecebd44b
commit 92e0e4920b
4 changed files with 62 additions and 28 deletions

View File

@@ -10,6 +10,9 @@ class Trainer(pl.Trainer):
def __init__(self, solver, kwargs={}):
super().__init__(**kwargs)
# get accellerator
device = self._accelerator_connector._accelerator_flag
# check inheritance consistency for solver
check_consistency(solver, SolverInterface)
self._model = solver
@@ -23,7 +26,7 @@ class Trainer(pl.Trainer):
'in the provided locations.')
# TODO: make a better dataloader for train
self._loader = DummyLoader(solver.problem.input_pts)
self._loader = DummyLoader(solver.problem.input_pts, device)
def train(self): # TODO add kwargs and lightining capabilities