diff --git a/pina/label_tensor.py b/pina/label_tensor.py index 0803fb4..5344930 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -66,7 +66,25 @@ class LabelTensor(torch.Tensor): 'the tensor has not the same number of columns of ' 'the passed labels.' ) - self.labels = labels + self._labels = labels + + @property + def labels(self): + """Property decorator for labels + + :return: labels of self + :rtype: list + """ + return self._labels + + @labels.setter + def labels(self, labels): + if len(labels) != self.shape[1]: # small check + raise ValueError( + 'the tensor has not the same number of columns of ' + 'the passed labels.') + + self._labels = labels # assign the label def clone(self, *args, **kwargs): """ @@ -120,7 +138,6 @@ class LabelTensor(torch.Tensor): extracted_tensor = new_data.as_subclass(LabelTensor) extracted_tensor.labels = new_labels - return extracted_tensor def append(self, lt, mode='std'): diff --git a/tests/test_label_tensor.py b/tests/test_label_tensor.py index df112fa..c6f0edd 100644 --- a/tests/test_label_tensor.py +++ b/tests/test_label_tensor.py @@ -20,6 +20,8 @@ def test_labels(): tensor = LabelTensor(data, labels) assert isinstance(tensor, torch.Tensor) assert tensor.labels == labels + with pytest.raises(ValueError): + tensor.labels = labels[:-1] def test_extract():