Label Tensor update (#188)

* Update test_label_tensor.py
* adding test

---------

Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local>
This commit is contained in:
Dario Coscia
2023-11-02 12:43:32 +01:00
committed by Nicola Demo
parent d556c592e0
commit cd5bc9a558
2 changed files with 47 additions and 7 deletions

View File

@@ -33,6 +33,14 @@ class LabelTensor(torch.Tensor):
[1.0246e-01, 9.5179e-01, 3.7043e-02], [1.0246e-01, 9.5179e-01, 3.7043e-02],
[9.6150e-01, 8.0656e-01, 8.3824e-01]]) [9.6150e-01, 8.0656e-01, 8.3824e-01]])
>>> tensor.extract('a') >>> tensor.extract('a')
tensor([[0.0671],
[0.9239],
[0.8927],
...,
[0.5819],
[0.1025],
[0.9615]])
>>> tensor['a']
tensor([[0.0671], tensor([[0.0671],
[0.9239], [0.9239],
[0.8927], [0.8927],
@@ -69,7 +77,7 @@ class LabelTensor(torch.Tensor):
'the passed labels.' 'the passed labels.'
) )
self._labels = labels self._labels = labels
@property @property
def labels(self): def labels(self):
"""Property decorator for labels """Property decorator for labels
@@ -100,7 +108,7 @@ class LabelTensor(torch.Tensor):
""" """
try: try:
out = LabelTensor(super().clone(*args, **kwargs), self.labels) out = LabelTensor(super().clone(*args, **kwargs), self.labels)
except: except: # this is used when the tensor loose the labels, notice it will create a bug! Kept for compatibility with Lightining
out = super().clone(*args, **kwargs) out = super().clone(*args, **kwargs)
return out return out
@@ -123,6 +131,24 @@ class LabelTensor(torch.Tensor):
tmp._labels = self._labels tmp._labels = self._labels
return tmp return tmp
def cuda(self, *args, **kwargs):
"""
Send Tensor to cuda. For more details, see :meth:`torch.Tensor.cuda`.
"""
tmp = super().cuda(*args, **kwargs)
new = self.__class__.clone(self)
new.data = tmp.data
return tmp
def cpu(self, *args, **kwargs):
"""
Send Tensor to cpu. For more details, see :meth:`torch.Tensor.cpu`.
"""
tmp = super().cpu(*args, **kwargs)
new = self.__class__.clone(self)
new.data = tmp.data
return tmp
def extract(self, label_to_extract): def extract(self, label_to_extract):
""" """
Extract the subset of the original tensor by returning all the columns Extract the subset of the original tensor by returning all the columns
@@ -149,7 +175,7 @@ class LabelTensor(torch.Tensor):
except ValueError: except ValueError:
raise ValueError(f'`{f}` not in the labels list') raise ValueError(f'`{f}` not in the labels list')
new_data = super(Tensor, self.T).__getitem__(indeces).float().T new_data = super(Tensor, self.T).__getitem__(indeces).T
new_labels = [self.labels[idx] for idx in indeces] new_labels = [self.labels[idx] for idx in indeces]
extracted_tensor = new_data.as_subclass(LabelTensor) extracted_tensor = new_data.as_subclass(LabelTensor)
@@ -196,8 +222,12 @@ class LabelTensor(torch.Tensor):
""" """
Return a copy of the selected tensor. Return a copy of the selected tensor.
""" """
selected_lt = super(Tensor, self).__getitem__(index)
if isinstance(index, str) or (isinstance(index, (tuple, list))and all(isinstance(a, str) for a in index)):
return self.extract(index)
selected_lt = super(Tensor, self).__getitem__(index)
try: try:
len_index = len(index) len_index = len(index)
except TypeError: except TypeError:

View File

@@ -27,7 +27,6 @@ def test_labels():
def test_extract(): def test_extract():
label_to_extract = ['a', 'c'] label_to_extract = ['a', 'c']
tensor = LabelTensor(data, labels) tensor = LabelTensor(data, labels)
print(tensor)
new = tensor.extract(label_to_extract) new = tensor.extract(label_to_extract)
assert new.labels == label_to_extract assert new.labels == label_to_extract
assert new.shape[1] == len(label_to_extract) assert new.shape[1] == len(label_to_extract)
@@ -58,7 +57,6 @@ def test_extract_order():
expected = torch.cat( expected = torch.cat(
(data[:, 2].reshape(-1, 1), data[:, 0].reshape(-1, 1)), (data[:, 2].reshape(-1, 1), data[:, 0].reshape(-1, 1)),
dim=1) dim=1)
print(expected)
assert new.labels == label_to_extract assert new.labels == label_to_extract
assert new.shape[1] == len(label_to_extract) assert new.shape[1] == len(label_to_extract)
assert torch.all(torch.isclose(expected, new)) assert torch.all(torch.isclose(expected, new))
@@ -83,6 +81,18 @@ def test_merge2():
def test_getitem(): def test_getitem():
tensor = LabelTensor(data, labels)
tensor_view = tensor['a']
assert tensor_view.labels == ['a']
assert torch.allclose(tensor_view.flatten(), data[:, 0])
tensor_view = tensor['a', 'c']
assert tensor_view.labels == ['a', 'c']
assert torch.allclose(tensor_view, data[:, 0::2])
def test_getitem2():
tensor = LabelTensor(data, labels) tensor = LabelTensor(data, labels)
tensor_view = tensor[:5] tensor_view = tensor[:5]
@@ -101,4 +111,4 @@ def test_slice():
tensor_view3 = tensor[:, 2] tensor_view3 = tensor[:, 2]
assert tensor_view3.labels == labels[2] assert tensor_view3.labels == labels[2]
assert torch.allclose(tensor_view3, data[:, 2].reshape(-1, 1)) assert torch.allclose(tensor_view3, data[:, 2].reshape(-1, 1))