CPU/GPU/TPU training (#159)
* device training --------- Co-authored-by: Dario Coscia <dcoscia@lovelace.maths.sissa.it> Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local>
This commit is contained in:
committed by
Nicola Demo
parent
38ecebd44b
commit
92e0e4920b
@@ -10,6 +10,9 @@ class Trainer(pl.Trainer):
|
||||
def __init__(self, solver, kwargs={}):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# get accellerator
|
||||
device = self._accelerator_connector._accelerator_flag
|
||||
|
||||
# check inheritance consistency for solver
|
||||
check_consistency(solver, SolverInterface)
|
||||
self._model = solver
|
||||
@@ -23,7 +26,7 @@ class Trainer(pl.Trainer):
|
||||
'in the provided locations.')
|
||||
|
||||
# TODO: make a better dataloader for train
|
||||
self._loader = DummyLoader(solver.problem.input_pts)
|
||||
self._loader = DummyLoader(solver.problem.input_pts, device)
|
||||
|
||||
|
||||
def train(self): # TODO add kwargs and lightining capabilities
|
||||
|
||||
Reference in New Issue
Block a user