CPU/GPU/TPU training (#159)
* device training --------- Co-authored-by: Dario Coscia <dcoscia@lovelace.maths.sissa.it> Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local>
This commit is contained in:
committed by
Nicola Demo
parent
38ecebd44b
commit
92e0e4920b
@@ -1,5 +1,6 @@
|
||||
""" """
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import functools
|
||||
|
||||
|
||||
class PinaDataset():
|
||||
@@ -48,7 +49,30 @@ class PinaDataset():
|
||||
# TODO: working also for datapoints
|
||||
class DummyLoader:
|
||||
|
||||
def __init__(self, data) -> None:
|
||||
def __init__(self, data, device) -> None:
|
||||
|
||||
# TODO: We need to make a dataset somehow
|
||||
# and the PINADataset needs to have a method
|
||||
# to send points to device
|
||||
# now we simply do it here
|
||||
# send data to device
|
||||
def convert_tensors(pts, device):
|
||||
pts = pts.to(device)
|
||||
pts.requires_grad_(True)
|
||||
pts.retain_grad()
|
||||
return pts
|
||||
|
||||
for location, pts in data.items():
|
||||
if isinstance(pts, (tuple, list)):
|
||||
pts = tuple(map(functools.partial(convert_tensors, device=device),pts))
|
||||
else:
|
||||
pts = pts.to(device)
|
||||
pts = pts.requires_grad_(True)
|
||||
pts.retain_grad()
|
||||
|
||||
data[location] = pts
|
||||
|
||||
# iterator
|
||||
self.data = [data]
|
||||
|
||||
def __iter__(self):
|
||||
|
||||
@@ -111,7 +111,7 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
continue
|
||||
self.input_pts[condition_name] = samples
|
||||
|
||||
def discretise_domain(self, n, mode = 'random', variables = 'all', locations = 'all', device=None):
|
||||
def discretise_domain(self, n, mode = 'random', variables = 'all', locations = 'all'):
|
||||
"""
|
||||
Generate a set of points to span the `Location` of all the conditions of
|
||||
the problem.
|
||||
@@ -129,9 +129,9 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
:type locations: str, optional
|
||||
|
||||
:Example:
|
||||
>>> pinn.span_pts(n=10, mode='grid')
|
||||
>>> pinn.span_pts(n=10, mode='grid', location=['bound1'])
|
||||
>>> pinn.span_pts(n=10, mode='grid', variables=['x'])
|
||||
>>> pinn.discretise_domain(n=10, mode='grid')
|
||||
>>> pinn.discretise_domain(n=10, mode='grid', location=['bound1'])
|
||||
>>> pinn.discretise_domain(n=10, mode='grid', variables=['x'])
|
||||
|
||||
.. warning::
|
||||
``random`` is currently the only implemented ``mode`` for all geometries, i.e.
|
||||
@@ -200,12 +200,6 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
# the condition is sampled if input_pts contains all labels
|
||||
if sorted(self.input_pts[location].labels) == sorted(self.input_variables):
|
||||
self._have_sampled_points[location] = True
|
||||
# setting device
|
||||
if device:
|
||||
self.input_pts[location] = self.input_pts[location].to(device=device) #TODO better fix
|
||||
# setting the grad
|
||||
self.input_pts[location].requires_grad_(True)
|
||||
self.input_pts[location].retain_grad()
|
||||
|
||||
@property
|
||||
def have_sampled_points(self):
|
||||
|
||||
@@ -10,6 +10,9 @@ class Trainer(pl.Trainer):
|
||||
def __init__(self, solver, kwargs={}):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# get accellerator
|
||||
device = self._accelerator_connector._accelerator_flag
|
||||
|
||||
# check inheritance consistency for solver
|
||||
check_consistency(solver, SolverInterface)
|
||||
self._model = solver
|
||||
@@ -23,7 +26,7 @@ class Trainer(pl.Trainer):
|
||||
'in the provided locations.')
|
||||
|
||||
# TODO: make a better dataloader for train
|
||||
self._loader = DummyLoader(solver.problem.input_pts)
|
||||
self._loader = DummyLoader(solver.problem.input_pts, device)
|
||||
|
||||
|
||||
def train(self): # TODO add kwargs and lightining capabilities
|
||||
|
||||
Reference in New Issue
Block a user