Doc LabelTensor

This commit is contained in:
FilippoOlivo
2025-03-11 14:47:34 +01:00
parent 7b00b80ecb
commit c99157d98a

View File

@@ -6,10 +6,23 @@ from torch import Tensor
class LabelTensor(torch.Tensor): class LabelTensor(torch.Tensor):
"""Torch tensor with a label for any column.""" """
Extension of the :class:`torch.Tensor` class that includes labels for
each dimension.
"""
@staticmethod @staticmethod
def __new__(cls, x, labels, *args, **kwargs): def __new__(cls, x, labels, *args, **kwargs):
"""
Create a new instance of the :class:`LabelTensor` class.
:param torch.Tensor x: :class:`torch.tensor` instance to be casted as a
:class:`LabelTensor`.
:param labels: Labels to assign to the tensor.
:type labels: str | list(str) | dict
:return: The instance of the :class:`LabelTensor` class.
:rtype: LabelTensor
"""
if isinstance(x, LabelTensor): if isinstance(x, LabelTensor):
return x return x
@@ -18,47 +31,47 @@ class LabelTensor(torch.Tensor):
@property @property
def tensor(self): def tensor(self):
""" """
Give the tensor part of the LabelTensor. Give the tensor part of the :class:`LabelTensor` object.
:return: tensor part of the LabelTensor :return: tensor part of the :class:`LabelTensor`.
:rtype: torch.Tensor :rtype: torch.Tensor
""" """
return self.as_subclass(Tensor) return self.as_subclass(Tensor)
def __init__(self, x, labels): def __init__(self, x, labels):
""" """
Construct a `LabelTensor` by passing a dict of the labels Construct a :class:`LabelTensor` by passing a dict of the labels and a
:class:`torch.Tensor`. Internally, the initialization method will store
check the compatibility of the labels with the tensor shape.
:Example: :Example:
>>> from pina import LabelTensor >>> from pina import LabelTensor
>>> tensor = LabelTensor( >>> tensor = LabelTensor(
>>> torch.rand((2000, 3)), >>> torch.rand((2000, 3)),
{1: {"name": "space"['a', 'b', 'c']) ... {1: {"name": "space", "dof": ['a', 'b', 'c'])
>>> tensor = LabelTensor(
>>> torch.rand((2000, 3)),
... ["a", "b", "c"])
""" """
# Avoid unused argument warning. x is not used in the constructor
# of the parent class.
# pylint: disable=unused-argument
super().__init__() super().__init__()
if labels is not None: if labels is not None:
self.labels = labels self.labels = labels
else: else:
self._labels = {} self._labels = {}
@property
def labels(self):
"""Property decorator for labels
:return: labels of self
:rtype: list
"""
if self.ndim - 1 in self._labels:
return self._labels[self.ndim - 1]["dof"]
return None
@property @property
def full_labels(self): def full_labels(self):
"""Property decorator for labels """
Gives the full labels of the tensor, even for the dimensions that are
not labeled.
:return: labels of self :return: The full labels of the tensor
:rtype: list :rtype: dict
""" """
to_return_dict = {} to_return_dict = {}
shape_tensor = self.shape shape_tensor = self.shape
@@ -71,21 +84,40 @@ class LabelTensor(torch.Tensor):
@property @property
def stored_labels(self): def stored_labels(self):
"""Property decorator for labels """
Gives the labels stored inside the instance.
:return: labels of self :return: The labels stored inside the instance.
:rtype: list :rtype: dict
""" """
return self._labels return self._labels
@property
def labels(self):
"""
Give the labels of the last dimension of the instance.
:return: labels of last dimension
:rtype: list
"""
if self.ndim - 1 in self._labels:
return self._labels[self.ndim - 1]["dof"]
return None
@labels.setter @labels.setter
def labels(self, labels): def labels(self, labels):
""" " """
Set properly the parameter _labels Set the parameter ``_labels`` by checking the type of the input labels
and handling it accordingly. The following types are accepted:
- **list**: The list of labels is assigned to the last dimension.
- **dict**: The dictionary of labels is assigned to the tensor.
- **str**: The string is assigned to the last dimension.
:param labels: Labels to assign to the class variable _labels. :param labels: Labels to assign to the class variable _labels.
:type: labels: str | list(str) | dict :type labels: str | list(str) | dict
""" """
if not hasattr(self, "_labels"): if not hasattr(self, "_labels"):
self._labels = {} self._labels = {}
if isinstance(labels, dict): if isinstance(labels, dict):
@@ -100,14 +132,14 @@ class LabelTensor(torch.Tensor):
def _init_labels_from_dict(self, labels: dict): def _init_labels_from_dict(self, labels: dict):
""" """
Update the internal label representation according to the values Store the internal label representation according to the values
passed as input. passed as input.
:param labels: The label(s) to update. :param dict labels: The label(s) to update.
:type labels: dict
:raises ValueError: If the dof list contains duplicates or the number of :raises ValueError: If the dof list contains duplicates or the number of
dof does not match the tensor shape. dof does not match the tensor shape.
""" """
tensor_shape = self.shape tensor_shape = self.shape
def validate_dof(dof_list, dim_size: int): def validate_dof(dof_list, dim_size: int):
@@ -151,12 +183,13 @@ class LabelTensor(torch.Tensor):
def _init_labels_from_list(self, labels): def _init_labels_from_list(self, labels):
""" """
Given a list of dof, this method update the internal label Given a ``list`` of dof, this method update the internal label
representation representation by assigning the dof to the last dimension.
:param labels: The label(s) to update. :param labels: The label(s) to update.
:type labels: list :type labels: list
""" """
# Create a dict with labels # Create a dict with labels
last_dim_labels = { last_dim_labels = {
self.ndim - 1: {"dof": labels, "name": self.ndim - 1} self.ndim - 1: {"dof": labels, "name": self.ndim - 1}
@@ -165,12 +198,19 @@ class LabelTensor(torch.Tensor):
def extract(self, labels_to_extract): def extract(self, labels_to_extract):
""" """
Extract the subset of the original tensor by returning all the columns Extract the subset of the original tensor by returning all the positions
corresponding to the passed ``label_to_extract``. corresponding to the passed ``label_to_extract``.
:param labels_to_extract: The label(s) to extract. :param labels_to_extract: The label(s) to extract. If a single label or
:type labels_to_extract: str | list(str) | tuple(str) a list of labels is passed, the last dimension is considered.
:raises TypeError: Labels are not ``str``. If a dictionary is passed, the keys are the dimension names and the
values are the labels to extract.
:type labels_to_extract: str | list(str) | tuple(str) | dict
:return: The extracted tensor with the updated labels.
:rtype: LabelTensor
:raises TypeError: Labels are not ``str``, ``list(str)`` or ``dict``
properly setted.
:raises ValueError: Label to extract is not in the labels ``list``. :raises ValueError: Label to extract is not in the labels ``list``.
""" """
@@ -231,8 +271,12 @@ class LabelTensor(torch.Tensor):
def __str__(self): def __str__(self):
""" """
returns a string with the representation of the class The string representation of the :class:`LabelTensor`.
:return: String representation of the :class:`LabelTensor` instance.
:rtype: str
""" """
s = "" s = ""
for key, value in self._labels.items(): for key, value in self._labels.items():
s += f"{key}: {value}\n" s += f"{key}: {value}\n"
@@ -243,18 +287,20 @@ class LabelTensor(torch.Tensor):
@staticmethod @staticmethod
def cat(tensors, dim=0): def cat(tensors, dim=0):
""" """
Stack a list of tensors. For example, given a tensor `a` of shape Concatenate a list of tensors along a specified dimension. For more
`(n,m,dof)` and a tensor `b` of dimension `(n',m,dof)` details, see :meth:`torch.cat`.
the resulting tensor is of shape `(n+n',m,dof)`
:param list(LabelTensor) tensors: :class:`LabelTensor` instances to
concatenate
:param int dim: dimensions on which you want to perform the operation
(default is 0)
:return: A new :class:`LabelTensor' instance obtained by concatenating
the input instances, with the updated labels.
:param tensors: tensors to concatenate
:type tensors: list of LabelTensor
:param dim: dimensions on which you want to perform the operation
(default is 0)
:type dim: int
:rtype: LabelTensor :rtype: LabelTensor
:raises ValueError: either number dof or dimensions names differ :raises ValueError: either number dof or dimensions names differ
""" """
if not tensors: if not tensors:
return [] # Handle empty list return [] # Handle empty list
if len(tensors) == 1: if len(tensors) == 1:
@@ -295,15 +341,16 @@ class LabelTensor(torch.Tensor):
@staticmethod @staticmethod
def stack(tensors): def stack(tensors):
""" """
Stacks a list of tensors along a new dimension. Stacks a list of tensors along a new dimension. For more details, see
:meth:`torch.stack`.
:param tensors: A list of tensors to stack. All tensors must have the :param list(LabelTensor) tensors: A list of tensors to stack.
same shape. All tensors must have the same shape.
:type tensors: list of LabelTensor :return: A new :class:`LabelTensor` instance obtained by stacking the
:return: A new tensor obtained by stacking the input tensors, input tensors, with the updated labels.
with the updated labels.
:rtype: LabelTensor :rtype: LabelTensor
""" """
# Perform stacking in torch # Perform stacking in torch
new_tensor = torch.stack(tensors) new_tensor = torch.stack(tensors)
@@ -315,17 +362,17 @@ class LabelTensor(torch.Tensor):
def requires_grad_(self, mode=True): def requires_grad_(self, mode=True):
""" """
Override the requires_grad_ method to update the labels in the new Override the requires_grad_ method to handle the labels in the new
tensor. tensor. For more details, see :meth:`torch.Tensor.requires_grad_`.
:param mode: A boolean value indicating whether the tensor should track :param bool mode: A boolean value indicating whether the tensor should
gradients.If `True`, the tensor will track gradients; if ` track gradients.If `True`, the tensor will track gradients;
False`, it will not. if `False`, it will not.
:type mode: bool, optional (default is `True`) :return: The :class:`LabelTensor` itself with the updated
:return: The tensor itself with the updated `requires_grad` state and `requires_grad` state and retained labels.
retained labels.
:rtype: LabelTensor :rtype: LabelTensor
""" """
lt = super().requires_grad_(mode) lt = super().requires_grad_(mode)
lt._labels = self._labels lt._labels = self._labels
return lt return lt
@@ -333,30 +380,39 @@ class LabelTensor(torch.Tensor):
@property @property
def dtype(self): def dtype(self):
""" """
Give the dtype of the tensor. Give the ``dtype`` of the tensor. For more details, see
:meth:`torch.dtype`.
:return: dtype of the tensor :return: dtype of the tensor
:rtype: torch.dtype :rtype: torch.dtype
""" """
return super().dtype return super().dtype
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
""" """
Performs Tensor dtype and/or device conversion. For more details, see Performs Tensor dtype and/or device conversion. For more details, see
:meth:`torch.Tensor.to`. :meth:`torch.Tensor.to`.
:return: A new :class:`LabelTensor` instance with the updated dtype
and/or device and retained labels.
:rtype: LabelTensor
""" """
lt = super().to(*args, **kwargs) lt = super().to(*args, **kwargs)
lt._labels = self._labels lt._labels = self._labels
return lt return lt
def clone(self, *args, **kwargs): def clone(self, *args, **kwargs):
""" """
Clone the LabelTensor. For more details, see Clone the :class:`LabelTensor`. For more details, see
:meth:`torch.Tensor.clone`. :meth:`torch.Tensor.clone`.
:return: A copy of the tensor. :return: A new :class:`LabelTensor` instance with the same data and
labels but allocated in a different memory location.
:rtype: LabelTensor :rtype: LabelTensor
""" """
out = LabelTensor( out = LabelTensor(
super().clone(*args, **kwargs), deepcopy(self._labels) super().clone(*args, **kwargs), deepcopy(self._labels)
) )
@@ -366,21 +422,23 @@ class LabelTensor(torch.Tensor):
""" """
Appends a given tensor to the current tensor along the last dimension. Appends a given tensor to the current tensor along the last dimension.
This method allows for two types of appending operations: This method supports two types of appending operations:
1. **Standard append** ("std"): Concatenates the tensors along the 1. **Standard append** ("std"): Concatenates the input tensor with the
last dimension. current tensor along the last dimension.
2. **Cross append** ("cross"): Repeats the current tensor and the new 2. **Cross append** ("cross"): Creates a cross-product of the current
tensor in a cross-product manner, then concatenates them. tensor and the input tensor by repeating them in a cross-product
fashion, then concatenates the result along the last dimension.
:param LabelTensor tensor: The tensor to append. :param tensor: The tensor to append to the current tensor.
:type tensor: LabelTensor
:param mode: The append mode to use. Defaults to "std". :param mode: The append mode to use. Defaults to "std".
:type mode: str, optional :type mode: str, optional
:return: The new tensor obtained by appending the input tensor :return: A new `LabelTensor` obtained by appending the input tensor.
(either 'std' or 'cross').
:rtype: LabelTensor :rtype: LabelTensor
:raises ValueError: If the mode is not "std" or "cross". :raises ValueError: If the mode is not "std" or "cross".
""" """
if mode == "std": if mode == "std":
# Call cat on last dimension # Call cat on last dimension
new_label_tensor = LabelTensor.cat( new_label_tensor = LabelTensor.cat(
@@ -406,14 +464,15 @@ class LabelTensor(torch.Tensor):
@staticmethod @staticmethod
def vstack(label_tensors): def vstack(label_tensors):
""" """
Stack tensors vertically. For more details, see Stack tensors vertically. For more details, see :meth:`torch.vstack`.
:meth:`torch.vstack`.
:param list(LabelTensor) label_tensors: the tensors to stack. They need :param list(LabelTensor) label_tensors: The :class:`LabelTensor`
to have equal labels. instances to stack. They need to have equal labels.
:return: the stacked tensor :return: A new :class:`LabelTensor` instance obtained by stacking the
input tensors vertically.
:rtype: LabelTensor :rtype: LabelTensor
""" """
return LabelTensor.cat(label_tensors, dim=0) return LabelTensor.cat(label_tensors, dim=0)
# This method is used to update labels # This method is used to update labels
@@ -421,13 +480,17 @@ class LabelTensor(torch.Tensor):
self, old_labels, to_update_labels, index, dim, to_update_dim self, old_labels, to_update_labels, index, dim, to_update_dim
): ):
""" """
Update the labels of the tensor by selecting only the labels Update the labels of the tensor based on the index (or list of indices).
:param old_labels: labels from which retrieve data
:param to_update_labels: labels to update :param dict old_labels: Labels from which retrieve data.
:param index: index of dof to retain :param dict to_update_labels: Labels to update.
:param dim: label index :param index: Index of dof to retain.
:return: :type index: int | slice | list | torch.Tensor]
:param int dim: The dimension to update.
:raises: ValueError: If the index type is not supported.
""" """
old_dof = old_labels[to_update_dim]["dof"] old_dof = old_labels[to_update_dim]["dof"]
label_name = old_labels[dim]["name"] label_name = old_labels[dim]["name"]
# Handle slicing # Handle slicing
@@ -460,16 +523,22 @@ class LabelTensor(torch.Tensor):
def __getitem__(self, index): def __getitem__(self, index):
""" " """ "
Override the __getitem__ method to handle the labels of the tensor. Override the __getitem__ method to handle the labels of the
Perform the __getitem__ operation on the tensor and update the labels. :class:`LabelTensor` instance. It first performs __getitem__ operation
on the :class:`torch.Tensor` part of the instance, then updates the
labels based on the index.
:param index: The index used to access the item :param index: The index used to access the item
:type index: Union[int, str, tuple, list] :type index: int | str | tuple | list | torch.Tensor
:return: A tensor-like object with updated labels. :return: A new :class:`LabelTensor` instance obtained __getitem__
operation on :class:`torch.Tensor` part of the instance, with the
updated labels.
:rtype: LabelTensor :rtype: LabelTensor
:raises KeyError: If an invalid label index is provided. :raises KeyError: If an invalid label index is provided.
:raises IndexError: If an invalid index is accessed in the tensor. :raises IndexError: If an invalid index is accessed in the tensor.
""" """
# Handle string index # Handle string index
if isinstance(index, str) or ( if isinstance(index, str) or (
isinstance(index, (tuple, list)) isinstance(index, (tuple, list))
@@ -516,12 +585,11 @@ class LabelTensor(torch.Tensor):
def sort_labels(self, dim=None): def sort_labels(self, dim=None):
""" """
Sorts the labels along a specified dimension and returns a new tensor Sort the labels along the specified dimension and apply the same sorting
with sorted labels. to the :class:`torch.Tensor` part of the instance.
:param dim: The dimension along which to sort the labels. If `None`, :param int dim: The dimension along which to sort the labels.
the last dimension (`ndim - 1`) is used. If ``None``, the last dimension (``ndim - 1``) is used.
:type dim: int, optional
:return: A new tensor with sorted labels along the specified dimension. :return: A new tensor with sorted labels along the specified dimension.
:rtype: LabelTensor :rtype: LabelTensor
""" """
@@ -543,13 +611,15 @@ class LabelTensor(torch.Tensor):
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
""" """
Creates a deep copy of the object. Creates a deep copy of the object. For more details, see
:meth:`copy.deepcopy`.
:param memo: LabelTensor object to be copied. :param memo: LabelTensor object to be copied.
:type memo: LabelTensor :type memo: LabelTensor
:return: A deep copy of the original LabelTensor object. :return: A deep copy of the original LabelTensor object.
:rtype: LabelTensor :rtype: LabelTensor
""" """
cls = self.__class__ cls = self.__class__
result = cls(deepcopy(self.tensor), deepcopy(self.stored_labels)) result = cls(deepcopy(self.tensor), deepcopy(self.stored_labels))
return result return result
@@ -557,7 +627,7 @@ class LabelTensor(torch.Tensor):
def permute(self, *dims): def permute(self, *dims):
""" """
Permutes the dimensions of the tensor and the associated labels Permutes the dimensions of the tensor and the associated labels
accordingly. accordingly. For more details, see :meth:`torch.Tensor.permute`.
:param dims: The dimensions to permute the tensor to. :param dims: The dimensions to permute the tensor to.
:type dims: tuple, list :type dims: tuple, list
@@ -579,11 +649,12 @@ class LabelTensor(torch.Tensor):
def detach(self): def detach(self):
""" """
Detaches the tensor from the computation graph and retains the stored Detaches the tensor from the computation graph and retains the stored
labels. labels. For more details, see :meth:`torch.Tensor.detach`.
:return: A new tensor detached from the computation graph. :return: A new tensor detached from the computation graph.
:rtype: LabelTensor :rtype: LabelTensor
""" """
lt = super().detach() lt = super().detach()
# Copy the labels to the new tensor only if present # Copy the labels to the new tensor only if present
@@ -594,14 +665,15 @@ class LabelTensor(torch.Tensor):
@staticmethod @staticmethod
def summation(tensors): def summation(tensors):
""" """
Computes the summation of a list of tensors. Computes the summation of a list of :class:`LabelTensor` instances.
:param tensors: A list of tensors to sum. All tensors must have the same :param list(LabelTensor) tensors: A list of tensors to sum. All tensors
shape and labels. must have the same shape and labels.
:type tensors: list of LabelTensor
:return: A new `LabelTensor` containing the element-wise sum of the :return: A new `LabelTensor` containing the element-wise sum of the
input tensors. input tensors.
:rtype: LabelTensor :rtype: LabelTensor
:raises ValueError: If the input `tensors` list is empty. :raises ValueError: If the input `tensors` list is empty.
:raises RuntimeError: If the tensors have different shapes and/or :raises RuntimeError: If the tensors have different shapes and/or
mismatched labels. mismatched labels.
@@ -637,12 +709,14 @@ class LabelTensor(torch.Tensor):
def reshape(self, *shape): def reshape(self, *shape):
""" """
Override the reshape method to update the labels of the tensor. Override the reshape method to update the labels of the tensor.
For more details, see :meth:`torch.Tensor.reshape`.
:param shape: The new shape of the tensor. :param tuple shape: The new shape of the tensor.
:type shape: tuple :return: A new :class:`LabelTensor` instance with the updated shape and
:return: A tensor-like object with updated labels. labels.
:rtype: LabelTensor :rtype: LabelTensor
""" """
# As for now the reshape method is used only in the context of the # As for now the reshape method is used only in the context of the
# dataset, the labels are not # dataset, the labels are not
tensor = super().reshape(*shape) tensor = super().reshape(*shape)