Fix __getitem__ in LabelTensor (#546)
* Fix LabelTensor * Cleaning label_tensor.py --------- Co-authored-by: Dario Coscia <93731561+dario-coscia@users.noreply.github.com>
This commit is contained in:
@@ -505,50 +505,40 @@ class LabelTensor(torch.Tensor):
|
|||||||
return LabelTensor.cat(tensors, dim=0)
|
return LabelTensor.cat(tensors, dim=0)
|
||||||
|
|
||||||
# This method is used to update labels
|
# This method is used to update labels
|
||||||
def _update_single_label(
|
def _update_single_label(self, index, dim):
|
||||||
self, old_labels, to_update_labels, index, dim, to_update_dim
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Update the labels of the tensor based on the index (or list of indices).
|
Update the labels of the tensor based on the index (or list of indices).
|
||||||
|
|
||||||
:param dict old_labels: Labels from which retrieve data.
|
|
||||||
:param dict to_update_labels: Labels to update.
|
|
||||||
:param index: Index of dof to retain.
|
:param index: Index of dof to retain.
|
||||||
:type index: int | slice | list[int] | tuple[int] | torch.Tensor
|
:type index: int | slice | list[int] | tuple[int] | torch.Tensor
|
||||||
:param int dim: The dimension to update.
|
:param int dim: Dimension of the indexes in the original tensor.
|
||||||
|
:return: The updated labels for the specified dimension.
|
||||||
|
:rtype: list[int]
|
||||||
:raises: ValueError: If the index type is not supported.
|
:raises: ValueError: If the index type is not supported.
|
||||||
"""
|
"""
|
||||||
|
old_dof = self._labels[dim]["dof"]
|
||||||
old_dof = old_labels[to_update_dim]["dof"]
|
|
||||||
label_name = old_labels[dim]["name"]
|
|
||||||
# Handle slicing
|
# Handle slicing
|
||||||
if isinstance(index, slice):
|
if isinstance(index, slice):
|
||||||
to_update_labels[dim] = {"dof": old_dof[index], "name": label_name}
|
new_dof = old_dof[index]
|
||||||
# Handle single integer index
|
# Handle single integer index
|
||||||
elif isinstance(index, int):
|
elif isinstance(index, int):
|
||||||
to_update_labels[dim] = {
|
new_dof = [old_dof[index]]
|
||||||
"dof": [old_dof[index]],
|
|
||||||
"name": label_name,
|
|
||||||
}
|
|
||||||
# Handle lists or tensors
|
# Handle lists or tensors
|
||||||
elif isinstance(index, (list, torch.Tensor)):
|
elif isinstance(index, (list, torch.Tensor)):
|
||||||
# Handle list of bools
|
# Handle list of bools
|
||||||
if isinstance(index, torch.Tensor) and index.dtype == torch.bool:
|
if isinstance(index, torch.Tensor) and index.dtype == torch.bool:
|
||||||
index = index.nonzero().squeeze()
|
index = index.nonzero().squeeze()
|
||||||
to_update_labels[dim] = {
|
new_dof = (
|
||||||
"dof": (
|
|
||||||
[old_dof[i] for i in index]
|
[old_dof[i] for i in index]
|
||||||
if isinstance(old_dof, list)
|
if isinstance(old_dof, list)
|
||||||
else index
|
else index
|
||||||
),
|
)
|
||||||
"name": label_name,
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Unsupported index type: {type(index)}. Expected slice, int, "
|
f"Unsupported index type: {type(index)}. Expected slice, int, "
|
||||||
f"list, or torch.Tensor."
|
f"list, or torch.Tensor."
|
||||||
)
|
)
|
||||||
|
return new_dof
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
""" "
|
""" "
|
||||||
@@ -589,14 +579,20 @@ class LabelTensor(torch.Tensor):
|
|||||||
|
|
||||||
# Update labels based on the index
|
# Update labels based on the index
|
||||||
offset = 0
|
offset = 0
|
||||||
|
removed = 0
|
||||||
for dim, idx in enumerate(index):
|
for dim, idx in enumerate(index):
|
||||||
if dim in self.stored_labels:
|
if dim in original_labels:
|
||||||
if isinstance(idx, int):
|
if isinstance(idx, int):
|
||||||
selected_tensor = selected_tensor.unsqueeze(dim)
|
# Compute the working dimension considering the removed
|
||||||
|
# dimensions due to int index on a non labled dimension
|
||||||
|
dim_ = dim - removed
|
||||||
|
selected_tensor = selected_tensor.unsqueeze(dim_)
|
||||||
if idx != slice(None):
|
if idx != slice(None):
|
||||||
self._update_single_label(
|
# Update the labels for the selected dimension
|
||||||
original_labels, updated_labels, idx, dim, offset
|
updated_labels[offset] = {
|
||||||
)
|
"dof": self._update_single_label(idx, dim),
|
||||||
|
"name": original_labels[dim]["name"],
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
# Adjust label keys if dimension is reduced (case of integer
|
# Adjust label keys if dimension is reduced (case of integer
|
||||||
# index on a non-labeled dimension)
|
# index on a non-labeled dimension)
|
||||||
@@ -605,6 +601,7 @@ class LabelTensor(torch.Tensor):
|
|||||||
key - 1 if key > dim else key: value
|
key - 1 if key > dim else key: value
|
||||||
for key, value in updated_labels.items()
|
for key, value in updated_labels.items()
|
||||||
}
|
}
|
||||||
|
removed += 1
|
||||||
continue
|
continue
|
||||||
offset += 1
|
offset += 1
|
||||||
|
|
||||||
|
|||||||
@@ -278,3 +278,63 @@ def test_cat_bool(labels):
|
|||||||
assert selected.stored_labels[1]["dof"] == [f"s{i}" for i in range(10)]
|
assert selected.stored_labels[1]["dof"] == [f"s{i}" for i in range(10)]
|
||||||
if isinstance(labels, dict):
|
if isinstance(labels, dict):
|
||||||
assert selected.stored_labels[0]["dof"] == ["a", "b"]
|
assert selected.stored_labels[0]["dof"] == ["a", "b"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_getitem_int():
|
||||||
|
data = torch.rand(20, 3)
|
||||||
|
labels = {1: {"name": 1, "dof": ["x", "y", "z"]}}
|
||||||
|
lt = LabelTensor(data, labels)
|
||||||
|
new = lt[0, 0]
|
||||||
|
assert new.ndim == 1
|
||||||
|
assert new.shape[0] == 1
|
||||||
|
assert torch.all(torch.isclose(data[0, 0], new))
|
||||||
|
|
||||||
|
data = torch.rand(20, 3, 2)
|
||||||
|
labels = {
|
||||||
|
1: {"name": 1, "dof": ["x", "y", "z"]},
|
||||||
|
2: {"name": 2, "dof": ["a", "b"]},
|
||||||
|
}
|
||||||
|
lt = LabelTensor(data, labels)
|
||||||
|
new = lt[0, 0, 0]
|
||||||
|
assert new.ndim == 2
|
||||||
|
assert new.shape[0] == 1
|
||||||
|
assert new.shape[1] == 1
|
||||||
|
assert torch.all(torch.isclose(data[0, 0, 0], new))
|
||||||
|
assert new.stored_labels[0]["dof"] == ["x"]
|
||||||
|
assert new.stored_labels[1]["dof"] == ["a"]
|
||||||
|
|
||||||
|
new = lt[0, 0, :]
|
||||||
|
assert new.ndim == 2
|
||||||
|
assert new.shape[0] == 1
|
||||||
|
assert new.shape[1] == 2
|
||||||
|
assert torch.all(torch.isclose(data[0, 0, :], new))
|
||||||
|
assert new.stored_labels[0]["dof"] == ["x"]
|
||||||
|
assert new.stored_labels[1]["dof"] == ["a", "b"]
|
||||||
|
|
||||||
|
new = lt[0, :, 1]
|
||||||
|
assert new.ndim == 2
|
||||||
|
assert new.shape[0] == 3
|
||||||
|
assert new.shape[1] == 1
|
||||||
|
assert torch.all(torch.isclose(data[0, :, 1], new.squeeze()))
|
||||||
|
assert new.stored_labels[0]["dof"] == ["x", "y", "z"]
|
||||||
|
assert new.stored_labels[1]["dof"] == ["b"]
|
||||||
|
|
||||||
|
labels.pop(2)
|
||||||
|
lt = LabelTensor(data, labels)
|
||||||
|
new = lt[0, 0, 0]
|
||||||
|
assert new.ndim == 1
|
||||||
|
assert new.shape[0] == 1
|
||||||
|
assert new.stored_labels[0]["dof"] == ["x"]
|
||||||
|
|
||||||
|
new = lt[:, 0, 0]
|
||||||
|
assert new.ndim == 2
|
||||||
|
assert new.shape[0] == 20
|
||||||
|
assert new.shape[1] == 1
|
||||||
|
assert new.stored_labels[1]["dof"] == ["x"]
|
||||||
|
|
||||||
|
new = lt[:, 0, :]
|
||||||
|
assert new.ndim == 3
|
||||||
|
assert new.shape[0] == 20
|
||||||
|
assert new.shape[1] == 1
|
||||||
|
assert new.shape[2] == 2
|
||||||
|
assert new.stored_labels[1]["dof"] == ["x"]
|
||||||
|
|||||||
Reference in New Issue
Block a user