diff --git a/pina/label_tensor.py b/pina/label_tensor.py index 4647ee5..a3a6160 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -33,10 +33,10 @@ class LabelTensor(torch.Tensor): @property def tensor(self): """ - Give the tensor part of the :class:`~pina.label_tensor.LabelTensor` + Returns the tensor part of the :class:`~pina.label_tensor.LabelTensor` object. - :return: tensor part of the :class:`~pina.label_tensor.LabelTensor`. + :return: Tensor part of the :class:`~pina.label_tensor.LabelTensor`. :rtype: torch.Tensor """ @@ -44,10 +44,15 @@ class LabelTensor(torch.Tensor): def __init__(self, x, labels): """ - Construct a :class:`~pina.label_tensor.LabelTensor` by passing a dict of - the labels and a :class:`torch.Tensor`. Internally, the initialization - method will store check the compatibility of the labels with the tensor - shape. + Initialize the :class:`~pina.label_tensor.LabelTensor` instance, by + checking the consistency of the labels and the tensor. Specifically, the + labels must match the following conditions: + + - At each dimension, the number of labels must match the size of the \ + dimension. + - At each dimension, the labels must be unique. + + The labels can be passed in the following formats: :Example: >>> from pina import LabelTensor @@ -57,11 +62,18 @@ class LabelTensor(torch.Tensor): >>> tensor = LabelTensor( >>> torch.rand((2000, 3)), ... ["a", "b", "c"]) + + The keys of the dictionary are the dimension indices, and the values are + dictionaries containing the labels and the name of the dimension. If + the labels are passed as a list, these are assigned to the last + dimension. + :param torch.Tensor x: The tensor to be casted as a + :class:`~pina.label_tensor.LabelTensor`. + :param labels: Labels to assign to the tensor. + :type labels: str | list[str] | dict + :raises ValueError: If the labels are not consistent with the tensor. """ - # Avoid unused argument warning. x is not used in the constructor - # of the parent class. - # pylint: disable=unused-argument super().__init__() if labels is not None: self.labels = labels @@ -71,7 +83,7 @@ class LabelTensor(torch.Tensor): @property def full_labels(self): """ - Gives the full labels of the tensor, even for the dimensions that are + Returns the full labels of the tensor, even for the dimensions that are not labeled. :return: The full labels of the tensor @@ -89,7 +101,7 @@ class LabelTensor(torch.Tensor): @property def stored_labels(self): """ - Gives the labels stored inside the instance. + Returns the labels stored inside the instance. :return: The labels stored inside the instance. :rtype: dict @@ -99,7 +111,7 @@ class LabelTensor(torch.Tensor): @property def labels(self): """ - Give the labels of the last dimension of the instance. + Returns the labels of the last dimension of the instance. :return: labels of last dimension :rtype: list @@ -111,8 +123,9 @@ class LabelTensor(torch.Tensor): @labels.setter def labels(self, labels): """ - Set the parameter ``_labels`` by checking the type of the input labels - and handling it accordingly. The following types are accepted: + Set labels stored insider the instance by checking the type of the + input labels and handling it accordingly. The following types are + accepted: - **list**: The list of labels is assigned to the last dimension. - **dict**: The dictionary of labels is assigned to the tensor. @@ -134,7 +147,7 @@ class LabelTensor(torch.Tensor): else: raise ValueError("labels must be list, dict or string.") - def _init_labels_from_dict(self, labels: dict): + def _init_labels_from_dict(self, labels): """ Store the internal label representation according to the values passed as input. @@ -146,7 +159,7 @@ class LabelTensor(torch.Tensor): tensor_shape = self.shape - def validate_dof(dof_list, dim_size: int): + def validate_dof(dof_list, dim_size): """Validate the 'dof' list for uniqueness and size.""" if len(dof_list) != len(set(dof_list)): raise ValueError("dof must be unique") @@ -187,7 +200,7 @@ class LabelTensor(torch.Tensor): def _init_labels_from_list(self, labels): """ - Given a ``list`` of dof, this method update the internal label + Given a list of dof, this method update the internal label representation by assigning the dof to the last dimension. :param labels: The label(s) to update. @@ -203,17 +216,25 @@ class LabelTensor(torch.Tensor): def extract(self, labels_to_extract): """ Extract the subset of the original tensor by returning all the positions - corresponding to the passed ``label_to_extract``. + corresponding to the passed ``label_to_extract``. If ``label_to_extract`` + is a dictionary, the keys are the dimension names and the values are the + labels to extract. If a single label or a list of labels is passed, the + last dimension is considered. - :param labels_to_extract: The label(s) to extract. If a single label or - a list of labels is passed, the last dimension is considered. - If a dictionary is passed, the keys are the dimension names and the - values are the labels to extract. + :Example: + >>> from pina import LabelTensor + >>> labels = {1: {'dof': ["a", "b", "c"], 'name': 'space'}} + >>> tensor = LabelTensor(torch.rand((2000, 3)), labels) + >>> tensor.extract("a") + >>> tensor.extract(["a", "b"]) + >>> tensor.extract({"space": ["a", "b"]}) + + :param labels_to_extract: The label(s) to extract. :type labels_to_extract: str | list[str] | tuple[str] | dict :return: The extracted tensor with the updated labels. :rtype: LabelTensor - :raises TypeError: Labels are not ``str``, ``list of str`` or ``dict`` + :raises TypeError: Labels are not ``str``, ``list[str]`` or ``dict`` properly setted. :raises ValueError: Label to extract is not in the labels ``list``. """ @@ -298,13 +319,13 @@ class LabelTensor(torch.Tensor): :param list[LabelTensor] tensors: :class:`~pina.label_tensor.LabelTensor` instances to concatenate - :param int dim: dimensions on which you want to perform the operation + :param int dim: Dimensions on which you want to perform the operation (default is 0) - :return: A new :class:`LabelTensor' instance obtained by concatenating - the input instances, with the updated labels. + :return: A new :class:`LabelTensor` instance obtained by concatenating + the input instances. :rtype: LabelTensor - :raises ValueError: either number dof or dimensions names differ + :raises ValueError: either number dof or dimensions names differ. """ if not tensors: @@ -353,7 +374,7 @@ class LabelTensor(torch.Tensor): :param list[LabelTensor] tensors: A list of tensors to stack. All tensors must have the same shape. :return: A new :class:`~pina.label_tensor.LabelTensor` instance obtained - by stacking the input tensors, with the updated labels. + by stacking the input tensors. :rtype: LabelTensor """ @@ -389,7 +410,7 @@ class LabelTensor(torch.Tensor): Give the ``dtype`` of the tensor. For more details, see :meth:`torch.dtype`. - :return: dtype of the tensor + :return: The data type of the tensor. :rtype: torch.dtype """ @@ -427,19 +448,19 @@ class LabelTensor(torch.Tensor): def append(self, tensor, mode="std"): """ Appends a given tensor to the current tensor along the last dimension. - This method supports two types of appending operations: - 1. **Standard append** ("std"): Concatenates the input tensor with the + + 1. **Standard append** ("std"): Concatenates the input tensor with the \ current tensor along the last dimension. - 2. **Cross append** ("cross"): Creates a cross-product of the current - tensor and the input tensor by repeating them in a cross-product - fashion, then concatenates the result along the last dimension. + 2. **Cross append** ("cross"): Creates a cross-product of the current \ + tensor and the input tensor. :param tensor: The tensor to append to the current tensor. :type tensor: LabelTensor - :param mode: The append mode to use. Defaults to "std". + :param mode: The append mode to use. Defaults to ``st``. :type mode: str, optional - :return: A new `LabelTensor` obtained by appending the input tensor. + :return: A new :class:`LabelTensor` instance obtained by appending the + input tensor. :rtype: LabelTensor :raises ValueError: If the mode is not "std" or "cross". @@ -468,7 +489,7 @@ class LabelTensor(torch.Tensor): raise ValueError('mode must be either "std" or "cross"') @staticmethod - def vstack(label_tensors): + def vstack(tensors): """ Stack tensors vertically. For more details, see :meth:`torch.vstack`. @@ -480,7 +501,7 @@ class LabelTensor(torch.Tensor): :rtype: LabelTensor """ - return LabelTensor.cat(label_tensors, dim=0) + return LabelTensor.cat(tensors, dim=0) # This method is used to update labels def _update_single_label( @@ -592,11 +613,11 @@ class LabelTensor(torch.Tensor): def sort_labels(self, dim=None): """ - Sort the labels along the specified dimension and apply the same sorting - to the :class:`torch.Tensor` part of the instance. + Sort the labels along the specified dimension and apply. It applies the + same sorting to the tensor part of the instance. :param int dim: The dimension along which to sort the labels. - If ``None``, the last dimension (``ndim - 1``) is used. + If ``None``, the last dimension is used. :return: A new tensor with sorted labels along the specified dimension. :rtype: LabelTensor """