Label Tensor update (#188)

* Update test_label_tensor.py
* adding test

---------

Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local>
This commit is contained in:
Dario Coscia
2023-11-02 12:43:32 +01:00
committed by Nicola Demo
parent d556c592e0
commit cd5bc9a558
2 changed files with 47 additions and 7 deletions

View File

@@ -33,6 +33,14 @@ class LabelTensor(torch.Tensor):
[1.0246e-01, 9.5179e-01, 3.7043e-02],
[9.6150e-01, 8.0656e-01, 8.3824e-01]])
>>> tensor.extract('a')
tensor([[0.0671],
[0.9239],
[0.8927],
...,
[0.5819],
[0.1025],
[0.9615]])
>>> tensor['a']
tensor([[0.0671],
[0.9239],
[0.8927],
@@ -69,7 +77,7 @@ class LabelTensor(torch.Tensor):
'the passed labels.'
)
self._labels = labels
@property
def labels(self):
"""Property decorator for labels
@@ -100,7 +108,7 @@ class LabelTensor(torch.Tensor):
"""
try:
out = LabelTensor(super().clone(*args, **kwargs), self.labels)
except:
except: # this is used when the tensor loose the labels, notice it will create a bug! Kept for compatibility with Lightining
out = super().clone(*args, **kwargs)
return out
@@ -123,6 +131,24 @@ class LabelTensor(torch.Tensor):
tmp._labels = self._labels
return tmp
def cuda(self, *args, **kwargs):
"""
Send Tensor to cuda. For more details, see :meth:`torch.Tensor.cuda`.
"""
tmp = super().cuda(*args, **kwargs)
new = self.__class__.clone(self)
new.data = tmp.data
return tmp
def cpu(self, *args, **kwargs):
"""
Send Tensor to cpu. For more details, see :meth:`torch.Tensor.cpu`.
"""
tmp = super().cpu(*args, **kwargs)
new = self.__class__.clone(self)
new.data = tmp.data
return tmp
def extract(self, label_to_extract):
"""
Extract the subset of the original tensor by returning all the columns
@@ -149,7 +175,7 @@ class LabelTensor(torch.Tensor):
except ValueError:
raise ValueError(f'`{f}` not in the labels list')
new_data = super(Tensor, self.T).__getitem__(indeces).float().T
new_data = super(Tensor, self.T).__getitem__(indeces).T
new_labels = [self.labels[idx] for idx in indeces]
extracted_tensor = new_data.as_subclass(LabelTensor)
@@ -196,8 +222,12 @@ class LabelTensor(torch.Tensor):
"""
Return a copy of the selected tensor.
"""
selected_lt = super(Tensor, self).__getitem__(index)
if isinstance(index, str) or (isinstance(index, (tuple, list))and all(isinstance(a, str) for a in index)):
return self.extract(index)
selected_lt = super(Tensor, self).__getitem__(index)
try:
len_index = len(index)
except TypeError:

View File

@@ -27,7 +27,6 @@ def test_labels():
def test_extract():
label_to_extract = ['a', 'c']
tensor = LabelTensor(data, labels)
print(tensor)
new = tensor.extract(label_to_extract)
assert new.labels == label_to_extract
assert new.shape[1] == len(label_to_extract)
@@ -58,7 +57,6 @@ def test_extract_order():
expected = torch.cat(
(data[:, 2].reshape(-1, 1), data[:, 0].reshape(-1, 1)),
dim=1)
print(expected)
assert new.labels == label_to_extract
assert new.shape[1] == len(label_to_extract)
assert torch.all(torch.isclose(expected, new))
@@ -83,6 +81,18 @@ def test_merge2():
def test_getitem():
tensor = LabelTensor(data, labels)
tensor_view = tensor['a']
assert tensor_view.labels == ['a']
assert torch.allclose(tensor_view.flatten(), data[:, 0])
tensor_view = tensor['a', 'c']
assert tensor_view.labels == ['a', 'c']
assert torch.allclose(tensor_view, data[:, 0::2])
def test_getitem2():
tensor = LabelTensor(data, labels)
tensor_view = tensor[:5]
@@ -101,4 +111,4 @@ def test_slice():
tensor_view3 = tensor[:, 2]
assert tensor_view3.labels == labels[2]
assert torch.allclose(tensor_view3, data[:, 2].reshape(-1, 1))
assert torch.allclose(tensor_view3, data[:, 2].reshape(-1, 1))