Adding features to label tensor class (#29)
* adding label.setter for runtime check on labels
This commit is contained in:
@@ -66,7 +66,25 @@ class LabelTensor(torch.Tensor):
|
|||||||
'the tensor has not the same number of columns of '
|
'the tensor has not the same number of columns of '
|
||||||
'the passed labels.'
|
'the passed labels.'
|
||||||
)
|
)
|
||||||
self.labels = labels
|
self._labels = labels
|
||||||
|
|
||||||
|
@property
|
||||||
|
def labels(self):
|
||||||
|
"""Property decorator for labels
|
||||||
|
|
||||||
|
:return: labels of self
|
||||||
|
:rtype: list
|
||||||
|
"""
|
||||||
|
return self._labels
|
||||||
|
|
||||||
|
@labels.setter
|
||||||
|
def labels(self, labels):
|
||||||
|
if len(labels) != self.shape[1]: # small check
|
||||||
|
raise ValueError(
|
||||||
|
'the tensor has not the same number of columns of '
|
||||||
|
'the passed labels.')
|
||||||
|
|
||||||
|
self._labels = labels # assign the label
|
||||||
|
|
||||||
def clone(self, *args, **kwargs):
|
def clone(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
@@ -120,7 +138,6 @@ class LabelTensor(torch.Tensor):
|
|||||||
extracted_tensor = new_data.as_subclass(LabelTensor)
|
extracted_tensor = new_data.as_subclass(LabelTensor)
|
||||||
extracted_tensor.labels = new_labels
|
extracted_tensor.labels = new_labels
|
||||||
|
|
||||||
|
|
||||||
return extracted_tensor
|
return extracted_tensor
|
||||||
|
|
||||||
def append(self, lt, mode='std'):
|
def append(self, lt, mode='std'):
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ def test_labels():
|
|||||||
tensor = LabelTensor(data, labels)
|
tensor = LabelTensor(data, labels)
|
||||||
assert isinstance(tensor, torch.Tensor)
|
assert isinstance(tensor, torch.Tensor)
|
||||||
assert tensor.labels == labels
|
assert tensor.labels == labels
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
tensor.labels = labels[:-1]
|
||||||
|
|
||||||
|
|
||||||
def test_extract():
|
def test_extract():
|
||||||
|
|||||||
Reference in New Issue
Block a user