batch_enhancement (#51)
This commit is contained in:
@@ -79,7 +79,7 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
@labels.setter
|
||||
def labels(self, labels):
|
||||
if len(labels) != self.shape[1]: # small check
|
||||
if len(labels) != self.shape[1]: # small check
|
||||
raise ValueError(
|
||||
'the tensor has not the same number of columns of '
|
||||
'the passed labels.')
|
||||
@@ -106,6 +106,14 @@ class LabelTensor(torch.Tensor):
|
||||
new.data = tmp.data
|
||||
return new
|
||||
|
||||
def select(self, *args, **kwargs):
|
||||
"""
|
||||
Performs Tensor selection. For more details, see :meth:`torch.Tensor.select`.
|
||||
"""
|
||||
tmp = super().select(*args, **kwargs)
|
||||
tmp._labels = self._labels
|
||||
return tmp
|
||||
|
||||
def extract(self, label_to_extract):
|
||||
"""
|
||||
Extract the subset of the original tensor by returning all the columns
|
||||
|
||||
Reference in New Issue
Block a user