Update utils.py

This commit is contained in:
Nicola Demo
2023-06-21 17:19:36 +02:00
parent 982af4a04d
commit bb066f7681

View File

@@ -187,82 +187,6 @@ def chebyshev_roots(n):
# return self._len
class LabelTensorDataset(Dataset):
def __init__(self, d):
for k, v in d.items():
setattr(self, k, v)
self.labels = list(d.keys())
def __getitem__(self, index):
print(index)
result = {}
for label in self.labels:
sample_tensor = getattr(self, label)
# print('porcodio')
# print(sample_tensor.shape[1])
# print(index)
# print(sample_tensor[index])
try:
result[label] = sample_tensor[index]
except IndexError:
result[label] = torch.tensor([])
print(result)
return result
def __len__(self):
return max([len(getattr(self, label)) for label in self.labels])
class LabelTensorDataLoader(DataLoader):
def collate_fn(self, data):
pass
# return dict(zip(self.pinn.input_pts.keys(), dataloaders))
# class SampleDataset(torch.utils.data.Dataset):
# def __init__(self, location, tensor):
# self._tensor = tensor
# self._location = location
# self._len = len(tensor)
# def __getitem__(self, index):
# tensor = self._tensor.select(0, index)
# return {self._location: tensor}
# def __len__(self):
# return self._len
class LabelTensorDataset(Dataset):
def __init__(self, d):
for k, v in d.items():
setattr(self, k, v)
self.labels = list(d.keys())
def __getitem__(self, index):
print(index)
result = {}
for label in self.labels:
sample_tensor = getattr(self, label)
# print('porcodio')
# print(sample_tensor.shape[1])
# print(index)
# print(sample_tensor[index])
try:
result[label] = sample_tensor[index]
except IndexError:
result[label] = torch.tensor([])
print(result)
return result
def __len__(self):
return max([len(getattr(self, label)) for label in self.labels])
class LabelTensorDataLoader(DataLoader):