CPU/GPU/TPU training (#159)
* device training --------- Co-authored-by: Dario Coscia <dcoscia@lovelace.maths.sissa.it> Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local>
This commit is contained in:
committed by
Nicola Demo
parent
38ecebd44b
commit
92e0e4920b
@@ -1,5 +1,6 @@
|
||||
""" """
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import functools
|
||||
|
||||
|
||||
class PinaDataset():
|
||||
@@ -48,7 +49,30 @@ class PinaDataset():
|
||||
# TODO: working also for datapoints
|
||||
class DummyLoader:
|
||||
|
||||
def __init__(self, data) -> None:
|
||||
def __init__(self, data, device) -> None:
|
||||
|
||||
# TODO: We need to make a dataset somehow
|
||||
# and the PINADataset needs to have a method
|
||||
# to send points to device
|
||||
# now we simply do it here
|
||||
# send data to device
|
||||
def convert_tensors(pts, device):
|
||||
pts = pts.to(device)
|
||||
pts.requires_grad_(True)
|
||||
pts.retain_grad()
|
||||
return pts
|
||||
|
||||
for location, pts in data.items():
|
||||
if isinstance(pts, (tuple, list)):
|
||||
pts = tuple(map(functools.partial(convert_tensors, device=device),pts))
|
||||
else:
|
||||
pts = pts.to(device)
|
||||
pts = pts.requires_grad_(True)
|
||||
pts.retain_grad()
|
||||
|
||||
data[location] = pts
|
||||
|
||||
# iterator
|
||||
self.data = [data]
|
||||
|
||||
def __iter__(self):
|
||||
|
||||
Reference in New Issue
Block a user