* equation class * difference domain * dummy dataloader * writer class * refactoring and minor fix
126 lines
3.5 KiB
Python
126 lines
3.5 KiB
Python
class PinaDataset():
|
|
|
|
def __init__(self, pinn) -> None:
|
|
self.pinn = pinn
|
|
|
|
@property
|
|
def dataloader(self):
|
|
return self._create_dataloader()
|
|
|
|
@property
|
|
def dataset(self):
|
|
return [self.SampleDataset(key, val)
|
|
for key, val in self.input_pts.items()]
|
|
|
|
def _create_dataloader(self):
|
|
"""Private method for creating dataloader
|
|
|
|
:return: dataloader
|
|
:rtype: torch.utils.data.DataLoader
|
|
"""
|
|
if self.pinn.batch_size is None:
|
|
return {key: [{key: val}] for key, val in self.pinn.input_pts.items()}
|
|
|
|
def custom_collate(batch):
|
|
# extracting pts labels
|
|
_, pts = list(batch[0].items())[0]
|
|
labels = pts.labels
|
|
# calling default torch collate
|
|
collate_res = default_collate(batch)
|
|
# save collate result in dict
|
|
res = {}
|
|
for key, val in collate_res.items():
|
|
val.labels = labels
|
|
res[key] = val
|
|
def __getitem__(self, index):
|
|
tensor = self._tensor.select(0, index)
|
|
return {self._location: tensor}
|
|
|
|
def __len__(self):
|
|
return self._len
|
|
|
|
from torch.utils.data import Dataset, DataLoader
|
|
class LabelTensorDataset(Dataset):
|
|
def __init__(self, d):
|
|
for k, v in d.items():
|
|
setattr(self, k, v)
|
|
self.labels = list(d.keys())
|
|
|
|
def __getitem__(self, index):
|
|
print(index)
|
|
result = {}
|
|
for label in self.labels:
|
|
sample_tensor = getattr(self, label)
|
|
|
|
# print('porcodio')
|
|
# print(sample_tensor.shape[1])
|
|
# print(index)
|
|
# print(sample_tensor[index])
|
|
try:
|
|
result[label] = sample_tensor[index]
|
|
except IndexError:
|
|
result[label] = torch.tensor([])
|
|
|
|
print(result)
|
|
return result
|
|
|
|
def __len__(self):
|
|
return max([len(getattr(self, label)) for label in self.labels])
|
|
|
|
|
|
class LabelTensorDataLoader(DataLoader):
|
|
|
|
def collate_fn(self, data):
|
|
print(data)
|
|
gggggggggg
|
|
# return dict(zip(self.pinn.input_pts.keys(), dataloaders))
|
|
|
|
# class SampleDataset(torch.utils.data.Dataset):
|
|
|
|
# def __init__(self, location, tensor):
|
|
# self._tensor = tensor
|
|
# self._location = location
|
|
# self._len = len(tensor)
|
|
|
|
# def __getitem__(self, index):
|
|
# tensor = self._tensor.select(0, index)
|
|
# return {self._location: tensor}
|
|
|
|
# def __len__(self):
|
|
# return self._len
|
|
|
|
from torch.utils.data import Dataset, DataLoader
|
|
class LabelTensorDataset(Dataset):
|
|
def __init__(self, d):
|
|
for k, v in d.items():
|
|
setattr(self, k, v)
|
|
self.labels = list(d.keys())
|
|
|
|
def __getitem__(self, index):
|
|
print(index)
|
|
result = {}
|
|
for label in self.labels:
|
|
sample_tensor = getattr(self, label)
|
|
|
|
# print('porcodio')
|
|
# print(sample_tensor.shape[1])
|
|
# print(index)
|
|
# print(sample_tensor[index])
|
|
try:
|
|
result[label] = sample_tensor[index]
|
|
except IndexError:
|
|
result[label] = torch.tensor([])
|
|
|
|
print(result)
|
|
return result
|
|
|
|
def __len__(self):
|
|
return max([len(getattr(self, label)) for label in self.labels])
|
|
|
|
class DummyLoader:
|
|
|
|
def __init__(self, data) -> None:
|
|
self.data = [data]
|
|
|
|
def __iter__(self):
|
|
return iter(self.data) |