fix slicing for LabelTensor (#167)
* fix slicing for LabelTensor * Update testing_pr.yml for solving python3.1 error
This commit is contained in:
@@ -197,8 +197,22 @@ class LabelTensor(torch.Tensor):
|
||||
Return a copy of the selected tensor.
|
||||
"""
|
||||
selected_lt = super(Tensor, self).__getitem__(index)
|
||||
if hasattr(self, 'labels'):
|
||||
selected_lt.labels = self.labels
|
||||
|
||||
try:
|
||||
len_index = len(index)
|
||||
except TypeError:
|
||||
len_index = 1
|
||||
|
||||
if isinstance(index, int) or len_index == 1:
|
||||
if selected_lt.ndim == 1:
|
||||
selected_lt = selected_lt.reshape(1, -1)
|
||||
if hasattr(self, 'labels'):
|
||||
selected_lt.labels = self.labels
|
||||
elif len_index == 2:
|
||||
if selected_lt.ndim == 1:
|
||||
selected_lt = selected_lt.reshape(-1, 1)
|
||||
if hasattr(self, 'labels'):
|
||||
selected_lt.labels = self.labels[index[1]]
|
||||
|
||||
return selected_lt
|
||||
|
||||
|
||||
Reference in New Issue
Block a user