From c99157d98a453e61520d220a13207e8c622c9d2f Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Tue, 11 Mar 2025 14:47:34 +0100 Subject: [PATCH] Doc LabelTensor --- pina/label_tensor.py | 270 +++++++++++++++++++++++++++---------------- 1 file changed, 172 insertions(+), 98 deletions(-) diff --git a/pina/label_tensor.py b/pina/label_tensor.py index cce141c..08132a7 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -6,10 +6,23 @@ from torch import Tensor class LabelTensor(torch.Tensor): - """Torch tensor with a label for any column.""" + """ + Extension of the :class:`torch.Tensor` class that includes labels for + each dimension. + """ @staticmethod def __new__(cls, x, labels, *args, **kwargs): + """ + Create a new instance of the :class:`LabelTensor` class. + + :param torch.Tensor x: :class:`torch.tensor` instance to be casted as a + :class:`LabelTensor`. + :param labels: Labels to assign to the tensor. + :type labels: str | list(str) | dict + :return: The instance of the :class:`LabelTensor` class. + :rtype: LabelTensor + """ if isinstance(x, LabelTensor): return x @@ -18,47 +31,47 @@ class LabelTensor(torch.Tensor): @property def tensor(self): """ - Give the tensor part of the LabelTensor. + Give the tensor part of the :class:`LabelTensor` object. - :return: tensor part of the LabelTensor + :return: tensor part of the :class:`LabelTensor`. :rtype: torch.Tensor """ + return self.as_subclass(Tensor) def __init__(self, x, labels): """ - Construct a `LabelTensor` by passing a dict of the labels + Construct a :class:`LabelTensor` by passing a dict of the labels and a + :class:`torch.Tensor`. Internally, the initialization method will store + check the compatibility of the labels with the tensor shape. :Example: >>> from pina import LabelTensor >>> tensor = LabelTensor( >>> torch.rand((2000, 3)), - {1: {"name": "space"['a', 'b', 'c']) + ... {1: {"name": "space", "dof": ['a', 'b', 'c']) + >>> tensor = LabelTensor( + >>> torch.rand((2000, 3)), + ... ["a", "b", "c"]) """ + # Avoid unused argument warning. x is not used in the constructor + # of the parent class. + # pylint: disable=unused-argument super().__init__() if labels is not None: self.labels = labels else: self._labels = {} - @property - def labels(self): - """Property decorator for labels - - :return: labels of self - :rtype: list - """ - if self.ndim - 1 in self._labels: - return self._labels[self.ndim - 1]["dof"] - return None - @property def full_labels(self): - """Property decorator for labels + """ + Gives the full labels of the tensor, even for the dimensions that are + not labeled. - :return: labels of self - :rtype: list + :return: The full labels of the tensor + :rtype: dict """ to_return_dict = {} shape_tensor = self.shape @@ -71,21 +84,40 @@ class LabelTensor(torch.Tensor): @property def stored_labels(self): - """Property decorator for labels + """ + Gives the labels stored inside the instance. - :return: labels of self - :rtype: list + :return: The labels stored inside the instance. + :rtype: dict """ return self._labels + @property + def labels(self): + """ + Give the labels of the last dimension of the instance. + + :return: labels of last dimension + :rtype: list + """ + if self.ndim - 1 in self._labels: + return self._labels[self.ndim - 1]["dof"] + return None + @labels.setter def labels(self, labels): - """ " - Set properly the parameter _labels + """ + Set the parameter ``_labels`` by checking the type of the input labels + and handling it accordingly. The following types are accepted: + + - **list**: The list of labels is assigned to the last dimension. + - **dict**: The dictionary of labels is assigned to the tensor. + - **str**: The string is assigned to the last dimension. :param labels: Labels to assign to the class variable _labels. - :type: labels: str | list(str) | dict + :type labels: str | list(str) | dict """ + if not hasattr(self, "_labels"): self._labels = {} if isinstance(labels, dict): @@ -100,14 +132,14 @@ class LabelTensor(torch.Tensor): def _init_labels_from_dict(self, labels: dict): """ - Update the internal label representation according to the values + Store the internal label representation according to the values passed as input. - :param labels: The label(s) to update. - :type labels: dict + :param dict labels: The label(s) to update. :raises ValueError: If the dof list contains duplicates or the number of dof does not match the tensor shape. """ + tensor_shape = self.shape def validate_dof(dof_list, dim_size: int): @@ -151,12 +183,13 @@ class LabelTensor(torch.Tensor): def _init_labels_from_list(self, labels): """ - Given a list of dof, this method update the internal label - representation + Given a ``list`` of dof, this method update the internal label + representation by assigning the dof to the last dimension. :param labels: The label(s) to update. :type labels: list """ + # Create a dict with labels last_dim_labels = { self.ndim - 1: {"dof": labels, "name": self.ndim - 1} @@ -165,12 +198,19 @@ class LabelTensor(torch.Tensor): def extract(self, labels_to_extract): """ - Extract the subset of the original tensor by returning all the columns + Extract the subset of the original tensor by returning all the positions corresponding to the passed ``label_to_extract``. - :param labels_to_extract: The label(s) to extract. - :type labels_to_extract: str | list(str) | tuple(str) - :raises TypeError: Labels are not ``str``. + :param labels_to_extract: The label(s) to extract. If a single label or + a list of labels is passed, the last dimension is considered. + If a dictionary is passed, the keys are the dimension names and the + values are the labels to extract. + :type labels_to_extract: str | list(str) | tuple(str) | dict + :return: The extracted tensor with the updated labels. + :rtype: LabelTensor + + :raises TypeError: Labels are not ``str``, ``list(str)`` or ``dict`` + properly setted. :raises ValueError: Label to extract is not in the labels ``list``. """ @@ -231,8 +271,12 @@ class LabelTensor(torch.Tensor): def __str__(self): """ - returns a string with the representation of the class + The string representation of the :class:`LabelTensor`. + + :return: String representation of the :class:`LabelTensor` instance. + :rtype: str """ + s = "" for key, value in self._labels.items(): s += f"{key}: {value}\n" @@ -243,18 +287,20 @@ class LabelTensor(torch.Tensor): @staticmethod def cat(tensors, dim=0): """ - Stack a list of tensors. For example, given a tensor `a` of shape - `(n,m,dof)` and a tensor `b` of dimension `(n',m,dof)` - the resulting tensor is of shape `(n+n',m,dof)` + Concatenate a list of tensors along a specified dimension. For more + details, see :meth:`torch.cat`. + + :param list(LabelTensor) tensors: :class:`LabelTensor` instances to + concatenate + :param int dim: dimensions on which you want to perform the operation + (default is 0) + :return: A new :class:`LabelTensor' instance obtained by concatenating + the input instances, with the updated labels. - :param tensors: tensors to concatenate - :type tensors: list of LabelTensor - :param dim: dimensions on which you want to perform the operation - (default is 0) - :type dim: int :rtype: LabelTensor :raises ValueError: either number dof or dimensions names differ """ + if not tensors: return [] # Handle empty list if len(tensors) == 1: @@ -295,15 +341,16 @@ class LabelTensor(torch.Tensor): @staticmethod def stack(tensors): """ - Stacks a list of tensors along a new dimension. + Stacks a list of tensors along a new dimension. For more details, see + :meth:`torch.stack`. - :param tensors: A list of tensors to stack. All tensors must have the - same shape. - :type tensors: list of LabelTensor - :return: A new tensor obtained by stacking the input tensors, - with the updated labels. + :param list(LabelTensor) tensors: A list of tensors to stack. + All tensors must have the same shape. + :return: A new :class:`LabelTensor` instance obtained by stacking the + input tensors, with the updated labels. :rtype: LabelTensor """ + # Perform stacking in torch new_tensor = torch.stack(tensors) @@ -315,17 +362,17 @@ class LabelTensor(torch.Tensor): def requires_grad_(self, mode=True): """ - Override the requires_grad_ method to update the labels in the new - tensor. + Override the requires_grad_ method to handle the labels in the new + tensor. For more details, see :meth:`torch.Tensor.requires_grad_`. - :param mode: A boolean value indicating whether the tensor should track - gradients.If `True`, the tensor will track gradients; if ` - False`, it will not. - :type mode: bool, optional (default is `True`) - :return: The tensor itself with the updated `requires_grad` state and - retained labels. + :param bool mode: A boolean value indicating whether the tensor should + track gradients.If `True`, the tensor will track gradients; + if `False`, it will not. + :return: The :class:`LabelTensor` itself with the updated + `requires_grad` state and retained labels. :rtype: LabelTensor """ + lt = super().requires_grad_(mode) lt._labels = self._labels return lt @@ -333,30 +380,39 @@ class LabelTensor(torch.Tensor): @property def dtype(self): """ - Give the dtype of the tensor. + Give the ``dtype`` of the tensor. For more details, see + :meth:`torch.dtype`. :return: dtype of the tensor :rtype: torch.dtype """ + return super().dtype def to(self, *args, **kwargs): """ Performs Tensor dtype and/or device conversion. For more details, see :meth:`torch.Tensor.to`. + + :return: A new :class:`LabelTensor` instance with the updated dtype + and/or device and retained labels. + :rtype: LabelTensor """ + lt = super().to(*args, **kwargs) lt._labels = self._labels return lt def clone(self, *args, **kwargs): """ - Clone the LabelTensor. For more details, see + Clone the :class:`LabelTensor`. For more details, see :meth:`torch.Tensor.clone`. - :return: A copy of the tensor. + :return: A new :class:`LabelTensor` instance with the same data and + labels but allocated in a different memory location. :rtype: LabelTensor """ + out = LabelTensor( super().clone(*args, **kwargs), deepcopy(self._labels) ) @@ -366,21 +422,23 @@ class LabelTensor(torch.Tensor): """ Appends a given tensor to the current tensor along the last dimension. - This method allows for two types of appending operations: - 1. **Standard append** ("std"): Concatenates the tensors along the - last dimension. - 2. **Cross append** ("cross"): Repeats the current tensor and the new - tensor in a cross-product manner, then concatenates them. + This method supports two types of appending operations: + 1. **Standard append** ("std"): Concatenates the input tensor with the + current tensor along the last dimension. + 2. **Cross append** ("cross"): Creates a cross-product of the current + tensor and the input tensor by repeating them in a cross-product + fashion, then concatenates the result along the last dimension. - :param LabelTensor tensor: The tensor to append. + :param tensor: The tensor to append to the current tensor. + :type tensor: LabelTensor :param mode: The append mode to use. Defaults to "std". :type mode: str, optional - :return: The new tensor obtained by appending the input tensor - (either 'std' or 'cross'). + :return: A new `LabelTensor` obtained by appending the input tensor. :rtype: LabelTensor :raises ValueError: If the mode is not "std" or "cross". """ + if mode == "std": # Call cat on last dimension new_label_tensor = LabelTensor.cat( @@ -406,14 +464,15 @@ class LabelTensor(torch.Tensor): @staticmethod def vstack(label_tensors): """ - Stack tensors vertically. For more details, see - :meth:`torch.vstack`. + Stack tensors vertically. For more details, see :meth:`torch.vstack`. - :param list(LabelTensor) label_tensors: the tensors to stack. They need - to have equal labels. - :return: the stacked tensor + :param list(LabelTensor) label_tensors: The :class:`LabelTensor` + instances to stack. They need to have equal labels. + :return: A new :class:`LabelTensor` instance obtained by stacking the + input tensors vertically. :rtype: LabelTensor """ + return LabelTensor.cat(label_tensors, dim=0) # This method is used to update labels @@ -421,13 +480,17 @@ class LabelTensor(torch.Tensor): self, old_labels, to_update_labels, index, dim, to_update_dim ): """ - Update the labels of the tensor by selecting only the labels - :param old_labels: labels from which retrieve data - :param to_update_labels: labels to update - :param index: index of dof to retain - :param dim: label index - :return: + Update the labels of the tensor based on the index (or list of indices). + + :param dict old_labels: Labels from which retrieve data. + :param dict to_update_labels: Labels to update. + :param index: Index of dof to retain. + :type index: int | slice | list | torch.Tensor] + :param int dim: The dimension to update. + + :raises: ValueError: If the index type is not supported. """ + old_dof = old_labels[to_update_dim]["dof"] label_name = old_labels[dim]["name"] # Handle slicing @@ -460,16 +523,22 @@ class LabelTensor(torch.Tensor): def __getitem__(self, index): """ " - Override the __getitem__ method to handle the labels of the tensor. - Perform the __getitem__ operation on the tensor and update the labels. + Override the __getitem__ method to handle the labels of the + :class:`LabelTensor` instance. It first performs __getitem__ operation + on the :class:`torch.Tensor` part of the instance, then updates the + labels based on the index. :param index: The index used to access the item - :type index: Union[int, str, tuple, list] - :return: A tensor-like object with updated labels. + :type index: int | str | tuple | list | torch.Tensor + :return: A new :class:`LabelTensor` instance obtained __getitem__ + operation on :class:`torch.Tensor` part of the instance, with the + updated labels. :rtype: LabelTensor + :raises KeyError: If an invalid label index is provided. :raises IndexError: If an invalid index is accessed in the tensor. """ + # Handle string index if isinstance(index, str) or ( isinstance(index, (tuple, list)) @@ -516,12 +585,11 @@ class LabelTensor(torch.Tensor): def sort_labels(self, dim=None): """ - Sorts the labels along a specified dimension and returns a new tensor - with sorted labels. + Sort the labels along the specified dimension and apply the same sorting + to the :class:`torch.Tensor` part of the instance. - :param dim: The dimension along which to sort the labels. If `None`, - the last dimension (`ndim - 1`) is used. - :type dim: int, optional + :param int dim: The dimension along which to sort the labels. + If ``None``, the last dimension (``ndim - 1``) is used. :return: A new tensor with sorted labels along the specified dimension. :rtype: LabelTensor """ @@ -543,13 +611,15 @@ class LabelTensor(torch.Tensor): def __deepcopy__(self, memo): """ - Creates a deep copy of the object. + Creates a deep copy of the object. For more details, see + :meth:`copy.deepcopy`. :param memo: LabelTensor object to be copied. :type memo: LabelTensor :return: A deep copy of the original LabelTensor object. :rtype: LabelTensor """ + cls = self.__class__ result = cls(deepcopy(self.tensor), deepcopy(self.stored_labels)) return result @@ -557,7 +627,7 @@ class LabelTensor(torch.Tensor): def permute(self, *dims): """ Permutes the dimensions of the tensor and the associated labels - accordingly. + accordingly. For more details, see :meth:`torch.Tensor.permute`. :param dims: The dimensions to permute the tensor to. :type dims: tuple, list @@ -579,11 +649,12 @@ class LabelTensor(torch.Tensor): def detach(self): """ Detaches the tensor from the computation graph and retains the stored - labels. + labels. For more details, see :meth:`torch.Tensor.detach`. :return: A new tensor detached from the computation graph. :rtype: LabelTensor """ + lt = super().detach() # Copy the labels to the new tensor only if present @@ -594,14 +665,15 @@ class LabelTensor(torch.Tensor): @staticmethod def summation(tensors): """ - Computes the summation of a list of tensors. + Computes the summation of a list of :class:`LabelTensor` instances. + - :param tensors: A list of tensors to sum. All tensors must have the same - shape and labels. - :type tensors: list of LabelTensor + :param list(LabelTensor) tensors: A list of tensors to sum. All tensors + must have the same shape and labels. :return: A new `LabelTensor` containing the element-wise sum of the input tensors. :rtype: LabelTensor + :raises ValueError: If the input `tensors` list is empty. :raises RuntimeError: If the tensors have different shapes and/or mismatched labels. @@ -637,12 +709,14 @@ class LabelTensor(torch.Tensor): def reshape(self, *shape): """ Override the reshape method to update the labels of the tensor. + For more details, see :meth:`torch.Tensor.reshape`. - :param shape: The new shape of the tensor. - :type shape: tuple - :return: A tensor-like object with updated labels. + :param tuple shape: The new shape of the tensor. + :return: A new :class:`LabelTensor` instance with the updated shape and + labels. :rtype: LabelTensor """ + # As for now the reshape method is used only in the context of the # dataset, the labels are not tensor = super().reshape(*shape)