From bb066f76817a473e5435cadaf54ba41cdf7d9213 Mon Sep 17 00:00:00 2001 From: Nicola Demo Date: Wed, 21 Jun 2023 17:19:36 +0200 Subject: [PATCH] Update utils.py --- pina/utils.py | 76 --------------------------------------------------- 1 file changed, 76 deletions(-) diff --git a/pina/utils.py b/pina/utils.py index a6bf0c5..bff9d7b 100644 --- a/pina/utils.py +++ b/pina/utils.py @@ -187,82 +187,6 @@ def chebyshev_roots(n): # return self._len -class LabelTensorDataset(Dataset): - def __init__(self, d): - for k, v in d.items(): - setattr(self, k, v) - self.labels = list(d.keys()) - - def __getitem__(self, index): - print(index) - result = {} - for label in self.labels: - sample_tensor = getattr(self, label) - - # print('porcodio') - # print(sample_tensor.shape[1]) - # print(index) - # print(sample_tensor[index]) - try: - result[label] = sample_tensor[index] - except IndexError: - result[label] = torch.tensor([]) - - print(result) - return result - - def __len__(self): - return max([len(getattr(self, label)) for label in self.labels]) - - -class LabelTensorDataLoader(DataLoader): - - def collate_fn(self, data): - pass -# return dict(zip(self.pinn.input_pts.keys(), dataloaders)) - -# class SampleDataset(torch.utils.data.Dataset): - -# def __init__(self, location, tensor): -# self._tensor = tensor -# self._location = location -# self._len = len(tensor) - -# def __getitem__(self, index): -# tensor = self._tensor.select(0, index) -# return {self._location: tensor} - -# def __len__(self): -# return self._len - - -class LabelTensorDataset(Dataset): - def __init__(self, d): - for k, v in d.items(): - setattr(self, k, v) - self.labels = list(d.keys()) - - def __getitem__(self, index): - print(index) - result = {} - for label in self.labels: - sample_tensor = getattr(self, label) - - # print('porcodio') - # print(sample_tensor.shape[1]) - # print(index) - # print(sample_tensor[index]) - try: - result[label] = sample_tensor[index] - except IndexError: - result[label] = torch.tensor([]) - - print(result) - return result - - def __len__(self): - return max([len(getattr(self, label)) for label in self.labels]) - class LabelTensorDataLoader(DataLoader):