From 6ed3ca04fee3ae3673d53ea384437ce270f008da Mon Sep 17 00:00:00 2001 From: Filippo Olivo Date: Mon, 14 Apr 2025 11:41:59 +0200 Subject: [PATCH] Fix `__getitem__` in LabelTensor (#546) * Fix LabelTensor * Cleaning label_tensor.py --------- Co-authored-by: Dario Coscia <93731561+dario-coscia@users.noreply.github.com> --- pina/label_tensor.py | 53 ++++++++--------- tests/test_label_tensor/test_label_tensor.py | 60 ++++++++++++++++++++ 2 files changed, 85 insertions(+), 28 deletions(-) diff --git a/pina/label_tensor.py b/pina/label_tensor.py index 3ff1e79..098f514 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -505,50 +505,40 @@ class LabelTensor(torch.Tensor): return LabelTensor.cat(tensors, dim=0) # This method is used to update labels - def _update_single_label( - self, old_labels, to_update_labels, index, dim, to_update_dim - ): + def _update_single_label(self, index, dim): """ Update the labels of the tensor based on the index (or list of indices). - :param dict old_labels: Labels from which retrieve data. - :param dict to_update_labels: Labels to update. :param index: Index of dof to retain. :type index: int | slice | list[int] | tuple[int] | torch.Tensor - :param int dim: The dimension to update. - + :param int dim: Dimension of the indexes in the original tensor. + :return: The updated labels for the specified dimension. + :rtype: list[int] :raises: ValueError: If the index type is not supported. """ - - old_dof = old_labels[to_update_dim]["dof"] - label_name = old_labels[dim]["name"] + old_dof = self._labels[dim]["dof"] # Handle slicing if isinstance(index, slice): - to_update_labels[dim] = {"dof": old_dof[index], "name": label_name} + new_dof = old_dof[index] # Handle single integer index elif isinstance(index, int): - to_update_labels[dim] = { - "dof": [old_dof[index]], - "name": label_name, - } + new_dof = [old_dof[index]] # Handle lists or tensors elif isinstance(index, (list, torch.Tensor)): # Handle list of bools if isinstance(index, torch.Tensor) and index.dtype == torch.bool: index = index.nonzero().squeeze() - to_update_labels[dim] = { - "dof": ( - [old_dof[i] for i in index] - if isinstance(old_dof, list) - else index - ), - "name": label_name, - } + new_dof = ( + [old_dof[i] for i in index] + if isinstance(old_dof, list) + else index + ) else: raise NotImplementedError( f"Unsupported index type: {type(index)}. Expected slice, int, " f"list, or torch.Tensor." ) + return new_dof def __getitem__(self, index): """ " @@ -589,14 +579,20 @@ class LabelTensor(torch.Tensor): # Update labels based on the index offset = 0 + removed = 0 for dim, idx in enumerate(index): - if dim in self.stored_labels: + if dim in original_labels: if isinstance(idx, int): - selected_tensor = selected_tensor.unsqueeze(dim) + # Compute the working dimension considering the removed + # dimensions due to int index on a non labled dimension + dim_ = dim - removed + selected_tensor = selected_tensor.unsqueeze(dim_) if idx != slice(None): - self._update_single_label( - original_labels, updated_labels, idx, dim, offset - ) + # Update the labels for the selected dimension + updated_labels[offset] = { + "dof": self._update_single_label(idx, dim), + "name": original_labels[dim]["name"], + } else: # Adjust label keys if dimension is reduced (case of integer # index on a non-labeled dimension) @@ -605,6 +601,7 @@ class LabelTensor(torch.Tensor): key - 1 if key > dim else key: value for key, value in updated_labels.items() } + removed += 1 continue offset += 1 diff --git a/tests/test_label_tensor/test_label_tensor.py b/tests/test_label_tensor/test_label_tensor.py index 556957b..973864d 100644 --- a/tests/test_label_tensor/test_label_tensor.py +++ b/tests/test_label_tensor/test_label_tensor.py @@ -278,3 +278,63 @@ def test_cat_bool(labels): assert selected.stored_labels[1]["dof"] == [f"s{i}" for i in range(10)] if isinstance(labels, dict): assert selected.stored_labels[0]["dof"] == ["a", "b"] + + +def test_getitem_int(): + data = torch.rand(20, 3) + labels = {1: {"name": 1, "dof": ["x", "y", "z"]}} + lt = LabelTensor(data, labels) + new = lt[0, 0] + assert new.ndim == 1 + assert new.shape[0] == 1 + assert torch.all(torch.isclose(data[0, 0], new)) + + data = torch.rand(20, 3, 2) + labels = { + 1: {"name": 1, "dof": ["x", "y", "z"]}, + 2: {"name": 2, "dof": ["a", "b"]}, + } + lt = LabelTensor(data, labels) + new = lt[0, 0, 0] + assert new.ndim == 2 + assert new.shape[0] == 1 + assert new.shape[1] == 1 + assert torch.all(torch.isclose(data[0, 0, 0], new)) + assert new.stored_labels[0]["dof"] == ["x"] + assert new.stored_labels[1]["dof"] == ["a"] + + new = lt[0, 0, :] + assert new.ndim == 2 + assert new.shape[0] == 1 + assert new.shape[1] == 2 + assert torch.all(torch.isclose(data[0, 0, :], new)) + assert new.stored_labels[0]["dof"] == ["x"] + assert new.stored_labels[1]["dof"] == ["a", "b"] + + new = lt[0, :, 1] + assert new.ndim == 2 + assert new.shape[0] == 3 + assert new.shape[1] == 1 + assert torch.all(torch.isclose(data[0, :, 1], new.squeeze())) + assert new.stored_labels[0]["dof"] == ["x", "y", "z"] + assert new.stored_labels[1]["dof"] == ["b"] + + labels.pop(2) + lt = LabelTensor(data, labels) + new = lt[0, 0, 0] + assert new.ndim == 1 + assert new.shape[0] == 1 + assert new.stored_labels[0]["dof"] == ["x"] + + new = lt[:, 0, 0] + assert new.ndim == 2 + assert new.shape[0] == 20 + assert new.shape[1] == 1 + assert new.stored_labels[1]["dof"] == ["x"] + + new = lt[:, 0, :] + assert new.ndim == 3 + assert new.shape[0] == 20 + assert new.shape[1] == 1 + assert new.shape[2] == 2 + assert new.stored_labels[1]["dof"] == ["x"]