Update dataset.py
This commit is contained in:
@@ -1,3 +1,7 @@
|
|||||||
|
""" """
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
|
||||||
|
|
||||||
class PinaDataset():
|
class PinaDataset():
|
||||||
|
|
||||||
def __init__(self, pinn) -> None:
|
def __init__(self, pinn) -> None:
|
||||||
@@ -39,83 +43,7 @@ class PinaDataset():
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self._len
|
return self._len
|
||||||
|
|
||||||
from torch.utils.data import Dataset, DataLoader
|
|
||||||
class LabelTensorDataset(Dataset):
|
|
||||||
def __init__(self, d):
|
|
||||||
for k, v in d.items():
|
|
||||||
setattr(self, k, v)
|
|
||||||
self.labels = list(d.keys())
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
print(index)
|
|
||||||
result = {}
|
|
||||||
for label in self.labels:
|
|
||||||
sample_tensor = getattr(self, label)
|
|
||||||
|
|
||||||
# print('porcodio')
|
|
||||||
# print(sample_tensor.shape[1])
|
|
||||||
# print(index)
|
|
||||||
# print(sample_tensor[index])
|
|
||||||
try:
|
|
||||||
result[label] = sample_tensor[index]
|
|
||||||
except IndexError:
|
|
||||||
result[label] = torch.tensor([])
|
|
||||||
|
|
||||||
print(result)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return max([len(getattr(self, label)) for label in self.labels])
|
|
||||||
|
|
||||||
|
|
||||||
class LabelTensorDataLoader(DataLoader):
|
|
||||||
|
|
||||||
def collate_fn(self, data):
|
|
||||||
print(data)
|
|
||||||
gggggggggg
|
|
||||||
# return dict(zip(self.pinn.input_pts.keys(), dataloaders))
|
|
||||||
|
|
||||||
# class SampleDataset(torch.utils.data.Dataset):
|
|
||||||
|
|
||||||
# def __init__(self, location, tensor):
|
|
||||||
# self._tensor = tensor
|
|
||||||
# self._location = location
|
|
||||||
# self._len = len(tensor)
|
|
||||||
|
|
||||||
# def __getitem__(self, index):
|
|
||||||
# tensor = self._tensor.select(0, index)
|
|
||||||
# return {self._location: tensor}
|
|
||||||
|
|
||||||
# def __len__(self):
|
|
||||||
# return self._len
|
|
||||||
|
|
||||||
from torch.utils.data import Dataset, DataLoader
|
|
||||||
class LabelTensorDataset(Dataset):
|
|
||||||
def __init__(self, d):
|
|
||||||
for k, v in d.items():
|
|
||||||
setattr(self, k, v)
|
|
||||||
self.labels = list(d.keys())
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
print(index)
|
|
||||||
result = {}
|
|
||||||
for label in self.labels:
|
|
||||||
sample_tensor = getattr(self, label)
|
|
||||||
|
|
||||||
# print('porcodio')
|
|
||||||
# print(sample_tensor.shape[1])
|
|
||||||
# print(index)
|
|
||||||
# print(sample_tensor[index])
|
|
||||||
try:
|
|
||||||
result[label] = sample_tensor[index]
|
|
||||||
except IndexError:
|
|
||||||
result[label] = torch.tensor([])
|
|
||||||
|
|
||||||
print(result)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return max([len(getattr(self, label)) for label in self.labels])
|
|
||||||
|
|
||||||
# TODO: working also for datapoints
|
# TODO: working also for datapoints
|
||||||
class DummyLoader:
|
class DummyLoader:
|
||||||
@@ -124,4 +52,4 @@ class DummyLoader:
|
|||||||
self.data = [data]
|
self.data = [data]
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return iter(self.data)
|
return iter(self.data)
|
||||||
|
|||||||
Reference in New Issue
Block a user