add dataset and dataloader for sample points (#195)

* add dataset and dataloader for sample points
* unittests
This commit is contained in:
Nicola Demo
2023-11-07 11:34:44 +01:00
parent cd5bc9a558
commit d654259428
19 changed files with 581 additions and 196 deletions

View File

@@ -96,6 +96,28 @@ class LabelTensor(torch.Tensor):
self._labels = labels # assign the label
@staticmethod
def vstack(label_tensors):
"""
Stack tensors vertically. For more details, see
:meth:`torch.vstack`.
:param list(LabelTensor) label_tensors: the tensors to stack. They need
to have equal labels.
:return: the stacked tensor
:rtype: LabelTensor
"""
if len(label_tensors) == 0:
return []
all_labels = [label for lt in label_tensors for label in lt.labels]
if set(all_labels) != set(label_tensors[0].labels):
raise RuntimeError('The tensors to stack have different labels')
labels = label_tensors[0].labels
tensors = [lt.extract(labels) for lt in label_tensors]
return LabelTensor(torch.vstack(tensors), labels)
# TODO remove try/ except thing IMPORTANT
# make the label None of default
def clone(self, *args, **kwargs):
@@ -183,6 +205,18 @@ class LabelTensor(torch.Tensor):
return extracted_tensor
def detach(self):
detached = super().detach()
if hasattr(self, '_labels'):
detached._labels = self._labels
return detached
def requires_grad_(self, mode = True) -> Tensor:
lt = super().requires_grad_(mode)
lt.labels = self.labels
return lt
def append(self, lt, mode='std'):
"""
Return a copy of the merged tensors.
@@ -232,7 +266,7 @@ class LabelTensor(torch.Tensor):
len_index = len(index)
except TypeError:
len_index = 1
if isinstance(index, int) or len_index == 1:
if selected_lt.ndim == 1:
selected_lt = selected_lt.reshape(1, -1)
@@ -246,8 +280,14 @@ class LabelTensor(torch.Tensor):
selected_lt.labels = [self.labels[i] for i in index[1]]
else:
selected_lt.labels = self.labels[index[1]]
else:
selected_lt.labels = self.labels
return selected_lt
@property
def tensor(self):
return self.as_subclass(Tensor)
def __len__(self) -> int:
return super().__len__()