batch_enhancement (#51)

This commit is contained in:
Dario Coscia
2022-12-12 11:09:20 +01:00
committed by GitHub
parent d70f5e730a
commit dbd78c9cf3
4 changed files with 236 additions and 59 deletions

View File

@@ -79,7 +79,7 @@ class LabelTensor(torch.Tensor):
@labels.setter
def labels(self, labels):
if len(labels) != self.shape[1]: # small check
if len(labels) != self.shape[1]: # small check
raise ValueError(
'the tensor has not the same number of columns of '
'the passed labels.')
@@ -106,6 +106,14 @@ class LabelTensor(torch.Tensor):
new.data = tmp.data
return new
def select(self, *args, **kwargs):
"""
Performs Tensor selection. For more details, see :meth:`torch.Tensor.select`.
"""
tmp = super().select(*args, **kwargs)
tmp._labels = self._labels
return tmp
def extract(self, label_to_extract):
"""
Extract the subset of the original tensor by returning all the columns