Additional improvement related to #395
This commit is contained in:
committed by
Nicola Demo
parent
afb1bca245
commit
629a6ee43b
@@ -380,20 +380,21 @@ class LabelTensor(torch.Tensor):
|
|||||||
"""
|
"""
|
||||||
old_dof = old_labels[to_update_dim]['dof']
|
old_dof = old_labels[to_update_dim]['dof']
|
||||||
label_name = old_labels[dim]['name']
|
label_name = old_labels[dim]['name']
|
||||||
|
# Handle slicing
|
||||||
if isinstance(index, slice):
|
if isinstance(index, slice):
|
||||||
# Handle slicing
|
|
||||||
to_update_labels[dim] = {'dof': old_dof[index], 'name': label_name}
|
to_update_labels[dim] = {'dof': old_dof[index], 'name': label_name}
|
||||||
|
# Handle single integer index
|
||||||
elif isinstance(index, int):
|
elif isinstance(index, int):
|
||||||
# Handle single integer index
|
|
||||||
to_update_labels[dim] = {'dof': [old_dof[index]],
|
to_update_labels[dim] = {'dof': [old_dof[index]],
|
||||||
'name': label_name}
|
'name': label_name}
|
||||||
|
# Handle lists or tensors
|
||||||
elif isinstance(index, (list, torch.Tensor)):
|
elif isinstance(index, (list, torch.Tensor)):
|
||||||
# Handle lists or tensors
|
# Handle list of bools
|
||||||
indices = [index] if isinstance(index, (int, str)) else index
|
if isinstance(index, torch.Tensor) and index.dtype == torch.bool:
|
||||||
|
index = index.nonzero().squeeze()
|
||||||
to_update_labels[dim] = {
|
to_update_labels[dim] = {
|
||||||
'dof': [old_dof[i] for i in indices] if isinstance(old_dof,
|
'dof': [old_dof[i] for i in index] if isinstance(old_dof,
|
||||||
list) else indices,
|
list) else index,
|
||||||
'name': label_name
|
'name': label_name
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -281,3 +281,17 @@ def test_sorting():
|
|||||||
assert torch.eq(lt_sorted.tensor[:, 1, :], torch.ones(20, 5) * 2).all()
|
assert torch.eq(lt_sorted.tensor[:, 1, :], torch.ones(20, 5) * 2).all()
|
||||||
assert torch.eq(lt_sorted.tensor[:, 2, :], torch.ones(20, 5) * 3).all()
|
assert torch.eq(lt_sorted.tensor[:, 2, :], torch.ones(20, 5) * 3).all()
|
||||||
assert torch.eq(lt_sorted.tensor[:, 3, :], torch.ones(20, 5) * 4).all()
|
assert torch.eq(lt_sorted.tensor[:, 3, :], torch.ones(20, 5) * 4).all()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("labels",
|
||||||
|
[[f's{i}' for i in range(10)],
|
||||||
|
{0: {'dof': ['a', 'b', 'c']},
|
||||||
|
1: {'dof': [f's{i}' for i in range(10)]}}])
|
||||||
|
def test_cat_bool(labels):
|
||||||
|
out = torch.randn((3, 10))
|
||||||
|
out = LabelTensor(out, labels)
|
||||||
|
selected = out[torch.tensor([True, True, False])]
|
||||||
|
assert selected.shape == (2, 10)
|
||||||
|
assert selected.stored_labels[1]['dof'] == [f's{i}' for i in range(10)]
|
||||||
|
if isinstance(labels, dict):
|
||||||
|
assert selected.stored_labels[0]['dof'] == ['a', 'b']
|
||||||
|
|||||||
Reference in New Issue
Block a user