Fix __getitem__ in LabelTensor (#546)

* Fix LabelTensor
* Cleaning label_tensor.py

---------

Co-authored-by: Dario Coscia <93731561+dario-coscia@users.noreply.github.com>
This commit is contained in:
Filippo Olivo
2025-04-14 11:41:59 +02:00
committed by FilippoOlivo
parent e250e3f5f7
commit a1f98764d3
2 changed files with 85 additions and 28 deletions

View File

@@ -505,50 +505,40 @@ class LabelTensor(torch.Tensor):
return LabelTensor.cat(tensors, dim=0) return LabelTensor.cat(tensors, dim=0)
# This method is used to update labels # This method is used to update labels
def _update_single_label( def _update_single_label(self, index, dim):
self, old_labels, to_update_labels, index, dim, to_update_dim
):
""" """
Update the labels of the tensor based on the index (or list of indices). Update the labels of the tensor based on the index (or list of indices).
:param dict old_labels: Labels from which retrieve data.
:param dict to_update_labels: Labels to update.
:param index: Index of dof to retain. :param index: Index of dof to retain.
:type index: int | slice | list[int] | tuple[int] | torch.Tensor :type index: int | slice | list[int] | tuple[int] | torch.Tensor
:param int dim: The dimension to update. :param int dim: Dimension of the indexes in the original tensor.
:return: The updated labels for the specified dimension.
:rtype: list[int]
:raises: ValueError: If the index type is not supported. :raises: ValueError: If the index type is not supported.
""" """
old_dof = self._labels[dim]["dof"]
old_dof = old_labels[to_update_dim]["dof"]
label_name = old_labels[dim]["name"]
# Handle slicing # Handle slicing
if isinstance(index, slice): if isinstance(index, slice):
to_update_labels[dim] = {"dof": old_dof[index], "name": label_name} new_dof = old_dof[index]
# Handle single integer index # Handle single integer index
elif isinstance(index, int): elif isinstance(index, int):
to_update_labels[dim] = { new_dof = [old_dof[index]]
"dof": [old_dof[index]],
"name": label_name,
}
# Handle lists or tensors # Handle lists or tensors
elif isinstance(index, (list, torch.Tensor)): elif isinstance(index, (list, torch.Tensor)):
# Handle list of bools # Handle list of bools
if isinstance(index, torch.Tensor) and index.dtype == torch.bool: if isinstance(index, torch.Tensor) and index.dtype == torch.bool:
index = index.nonzero().squeeze() index = index.nonzero().squeeze()
to_update_labels[dim] = { new_dof = (
"dof": (
[old_dof[i] for i in index] [old_dof[i] for i in index]
if isinstance(old_dof, list) if isinstance(old_dof, list)
else index else index
), )
"name": label_name,
}
else: else:
raise NotImplementedError( raise NotImplementedError(
f"Unsupported index type: {type(index)}. Expected slice, int, " f"Unsupported index type: {type(index)}. Expected slice, int, "
f"list, or torch.Tensor." f"list, or torch.Tensor."
) )
return new_dof
def __getitem__(self, index): def __getitem__(self, index):
""" " """ "
@@ -589,14 +579,20 @@ class LabelTensor(torch.Tensor):
# Update labels based on the index # Update labels based on the index
offset = 0 offset = 0
removed = 0
for dim, idx in enumerate(index): for dim, idx in enumerate(index):
if dim in self.stored_labels: if dim in original_labels:
if isinstance(idx, int): if isinstance(idx, int):
selected_tensor = selected_tensor.unsqueeze(dim) # Compute the working dimension considering the removed
# dimensions due to int index on a non labled dimension
dim_ = dim - removed
selected_tensor = selected_tensor.unsqueeze(dim_)
if idx != slice(None): if idx != slice(None):
self._update_single_label( # Update the labels for the selected dimension
original_labels, updated_labels, idx, dim, offset updated_labels[offset] = {
) "dof": self._update_single_label(idx, dim),
"name": original_labels[dim]["name"],
}
else: else:
# Adjust label keys if dimension is reduced (case of integer # Adjust label keys if dimension is reduced (case of integer
# index on a non-labeled dimension) # index on a non-labeled dimension)
@@ -605,6 +601,7 @@ class LabelTensor(torch.Tensor):
key - 1 if key > dim else key: value key - 1 if key > dim else key: value
for key, value in updated_labels.items() for key, value in updated_labels.items()
} }
removed += 1
continue continue
offset += 1 offset += 1

View File

@@ -278,3 +278,63 @@ def test_cat_bool(labels):
assert selected.stored_labels[1]["dof"] == [f"s{i}" for i in range(10)] assert selected.stored_labels[1]["dof"] == [f"s{i}" for i in range(10)]
if isinstance(labels, dict): if isinstance(labels, dict):
assert selected.stored_labels[0]["dof"] == ["a", "b"] assert selected.stored_labels[0]["dof"] == ["a", "b"]
def test_getitem_int():
data = torch.rand(20, 3)
labels = {1: {"name": 1, "dof": ["x", "y", "z"]}}
lt = LabelTensor(data, labels)
new = lt[0, 0]
assert new.ndim == 1
assert new.shape[0] == 1
assert torch.all(torch.isclose(data[0, 0], new))
data = torch.rand(20, 3, 2)
labels = {
1: {"name": 1, "dof": ["x", "y", "z"]},
2: {"name": 2, "dof": ["a", "b"]},
}
lt = LabelTensor(data, labels)
new = lt[0, 0, 0]
assert new.ndim == 2
assert new.shape[0] == 1
assert new.shape[1] == 1
assert torch.all(torch.isclose(data[0, 0, 0], new))
assert new.stored_labels[0]["dof"] == ["x"]
assert new.stored_labels[1]["dof"] == ["a"]
new = lt[0, 0, :]
assert new.ndim == 2
assert new.shape[0] == 1
assert new.shape[1] == 2
assert torch.all(torch.isclose(data[0, 0, :], new))
assert new.stored_labels[0]["dof"] == ["x"]
assert new.stored_labels[1]["dof"] == ["a", "b"]
new = lt[0, :, 1]
assert new.ndim == 2
assert new.shape[0] == 3
assert new.shape[1] == 1
assert torch.all(torch.isclose(data[0, :, 1], new.squeeze()))
assert new.stored_labels[0]["dof"] == ["x", "y", "z"]
assert new.stored_labels[1]["dof"] == ["b"]
labels.pop(2)
lt = LabelTensor(data, labels)
new = lt[0, 0, 0]
assert new.ndim == 1
assert new.shape[0] == 1
assert new.stored_labels[0]["dof"] == ["x"]
new = lt[:, 0, 0]
assert new.ndim == 2
assert new.shape[0] == 20
assert new.shape[1] == 1
assert new.stored_labels[1]["dof"] == ["x"]
new = lt[:, 0, :]
assert new.ndim == 3
assert new.shape[0] == 20
assert new.shape[1] == 1
assert new.shape[2] == 2
assert new.stored_labels[1]["dof"] == ["x"]