diff --git a/pina/label_tensor.py b/pina/label_tensor.py index d6f57e3..5645381 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -33,6 +33,14 @@ class LabelTensor(torch.Tensor): [1.0246e-01, 9.5179e-01, 3.7043e-02], [9.6150e-01, 8.0656e-01, 8.3824e-01]]) >>> tensor.extract('a') + tensor([[0.0671], + [0.9239], + [0.8927], + ..., + [0.5819], + [0.1025], + [0.9615]]) + >>> tensor['a'] tensor([[0.0671], [0.9239], [0.8927], @@ -69,7 +77,7 @@ class LabelTensor(torch.Tensor): 'the passed labels.' ) self._labels = labels - + @property def labels(self): """Property decorator for labels @@ -100,7 +108,7 @@ class LabelTensor(torch.Tensor): """ try: out = LabelTensor(super().clone(*args, **kwargs), self.labels) - except: + except: # this is used when the tensor loose the labels, notice it will create a bug! Kept for compatibility with Lightining out = super().clone(*args, **kwargs) return out @@ -123,6 +131,24 @@ class LabelTensor(torch.Tensor): tmp._labels = self._labels return tmp + def cuda(self, *args, **kwargs): + """ + Send Tensor to cuda. For more details, see :meth:`torch.Tensor.cuda`. + """ + tmp = super().cuda(*args, **kwargs) + new = self.__class__.clone(self) + new.data = tmp.data + return tmp + + def cpu(self, *args, **kwargs): + """ + Send Tensor to cpu. For more details, see :meth:`torch.Tensor.cpu`. + """ + tmp = super().cpu(*args, **kwargs) + new = self.__class__.clone(self) + new.data = tmp.data + return tmp + def extract(self, label_to_extract): """ Extract the subset of the original tensor by returning all the columns @@ -149,7 +175,7 @@ class LabelTensor(torch.Tensor): except ValueError: raise ValueError(f'`{f}` not in the labels list') - new_data = super(Tensor, self.T).__getitem__(indeces).float().T + new_data = super(Tensor, self.T).__getitem__(indeces).T new_labels = [self.labels[idx] for idx in indeces] extracted_tensor = new_data.as_subclass(LabelTensor) @@ -196,8 +222,12 @@ class LabelTensor(torch.Tensor): """ Return a copy of the selected tensor. """ - selected_lt = super(Tensor, self).__getitem__(index) + if isinstance(index, str) or (isinstance(index, (tuple, list))and all(isinstance(a, str) for a in index)): + return self.extract(index) + + selected_lt = super(Tensor, self).__getitem__(index) + try: len_index = len(index) except TypeError: diff --git a/tests/test_label_tensor.py b/tests/test_label_tensor.py index 161ffa4..4e6a302 100644 --- a/tests/test_label_tensor.py +++ b/tests/test_label_tensor.py @@ -27,7 +27,6 @@ def test_labels(): def test_extract(): label_to_extract = ['a', 'c'] tensor = LabelTensor(data, labels) - print(tensor) new = tensor.extract(label_to_extract) assert new.labels == label_to_extract assert new.shape[1] == len(label_to_extract) @@ -58,7 +57,6 @@ def test_extract_order(): expected = torch.cat( (data[:, 2].reshape(-1, 1), data[:, 0].reshape(-1, 1)), dim=1) - print(expected) assert new.labels == label_to_extract assert new.shape[1] == len(label_to_extract) assert torch.all(torch.isclose(expected, new)) @@ -83,6 +81,18 @@ def test_merge2(): def test_getitem(): + tensor = LabelTensor(data, labels) + tensor_view = tensor['a'] + + assert tensor_view.labels == ['a'] + assert torch.allclose(tensor_view.flatten(), data[:, 0]) + + tensor_view = tensor['a', 'c'] + + assert tensor_view.labels == ['a', 'c'] + assert torch.allclose(tensor_view, data[:, 0::2]) + +def test_getitem2(): tensor = LabelTensor(data, labels) tensor_view = tensor[:5] @@ -101,4 +111,4 @@ def test_slice(): tensor_view3 = tensor[:, 2] assert tensor_view3.labels == labels[2] - assert torch.allclose(tensor_view3, data[:, 2].reshape(-1, 1)) \ No newline at end of file + assert torch.allclose(tensor_view3, data[:, 2].reshape(-1, 1))