add dataset and dataloader for sample points (#195)
* add dataset and dataloader for sample points * unittests
This commit is contained in:
@@ -96,6 +96,28 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
self._labels = labels # assign the label
|
||||
|
||||
@staticmethod
|
||||
def vstack(label_tensors):
|
||||
"""
|
||||
Stack tensors vertically. For more details, see
|
||||
:meth:`torch.vstack`.
|
||||
|
||||
:param list(LabelTensor) label_tensors: the tensors to stack. They need
|
||||
to have equal labels.
|
||||
:return: the stacked tensor
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
if len(label_tensors) == 0:
|
||||
return []
|
||||
|
||||
all_labels = [label for lt in label_tensors for label in lt.labels]
|
||||
if set(all_labels) != set(label_tensors[0].labels):
|
||||
raise RuntimeError('The tensors to stack have different labels')
|
||||
|
||||
labels = label_tensors[0].labels
|
||||
tensors = [lt.extract(labels) for lt in label_tensors]
|
||||
return LabelTensor(torch.vstack(tensors), labels)
|
||||
|
||||
# TODO remove try/ except thing IMPORTANT
|
||||
# make the label None of default
|
||||
def clone(self, *args, **kwargs):
|
||||
@@ -183,6 +205,18 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
return extracted_tensor
|
||||
|
||||
def detach(self):
|
||||
detached = super().detach()
|
||||
if hasattr(self, '_labels'):
|
||||
detached._labels = self._labels
|
||||
return detached
|
||||
|
||||
|
||||
def requires_grad_(self, mode = True) -> Tensor:
|
||||
lt = super().requires_grad_(mode)
|
||||
lt.labels = self.labels
|
||||
return lt
|
||||
|
||||
def append(self, lt, mode='std'):
|
||||
"""
|
||||
Return a copy of the merged tensors.
|
||||
@@ -232,7 +266,7 @@ class LabelTensor(torch.Tensor):
|
||||
len_index = len(index)
|
||||
except TypeError:
|
||||
len_index = 1
|
||||
|
||||
|
||||
if isinstance(index, int) or len_index == 1:
|
||||
if selected_lt.ndim == 1:
|
||||
selected_lt = selected_lt.reshape(1, -1)
|
||||
@@ -246,8 +280,14 @@ class LabelTensor(torch.Tensor):
|
||||
selected_lt.labels = [self.labels[i] for i in index[1]]
|
||||
else:
|
||||
selected_lt.labels = self.labels[index[1]]
|
||||
else:
|
||||
selected_lt.labels = self.labels
|
||||
|
||||
return selected_lt
|
||||
|
||||
@property
|
||||
def tensor(self):
|
||||
return self.as_subclass(Tensor)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return super().__len__()
|
||||
|
||||
Reference in New Issue
Block a user