CPU/GPU/TPU training (#159)

* device training

---------

Co-authored-by: Dario Coscia <dcoscia@lovelace.maths.sissa.it>
Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local>
This commit is contained in:
Dario Coscia
2023-07-19 17:19:08 +02:00
committed by Nicola Demo
parent 38ecebd44b
commit 92e0e4920b
4 changed files with 62 additions and 28 deletions

View File

@@ -1,5 +1,6 @@
""" """
from torch.utils.data import Dataset, DataLoader
import functools
class PinaDataset():
@@ -48,7 +49,30 @@ class PinaDataset():
# TODO: working also for datapoints
class DummyLoader:
def __init__(self, data) -> None:
def __init__(self, data, device) -> None:
# TODO: We need to make a dataset somehow
# and the PINADataset needs to have a method
# to send points to device
# now we simply do it here
# send data to device
def convert_tensors(pts, device):
pts = pts.to(device)
pts.requires_grad_(True)
pts.retain_grad()
return pts
for location, pts in data.items():
if isinstance(pts, (tuple, list)):
pts = tuple(map(functools.partial(convert_tensors, device=device),pts))
else:
pts = pts.to(device)
pts = pts.requires_grad_(True)
pts.retain_grad()
data[location] = pts
# iterator
self.data = [data]
def __iter__(self):