Label Tensor update (#188)
* Update test_label_tensor.py * adding test --------- Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local>
This commit is contained in:
committed by
Nicola Demo
parent
d556c592e0
commit
cd5bc9a558
@@ -33,6 +33,14 @@ class LabelTensor(torch.Tensor):
|
|||||||
[1.0246e-01, 9.5179e-01, 3.7043e-02],
|
[1.0246e-01, 9.5179e-01, 3.7043e-02],
|
||||||
[9.6150e-01, 8.0656e-01, 8.3824e-01]])
|
[9.6150e-01, 8.0656e-01, 8.3824e-01]])
|
||||||
>>> tensor.extract('a')
|
>>> tensor.extract('a')
|
||||||
|
tensor([[0.0671],
|
||||||
|
[0.9239],
|
||||||
|
[0.8927],
|
||||||
|
...,
|
||||||
|
[0.5819],
|
||||||
|
[0.1025],
|
||||||
|
[0.9615]])
|
||||||
|
>>> tensor['a']
|
||||||
tensor([[0.0671],
|
tensor([[0.0671],
|
||||||
[0.9239],
|
[0.9239],
|
||||||
[0.8927],
|
[0.8927],
|
||||||
@@ -69,7 +77,7 @@ class LabelTensor(torch.Tensor):
|
|||||||
'the passed labels.'
|
'the passed labels.'
|
||||||
)
|
)
|
||||||
self._labels = labels
|
self._labels = labels
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def labels(self):
|
def labels(self):
|
||||||
"""Property decorator for labels
|
"""Property decorator for labels
|
||||||
@@ -100,7 +108,7 @@ class LabelTensor(torch.Tensor):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
out = LabelTensor(super().clone(*args, **kwargs), self.labels)
|
out = LabelTensor(super().clone(*args, **kwargs), self.labels)
|
||||||
except:
|
except: # this is used when the tensor loose the labels, notice it will create a bug! Kept for compatibility with Lightining
|
||||||
out = super().clone(*args, **kwargs)
|
out = super().clone(*args, **kwargs)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
@@ -123,6 +131,24 @@ class LabelTensor(torch.Tensor):
|
|||||||
tmp._labels = self._labels
|
tmp._labels = self._labels
|
||||||
return tmp
|
return tmp
|
||||||
|
|
||||||
|
def cuda(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Send Tensor to cuda. For more details, see :meth:`torch.Tensor.cuda`.
|
||||||
|
"""
|
||||||
|
tmp = super().cuda(*args, **kwargs)
|
||||||
|
new = self.__class__.clone(self)
|
||||||
|
new.data = tmp.data
|
||||||
|
return tmp
|
||||||
|
|
||||||
|
def cpu(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Send Tensor to cpu. For more details, see :meth:`torch.Tensor.cpu`.
|
||||||
|
"""
|
||||||
|
tmp = super().cpu(*args, **kwargs)
|
||||||
|
new = self.__class__.clone(self)
|
||||||
|
new.data = tmp.data
|
||||||
|
return tmp
|
||||||
|
|
||||||
def extract(self, label_to_extract):
|
def extract(self, label_to_extract):
|
||||||
"""
|
"""
|
||||||
Extract the subset of the original tensor by returning all the columns
|
Extract the subset of the original tensor by returning all the columns
|
||||||
@@ -149,7 +175,7 @@ class LabelTensor(torch.Tensor):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
raise ValueError(f'`{f}` not in the labels list')
|
raise ValueError(f'`{f}` not in the labels list')
|
||||||
|
|
||||||
new_data = super(Tensor, self.T).__getitem__(indeces).float().T
|
new_data = super(Tensor, self.T).__getitem__(indeces).T
|
||||||
new_labels = [self.labels[idx] for idx in indeces]
|
new_labels = [self.labels[idx] for idx in indeces]
|
||||||
|
|
||||||
extracted_tensor = new_data.as_subclass(LabelTensor)
|
extracted_tensor = new_data.as_subclass(LabelTensor)
|
||||||
@@ -196,8 +222,12 @@ class LabelTensor(torch.Tensor):
|
|||||||
"""
|
"""
|
||||||
Return a copy of the selected tensor.
|
Return a copy of the selected tensor.
|
||||||
"""
|
"""
|
||||||
selected_lt = super(Tensor, self).__getitem__(index)
|
|
||||||
|
|
||||||
|
if isinstance(index, str) or (isinstance(index, (tuple, list))and all(isinstance(a, str) for a in index)):
|
||||||
|
return self.extract(index)
|
||||||
|
|
||||||
|
selected_lt = super(Tensor, self).__getitem__(index)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
len_index = len(index)
|
len_index = len(index)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ def test_labels():
|
|||||||
def test_extract():
|
def test_extract():
|
||||||
label_to_extract = ['a', 'c']
|
label_to_extract = ['a', 'c']
|
||||||
tensor = LabelTensor(data, labels)
|
tensor = LabelTensor(data, labels)
|
||||||
print(tensor)
|
|
||||||
new = tensor.extract(label_to_extract)
|
new = tensor.extract(label_to_extract)
|
||||||
assert new.labels == label_to_extract
|
assert new.labels == label_to_extract
|
||||||
assert new.shape[1] == len(label_to_extract)
|
assert new.shape[1] == len(label_to_extract)
|
||||||
@@ -58,7 +57,6 @@ def test_extract_order():
|
|||||||
expected = torch.cat(
|
expected = torch.cat(
|
||||||
(data[:, 2].reshape(-1, 1), data[:, 0].reshape(-1, 1)),
|
(data[:, 2].reshape(-1, 1), data[:, 0].reshape(-1, 1)),
|
||||||
dim=1)
|
dim=1)
|
||||||
print(expected)
|
|
||||||
assert new.labels == label_to_extract
|
assert new.labels == label_to_extract
|
||||||
assert new.shape[1] == len(label_to_extract)
|
assert new.shape[1] == len(label_to_extract)
|
||||||
assert torch.all(torch.isclose(expected, new))
|
assert torch.all(torch.isclose(expected, new))
|
||||||
@@ -83,6 +81,18 @@ def test_merge2():
|
|||||||
|
|
||||||
|
|
||||||
def test_getitem():
|
def test_getitem():
|
||||||
|
tensor = LabelTensor(data, labels)
|
||||||
|
tensor_view = tensor['a']
|
||||||
|
|
||||||
|
assert tensor_view.labels == ['a']
|
||||||
|
assert torch.allclose(tensor_view.flatten(), data[:, 0])
|
||||||
|
|
||||||
|
tensor_view = tensor['a', 'c']
|
||||||
|
|
||||||
|
assert tensor_view.labels == ['a', 'c']
|
||||||
|
assert torch.allclose(tensor_view, data[:, 0::2])
|
||||||
|
|
||||||
|
def test_getitem2():
|
||||||
tensor = LabelTensor(data, labels)
|
tensor = LabelTensor(data, labels)
|
||||||
tensor_view = tensor[:5]
|
tensor_view = tensor[:5]
|
||||||
|
|
||||||
@@ -101,4 +111,4 @@ def test_slice():
|
|||||||
|
|
||||||
tensor_view3 = tensor[:, 2]
|
tensor_view3 = tensor[:, 2]
|
||||||
assert tensor_view3.labels == labels[2]
|
assert tensor_view3.labels == labels[2]
|
||||||
assert torch.allclose(tensor_view3, data[:, 2].reshape(-1, 1))
|
assert torch.allclose(tensor_view3, data[:, 2].reshape(-1, 1))
|
||||||
|
|||||||
Reference in New Issue
Block a user