Adding features to label tensor class (#29)

* adding label.setter for runtime check on labels
This commit is contained in:
Dario Coscia
2022-11-04 17:26:05 +01:00
committed by GitHub
parent d06f28de7b
commit a92a764844
2 changed files with 21 additions and 2 deletions

View File

@@ -66,7 +66,25 @@ class LabelTensor(torch.Tensor):
'the tensor has not the same number of columns of ' 'the tensor has not the same number of columns of '
'the passed labels.' 'the passed labels.'
) )
self.labels = labels self._labels = labels
@property
def labels(self):
"""Property decorator for labels
:return: labels of self
:rtype: list
"""
return self._labels
@labels.setter
def labels(self, labels):
if len(labels) != self.shape[1]: # small check
raise ValueError(
'the tensor has not the same number of columns of '
'the passed labels.')
self._labels = labels # assign the label
def clone(self, *args, **kwargs): def clone(self, *args, **kwargs):
""" """
@@ -120,7 +138,6 @@ class LabelTensor(torch.Tensor):
extracted_tensor = new_data.as_subclass(LabelTensor) extracted_tensor = new_data.as_subclass(LabelTensor)
extracted_tensor.labels = new_labels extracted_tensor.labels = new_labels
return extracted_tensor return extracted_tensor
def append(self, lt, mode='std'): def append(self, lt, mode='std'):

View File

@@ -20,6 +20,8 @@ def test_labels():
tensor = LabelTensor(data, labels) tensor = LabelTensor(data, labels)
assert isinstance(tensor, torch.Tensor) assert isinstance(tensor, torch.Tensor)
assert tensor.labels == labels assert tensor.labels == labels
with pytest.raises(ValueError):
tensor.labels = labels[:-1]
def test_extract(): def test_extract():