fix slicing for LabelTensor (#167)

* fix slicing for LabelTensor
* Update testing_pr.yml for solving python3.1 error
This commit is contained in:
Nicola Demo
2023-07-22 15:44:52 +02:00
parent ba0b9760ac
commit bd88e24174
3 changed files with 33 additions and 5 deletions

View File

@@ -14,7 +14,7 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
os: [windows-latest, macos-latest, ubuntu-latest] os: [windows-latest, macos-latest, ubuntu-latest]
python-version: [3.7, 3.8, 3.9, 3.10] python-version: [3.7, 3.8, 3.9, '3.10']
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2

View File

@@ -197,8 +197,22 @@ class LabelTensor(torch.Tensor):
Return a copy of the selected tensor. Return a copy of the selected tensor.
""" """
selected_lt = super(Tensor, self).__getitem__(index) selected_lt = super(Tensor, self).__getitem__(index)
if hasattr(self, 'labels'):
selected_lt.labels = self.labels try:
len_index = len(index)
except TypeError:
len_index = 1
if isinstance(index, int) or len_index == 1:
if selected_lt.ndim == 1:
selected_lt = selected_lt.reshape(1, -1)
if hasattr(self, 'labels'):
selected_lt.labels = self.labels
elif len_index == 2:
if selected_lt.ndim == 1:
selected_lt = selected_lt.reshape(-1, 1)
if hasattr(self, 'labels'):
selected_lt.labels = self.labels[index[1]]
return selected_lt return selected_lt

View File

@@ -73,7 +73,7 @@ def test_merge():
tensor_bc = tensor_b.append(tensor_c) tensor_bc = tensor_b.append(tensor_c)
assert torch.allclose(tensor_bc, tensor.extract(['b', 'c'])) assert torch.allclose(tensor_bc, tensor.extract(['b', 'c']))
def test_merge(): def test_merge2():
tensor = LabelTensor(data, labels) tensor = LabelTensor(data, labels)
tensor_b = tensor.extract('b') tensor_b = tensor.extract('b')
tensor_c = tensor.extract('c') tensor_c = tensor.extract('c')
@@ -87,4 +87,18 @@ def test_getitem():
tensor_view = tensor[:5] tensor_view = tensor[:5]
assert tensor_view.labels == labels assert tensor_view.labels == labels
assert torch.allclose(tensor_view, data[:5]) assert torch.allclose(tensor_view, data[:5])
def test_slice():
tensor = LabelTensor(data, labels)
tensor_view = tensor[:5, :2]
assert tensor_view.labels == labels[:2]
assert torch.allclose(tensor_view, data[:5, :2])
tensor_view2 = tensor[3]
assert tensor_view2.labels == labels
assert torch.allclose(tensor_view2, data[3])
tensor_view3 = tensor[:, 2]
assert tensor_view3.labels == labels[2]
assert torch.allclose(tensor_view3, data[:, 2].reshape(-1, 1))