diff --git a/pina/dataset.py b/pina/dataset.py index ac81f62..f8c41a2 100644 --- a/pina/dataset.py +++ b/pina/dataset.py @@ -1,3 +1,7 @@ +""" """ +from torch.utils.data import Dataset, DataLoader + + class PinaDataset(): def __init__(self, pinn) -> None: @@ -39,83 +43,7 @@ class PinaDataset(): def __len__(self): return self._len -from torch.utils.data import Dataset, DataLoader -class LabelTensorDataset(Dataset): - def __init__(self, d): - for k, v in d.items(): - setattr(self, k, v) - self.labels = list(d.keys()) - - def __getitem__(self, index): - print(index) - result = {} - for label in self.labels: - sample_tensor = getattr(self, label) - # print('porcodio') - # print(sample_tensor.shape[1]) - # print(index) - # print(sample_tensor[index]) - try: - result[label] = sample_tensor[index] - except IndexError: - result[label] = torch.tensor([]) - - print(result) - return result - - def __len__(self): - return max([len(getattr(self, label)) for label in self.labels]) - - -class LabelTensorDataLoader(DataLoader): - - def collate_fn(self, data): - print(data) - gggggggggg -# return dict(zip(self.pinn.input_pts.keys(), dataloaders)) - -# class SampleDataset(torch.utils.data.Dataset): - -# def __init__(self, location, tensor): -# self._tensor = tensor -# self._location = location -# self._len = len(tensor) - -# def __getitem__(self, index): -# tensor = self._tensor.select(0, index) -# return {self._location: tensor} - -# def __len__(self): -# return self._len - -from torch.utils.data import Dataset, DataLoader -class LabelTensorDataset(Dataset): - def __init__(self, d): - for k, v in d.items(): - setattr(self, k, v) - self.labels = list(d.keys()) - - def __getitem__(self, index): - print(index) - result = {} - for label in self.labels: - sample_tensor = getattr(self, label) - - # print('porcodio') - # print(sample_tensor.shape[1]) - # print(index) - # print(sample_tensor[index]) - try: - result[label] = sample_tensor[index] - except IndexError: - result[label] = torch.tensor([]) - - print(result) - return result - - def __len__(self): - return max([len(getattr(self, label)) for label in self.labels]) # TODO: working also for datapoints class DummyLoader: @@ -124,4 +52,4 @@ class DummyLoader: self.data = [data] def __iter__(self): - return iter(self.data) \ No newline at end of file + return iter(self.data)