fix slicing for LabelTensor (#167)
* fix slicing for LabelTensor * Update testing_pr.yml for solving python3.1 error
This commit is contained in:
2
.github/workflows/testing_pr.yml
vendored
2
.github/workflows/testing_pr.yml
vendored
@@ -14,7 +14,7 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
os: [windows-latest, macos-latest, ubuntu-latest]
|
os: [windows-latest, macos-latest, ubuntu-latest]
|
||||||
python-version: [3.7, 3.8, 3.9, 3.10]
|
python-version: [3.7, 3.8, 3.9, '3.10']
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
|
|||||||
@@ -197,8 +197,22 @@ class LabelTensor(torch.Tensor):
|
|||||||
Return a copy of the selected tensor.
|
Return a copy of the selected tensor.
|
||||||
"""
|
"""
|
||||||
selected_lt = super(Tensor, self).__getitem__(index)
|
selected_lt = super(Tensor, self).__getitem__(index)
|
||||||
if hasattr(self, 'labels'):
|
|
||||||
selected_lt.labels = self.labels
|
try:
|
||||||
|
len_index = len(index)
|
||||||
|
except TypeError:
|
||||||
|
len_index = 1
|
||||||
|
|
||||||
|
if isinstance(index, int) or len_index == 1:
|
||||||
|
if selected_lt.ndim == 1:
|
||||||
|
selected_lt = selected_lt.reshape(1, -1)
|
||||||
|
if hasattr(self, 'labels'):
|
||||||
|
selected_lt.labels = self.labels
|
||||||
|
elif len_index == 2:
|
||||||
|
if selected_lt.ndim == 1:
|
||||||
|
selected_lt = selected_lt.reshape(-1, 1)
|
||||||
|
if hasattr(self, 'labels'):
|
||||||
|
selected_lt.labels = self.labels[index[1]]
|
||||||
|
|
||||||
return selected_lt
|
return selected_lt
|
||||||
|
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ def test_merge():
|
|||||||
tensor_bc = tensor_b.append(tensor_c)
|
tensor_bc = tensor_b.append(tensor_c)
|
||||||
assert torch.allclose(tensor_bc, tensor.extract(['b', 'c']))
|
assert torch.allclose(tensor_bc, tensor.extract(['b', 'c']))
|
||||||
|
|
||||||
def test_merge():
|
def test_merge2():
|
||||||
tensor = LabelTensor(data, labels)
|
tensor = LabelTensor(data, labels)
|
||||||
tensor_b = tensor.extract('b')
|
tensor_b = tensor.extract('b')
|
||||||
tensor_c = tensor.extract('c')
|
tensor_c = tensor.extract('c')
|
||||||
@@ -87,4 +87,18 @@ def test_getitem():
|
|||||||
tensor_view = tensor[:5]
|
tensor_view = tensor[:5]
|
||||||
|
|
||||||
assert tensor_view.labels == labels
|
assert tensor_view.labels == labels
|
||||||
assert torch.allclose(tensor_view, data[:5])
|
assert torch.allclose(tensor_view, data[:5])
|
||||||
|
|
||||||
|
def test_slice():
|
||||||
|
tensor = LabelTensor(data, labels)
|
||||||
|
tensor_view = tensor[:5, :2]
|
||||||
|
assert tensor_view.labels == labels[:2]
|
||||||
|
assert torch.allclose(tensor_view, data[:5, :2])
|
||||||
|
|
||||||
|
tensor_view2 = tensor[3]
|
||||||
|
assert tensor_view2.labels == labels
|
||||||
|
assert torch.allclose(tensor_view2, data[3])
|
||||||
|
|
||||||
|
tensor_view3 = tensor[:, 2]
|
||||||
|
assert tensor_view3.labels == labels[2]
|
||||||
|
assert torch.allclose(tensor_view3, data[:, 2].reshape(-1, 1))
|
||||||
Reference in New Issue
Block a user