Fix __getitem__ in LabelTensor (#546)
* Fix LabelTensor * Cleaning label_tensor.py --------- Co-authored-by: Dario Coscia <93731561+dario-coscia@users.noreply.github.com>
This commit is contained in:
committed by
FilippoOlivo
parent
e250e3f5f7
commit
a1f98764d3
@@ -505,50 +505,40 @@ class LabelTensor(torch.Tensor):
|
||||
return LabelTensor.cat(tensors, dim=0)
|
||||
|
||||
# This method is used to update labels
|
||||
def _update_single_label(
|
||||
self, old_labels, to_update_labels, index, dim, to_update_dim
|
||||
):
|
||||
def _update_single_label(self, index, dim):
|
||||
"""
|
||||
Update the labels of the tensor based on the index (or list of indices).
|
||||
|
||||
:param dict old_labels: Labels from which retrieve data.
|
||||
:param dict to_update_labels: Labels to update.
|
||||
:param index: Index of dof to retain.
|
||||
:type index: int | slice | list[int] | tuple[int] | torch.Tensor
|
||||
:param int dim: The dimension to update.
|
||||
|
||||
:param int dim: Dimension of the indexes in the original tensor.
|
||||
:return: The updated labels for the specified dimension.
|
||||
:rtype: list[int]
|
||||
:raises: ValueError: If the index type is not supported.
|
||||
"""
|
||||
|
||||
old_dof = old_labels[to_update_dim]["dof"]
|
||||
label_name = old_labels[dim]["name"]
|
||||
old_dof = self._labels[dim]["dof"]
|
||||
# Handle slicing
|
||||
if isinstance(index, slice):
|
||||
to_update_labels[dim] = {"dof": old_dof[index], "name": label_name}
|
||||
new_dof = old_dof[index]
|
||||
# Handle single integer index
|
||||
elif isinstance(index, int):
|
||||
to_update_labels[dim] = {
|
||||
"dof": [old_dof[index]],
|
||||
"name": label_name,
|
||||
}
|
||||
new_dof = [old_dof[index]]
|
||||
# Handle lists or tensors
|
||||
elif isinstance(index, (list, torch.Tensor)):
|
||||
# Handle list of bools
|
||||
if isinstance(index, torch.Tensor) and index.dtype == torch.bool:
|
||||
index = index.nonzero().squeeze()
|
||||
to_update_labels[dim] = {
|
||||
"dof": (
|
||||
[old_dof[i] for i in index]
|
||||
if isinstance(old_dof, list)
|
||||
else index
|
||||
),
|
||||
"name": label_name,
|
||||
}
|
||||
new_dof = (
|
||||
[old_dof[i] for i in index]
|
||||
if isinstance(old_dof, list)
|
||||
else index
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported index type: {type(index)}. Expected slice, int, "
|
||||
f"list, or torch.Tensor."
|
||||
)
|
||||
return new_dof
|
||||
|
||||
def __getitem__(self, index):
|
||||
""" "
|
||||
@@ -589,14 +579,20 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
# Update labels based on the index
|
||||
offset = 0
|
||||
removed = 0
|
||||
for dim, idx in enumerate(index):
|
||||
if dim in self.stored_labels:
|
||||
if dim in original_labels:
|
||||
if isinstance(idx, int):
|
||||
selected_tensor = selected_tensor.unsqueeze(dim)
|
||||
# Compute the working dimension considering the removed
|
||||
# dimensions due to int index on a non labled dimension
|
||||
dim_ = dim - removed
|
||||
selected_tensor = selected_tensor.unsqueeze(dim_)
|
||||
if idx != slice(None):
|
||||
self._update_single_label(
|
||||
original_labels, updated_labels, idx, dim, offset
|
||||
)
|
||||
# Update the labels for the selected dimension
|
||||
updated_labels[offset] = {
|
||||
"dof": self._update_single_label(idx, dim),
|
||||
"name": original_labels[dim]["name"],
|
||||
}
|
||||
else:
|
||||
# Adjust label keys if dimension is reduced (case of integer
|
||||
# index on a non-labeled dimension)
|
||||
@@ -605,6 +601,7 @@ class LabelTensor(torch.Tensor):
|
||||
key - 1 if key > dim else key: value
|
||||
for key, value in updated_labels.items()
|
||||
}
|
||||
removed += 1
|
||||
continue
|
||||
offset += 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user