diff --git a/.github/workflows/testing_pr.yml b/.github/workflows/testing_pr.yml index 796bcc9..06045cf 100644 --- a/.github/workflows/testing_pr.yml +++ b/.github/workflows/testing_pr.yml @@ -14,7 +14,7 @@ jobs: fail-fast: false matrix: os: [windows-latest, macos-latest, ubuntu-latest] - python-version: [3.7, 3.8, 3.9, 3.10] + python-version: [3.7, 3.8, 3.9, '3.10'] steps: - uses: actions/checkout@v2 diff --git a/pina/label_tensor.py b/pina/label_tensor.py index f79420c..c9e03bf 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -197,8 +197,22 @@ class LabelTensor(torch.Tensor): Return a copy of the selected tensor. """ selected_lt = super(Tensor, self).__getitem__(index) - if hasattr(self, 'labels'): - selected_lt.labels = self.labels + + try: + len_index = len(index) + except TypeError: + len_index = 1 + + if isinstance(index, int) or len_index == 1: + if selected_lt.ndim == 1: + selected_lt = selected_lt.reshape(1, -1) + if hasattr(self, 'labels'): + selected_lt.labels = self.labels + elif len_index == 2: + if selected_lt.ndim == 1: + selected_lt = selected_lt.reshape(-1, 1) + if hasattr(self, 'labels'): + selected_lt.labels = self.labels[index[1]] return selected_lt diff --git a/tests/test_label_tensor.py b/tests/test_label_tensor.py index ce601f0..161ffa4 100644 --- a/tests/test_label_tensor.py +++ b/tests/test_label_tensor.py @@ -73,7 +73,7 @@ def test_merge(): tensor_bc = tensor_b.append(tensor_c) assert torch.allclose(tensor_bc, tensor.extract(['b', 'c'])) -def test_merge(): +def test_merge2(): tensor = LabelTensor(data, labels) tensor_b = tensor.extract('b') tensor_c = tensor.extract('c') @@ -87,4 +87,18 @@ def test_getitem(): tensor_view = tensor[:5] assert tensor_view.labels == labels - assert torch.allclose(tensor_view, data[:5]) \ No newline at end of file + assert torch.allclose(tensor_view, data[:5]) + +def test_slice(): + tensor = LabelTensor(data, labels) + tensor_view = tensor[:5, :2] + assert tensor_view.labels == labels[:2] + assert torch.allclose(tensor_view, data[:5, :2]) + + tensor_view2 = tensor[3] + assert tensor_view2.labels == labels + assert torch.allclose(tensor_view2, data[3]) + + tensor_view3 = tensor[:, 2] + assert tensor_view3.labels == labels[2] + assert torch.allclose(tensor_view3, data[:, 2].reshape(-1, 1)) \ No newline at end of file