Additional improvement related to #395

This commit is contained in:
FilippoOlivo
2025-01-24 09:47:57 +01:00
committed by Nicola Demo
parent afb1bca245
commit 629a6ee43b
2 changed files with 22 additions and 7 deletions

View File

@@ -380,20 +380,21 @@ class LabelTensor(torch.Tensor):
""" """
old_dof = old_labels[to_update_dim]['dof'] old_dof = old_labels[to_update_dim]['dof']
label_name = old_labels[dim]['name'] label_name = old_labels[dim]['name']
# Handle slicing
if isinstance(index, slice): if isinstance(index, slice):
# Handle slicing
to_update_labels[dim] = {'dof': old_dof[index], 'name': label_name} to_update_labels[dim] = {'dof': old_dof[index], 'name': label_name}
# Handle single integer index
elif isinstance(index, int): elif isinstance(index, int):
# Handle single integer index
to_update_labels[dim] = {'dof': [old_dof[index]], to_update_labels[dim] = {'dof': [old_dof[index]],
'name': label_name} 'name': label_name}
# Handle lists or tensors
elif isinstance(index, (list, torch.Tensor)): elif isinstance(index, (list, torch.Tensor)):
# Handle lists or tensors # Handle list of bools
indices = [index] if isinstance(index, (int, str)) else index if isinstance(index, torch.Tensor) and index.dtype == torch.bool:
index = index.nonzero().squeeze()
to_update_labels[dim] = { to_update_labels[dim] = {
'dof': [old_dof[i] for i in indices] if isinstance(old_dof, 'dof': [old_dof[i] for i in index] if isinstance(old_dof,
list) else indices, list) else index,
'name': label_name 'name': label_name
} }
else: else:

View File

@@ -281,3 +281,17 @@ def test_sorting():
assert torch.eq(lt_sorted.tensor[:, 1, :], torch.ones(20, 5) * 2).all() assert torch.eq(lt_sorted.tensor[:, 1, :], torch.ones(20, 5) * 2).all()
assert torch.eq(lt_sorted.tensor[:, 2, :], torch.ones(20, 5) * 3).all() assert torch.eq(lt_sorted.tensor[:, 2, :], torch.ones(20, 5) * 3).all()
assert torch.eq(lt_sorted.tensor[:, 3, :], torch.ones(20, 5) * 4).all() assert torch.eq(lt_sorted.tensor[:, 3, :], torch.ones(20, 5) * 4).all()
@pytest.mark.parametrize("labels",
[[f's{i}' for i in range(10)],
{0: {'dof': ['a', 'b', 'c']},
1: {'dof': [f's{i}' for i in range(10)]}}])
def test_cat_bool(labels):
out = torch.randn((3, 10))
out = LabelTensor(out, labels)
selected = out[torch.tensor([True, True, False])]
assert selected.shape == (2, 10)
assert selected.stored_labels[1]['dof'] == [f's{i}' for i in range(10)]
if isinstance(labels, dict):
assert selected.stored_labels[0]['dof'] == ['a', 'b']