Additional improvement related to #395
This commit is contained in:
committed by
Nicola Demo
parent
afb1bca245
commit
629a6ee43b
@@ -380,20 +380,21 @@ class LabelTensor(torch.Tensor):
|
||||
"""
|
||||
old_dof = old_labels[to_update_dim]['dof']
|
||||
label_name = old_labels[dim]['name']
|
||||
|
||||
# Handle slicing
|
||||
if isinstance(index, slice):
|
||||
# Handle slicing
|
||||
to_update_labels[dim] = {'dof': old_dof[index], 'name': label_name}
|
||||
# Handle single integer index
|
||||
elif isinstance(index, int):
|
||||
# Handle single integer index
|
||||
to_update_labels[dim] = {'dof': [old_dof[index]],
|
||||
'name': label_name}
|
||||
# Handle lists or tensors
|
||||
elif isinstance(index, (list, torch.Tensor)):
|
||||
# Handle lists or tensors
|
||||
indices = [index] if isinstance(index, (int, str)) else index
|
||||
# Handle list of bools
|
||||
if isinstance(index, torch.Tensor) and index.dtype == torch.bool:
|
||||
index = index.nonzero().squeeze()
|
||||
to_update_labels[dim] = {
|
||||
'dof': [old_dof[i] for i in indices] if isinstance(old_dof,
|
||||
list) else indices,
|
||||
'dof': [old_dof[i] for i in index] if isinstance(old_dof,
|
||||
list) else index,
|
||||
'name': label_name
|
||||
}
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user