Fix __getitem__ in LabelTensor (#546)

* Fix LabelTensor
* Cleaning label_tensor.py

---------

Co-authored-by: Dario Coscia <93731561+dario-coscia@users.noreply.github.com>
This commit is contained in:
Filippo Olivo
2025-04-14 11:41:59 +02:00
committed by GitHub
parent 0a60ed4c9a
commit 6ed3ca04fe
2 changed files with 85 additions and 28 deletions

View File

@@ -505,50 +505,40 @@ class LabelTensor(torch.Tensor):
return LabelTensor.cat(tensors, dim=0)
# This method is used to update labels
def _update_single_label(
self, old_labels, to_update_labels, index, dim, to_update_dim
):
def _update_single_label(self, index, dim):
"""
Update the labels of the tensor based on the index (or list of indices).
:param dict old_labels: Labels from which retrieve data.
:param dict to_update_labels: Labels to update.
:param index: Index of dof to retain.
:type index: int | slice | list[int] | tuple[int] | torch.Tensor
:param int dim: The dimension to update.
:param int dim: Dimension of the indexes in the original tensor.
:return: The updated labels for the specified dimension.
:rtype: list[int]
:raises: ValueError: If the index type is not supported.
"""
old_dof = old_labels[to_update_dim]["dof"]
label_name = old_labels[dim]["name"]
old_dof = self._labels[dim]["dof"]
# Handle slicing
if isinstance(index, slice):
to_update_labels[dim] = {"dof": old_dof[index], "name": label_name}
new_dof = old_dof[index]
# Handle single integer index
elif isinstance(index, int):
to_update_labels[dim] = {
"dof": [old_dof[index]],
"name": label_name,
}
new_dof = [old_dof[index]]
# Handle lists or tensors
elif isinstance(index, (list, torch.Tensor)):
# Handle list of bools
if isinstance(index, torch.Tensor) and index.dtype == torch.bool:
index = index.nonzero().squeeze()
to_update_labels[dim] = {
"dof": (
[old_dof[i] for i in index]
if isinstance(old_dof, list)
else index
),
"name": label_name,
}
new_dof = (
[old_dof[i] for i in index]
if isinstance(old_dof, list)
else index
)
else:
raise NotImplementedError(
f"Unsupported index type: {type(index)}. Expected slice, int, "
f"list, or torch.Tensor."
)
return new_dof
def __getitem__(self, index):
""" "
@@ -589,14 +579,20 @@ class LabelTensor(torch.Tensor):
# Update labels based on the index
offset = 0
removed = 0
for dim, idx in enumerate(index):
if dim in self.stored_labels:
if dim in original_labels:
if isinstance(idx, int):
selected_tensor = selected_tensor.unsqueeze(dim)
# Compute the working dimension considering the removed
# dimensions due to int index on a non labled dimension
dim_ = dim - removed
selected_tensor = selected_tensor.unsqueeze(dim_)
if idx != slice(None):
self._update_single_label(
original_labels, updated_labels, idx, dim, offset
)
# Update the labels for the selected dimension
updated_labels[offset] = {
"dof": self._update_single_label(idx, dim),
"name": original_labels[dim]["name"],
}
else:
# Adjust label keys if dimension is reduced (case of integer
# index on a non-labeled dimension)
@@ -605,6 +601,7 @@ class LabelTensor(torch.Tensor):
key - 1 if key > dim else key: value
for key, value in updated_labels.items()
}
removed += 1
continue
offset += 1

View File

@@ -278,3 +278,63 @@ def test_cat_bool(labels):
assert selected.stored_labels[1]["dof"] == [f"s{i}" for i in range(10)]
if isinstance(labels, dict):
assert selected.stored_labels[0]["dof"] == ["a", "b"]
def test_getitem_int():
data = torch.rand(20, 3)
labels = {1: {"name": 1, "dof": ["x", "y", "z"]}}
lt = LabelTensor(data, labels)
new = lt[0, 0]
assert new.ndim == 1
assert new.shape[0] == 1
assert torch.all(torch.isclose(data[0, 0], new))
data = torch.rand(20, 3, 2)
labels = {
1: {"name": 1, "dof": ["x", "y", "z"]},
2: {"name": 2, "dof": ["a", "b"]},
}
lt = LabelTensor(data, labels)
new = lt[0, 0, 0]
assert new.ndim == 2
assert new.shape[0] == 1
assert new.shape[1] == 1
assert torch.all(torch.isclose(data[0, 0, 0], new))
assert new.stored_labels[0]["dof"] == ["x"]
assert new.stored_labels[1]["dof"] == ["a"]
new = lt[0, 0, :]
assert new.ndim == 2
assert new.shape[0] == 1
assert new.shape[1] == 2
assert torch.all(torch.isclose(data[0, 0, :], new))
assert new.stored_labels[0]["dof"] == ["x"]
assert new.stored_labels[1]["dof"] == ["a", "b"]
new = lt[0, :, 1]
assert new.ndim == 2
assert new.shape[0] == 3
assert new.shape[1] == 1
assert torch.all(torch.isclose(data[0, :, 1], new.squeeze()))
assert new.stored_labels[0]["dof"] == ["x", "y", "z"]
assert new.stored_labels[1]["dof"] == ["b"]
labels.pop(2)
lt = LabelTensor(data, labels)
new = lt[0, 0, 0]
assert new.ndim == 1
assert new.shape[0] == 1
assert new.stored_labels[0]["dof"] == ["x"]
new = lt[:, 0, 0]
assert new.ndim == 2
assert new.shape[0] == 20
assert new.shape[1] == 1
assert new.stored_labels[1]["dof"] == ["x"]
new = lt[:, 0, :]
assert new.ndim == 3
assert new.shape[0] == 20
assert new.shape[1] == 1
assert new.shape[2] == 2
assert new.stored_labels[1]["dof"] == ["x"]