diff --git a/pina/label_tensor.py b/pina/label_tensor.py index 5ffa611..d8f66f9 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -380,20 +380,21 @@ class LabelTensor(torch.Tensor): """ old_dof = old_labels[to_update_dim]['dof'] label_name = old_labels[dim]['name'] - + # Handle slicing if isinstance(index, slice): - # Handle slicing to_update_labels[dim] = {'dof': old_dof[index], 'name': label_name} + # Handle single integer index elif isinstance(index, int): - # Handle single integer index to_update_labels[dim] = {'dof': [old_dof[index]], 'name': label_name} + # Handle lists or tensors elif isinstance(index, (list, torch.Tensor)): - # Handle lists or tensors - indices = [index] if isinstance(index, (int, str)) else index + # Handle list of bools + if isinstance(index, torch.Tensor) and index.dtype == torch.bool: + index = index.nonzero().squeeze() to_update_labels[dim] = { - 'dof': [old_dof[i] for i in indices] if isinstance(old_dof, - list) else indices, + 'dof': [old_dof[i] for i in index] if isinstance(old_dof, + list) else index, 'name': label_name } else: diff --git a/tests/test_label_tensor/test_label_tensor.py b/tests/test_label_tensor/test_label_tensor.py index 41288e6..2c5d15e 100644 --- a/tests/test_label_tensor/test_label_tensor.py +++ b/tests/test_label_tensor/test_label_tensor.py @@ -281,3 +281,17 @@ def test_sorting(): assert torch.eq(lt_sorted.tensor[:, 1, :], torch.ones(20, 5) * 2).all() assert torch.eq(lt_sorted.tensor[:, 2, :], torch.ones(20, 5) * 3).all() assert torch.eq(lt_sorted.tensor[:, 3, :], torch.ones(20, 5) * 4).all() + + +@pytest.mark.parametrize("labels", + [[f's{i}' for i in range(10)], + {0: {'dof': ['a', 'b', 'c']}, + 1: {'dof': [f's{i}' for i in range(10)]}}]) +def test_cat_bool(labels): + out = torch.randn((3, 10)) + out = LabelTensor(out, labels) + selected = out[torch.tensor([True, True, False])] + assert selected.shape == (2, 10) + assert selected.stored_labels[1]['dof'] == [f's{i}' for i in range(10)] + if isinstance(labels, dict): + assert selected.stored_labels[0]['dof'] == ['a', 'b']