fix slicing for LabelTensor (#167)

* fix slicing for LabelTensor
* Update testing_pr.yml for solving python3.1 error
This commit is contained in:
Nicola Demo
2023-07-22 15:44:52 +02:00
parent ba0b9760ac
commit bd88e24174
3 changed files with 33 additions and 5 deletions

View File

@@ -197,8 +197,22 @@ class LabelTensor(torch.Tensor):
Return a copy of the selected tensor.
"""
selected_lt = super(Tensor, self).__getitem__(index)
if hasattr(self, 'labels'):
selected_lt.labels = self.labels
try:
len_index = len(index)
except TypeError:
len_index = 1
if isinstance(index, int) or len_index == 1:
if selected_lt.ndim == 1:
selected_lt = selected_lt.reshape(1, -1)
if hasattr(self, 'labels'):
selected_lt.labels = self.labels
elif len_index == 2:
if selected_lt.ndim == 1:
selected_lt = selected_lt.reshape(-1, 1)
if hasattr(self, 'labels'):
selected_lt.labels = self.labels[index[1]]
return selected_lt