Update doc LT

This commit is contained in:
FilippoOlivo
2025-03-12 09:28:54 +01:00
parent d3629b2b54
commit e7998b629b

View File

@@ -19,7 +19,7 @@ class LabelTensor(torch.Tensor):
:param torch.Tensor x: :class:`torch.tensor` instance to be casted as a :param torch.Tensor x: :class:`torch.tensor` instance to be casted as a
:class:`LabelTensor`. :class:`LabelTensor`.
:param labels: Labels to assign to the tensor. :param labels: Labels to assign to the tensor.
:type labels: str | list(str) | dict :type labels: str | list[str] | dict
:return: The instance of the :class:`LabelTensor` class. :return: The instance of the :class:`LabelTensor` class.
:rtype: LabelTensor :rtype: LabelTensor
""" """
@@ -115,7 +115,7 @@ class LabelTensor(torch.Tensor):
- **str**: The string is assigned to the last dimension. - **str**: The string is assigned to the last dimension.
:param labels: Labels to assign to the class variable _labels. :param labels: Labels to assign to the class variable _labels.
:type labels: str | list(str) | dict :type labels: str | list[str] | dict
""" """
if not hasattr(self, "_labels"): if not hasattr(self, "_labels"):
@@ -205,11 +205,11 @@ class LabelTensor(torch.Tensor):
a list of labels is passed, the last dimension is considered. a list of labels is passed, the last dimension is considered.
If a dictionary is passed, the keys are the dimension names and the If a dictionary is passed, the keys are the dimension names and the
values are the labels to extract. values are the labels to extract.
:type labels_to_extract: str | list(str) | tuple(str) | dict :type labels_to_extract: str | list[str] | tuple[str] | dict
:return: The extracted tensor with the updated labels. :return: The extracted tensor with the updated labels.
:rtype: LabelTensor :rtype: LabelTensor
:raises TypeError: Labels are not ``str``, ``list(str)`` or ``dict`` :raises TypeError: Labels are not ``str``, ``list of str`` or ``dict``
properly setted. properly setted.
:raises ValueError: Label to extract is not in the labels ``list``. :raises ValueError: Label to extract is not in the labels ``list``.
""" """
@@ -290,7 +290,7 @@ class LabelTensor(torch.Tensor):
Concatenate a list of tensors along a specified dimension. For more Concatenate a list of tensors along a specified dimension. For more
details, see :meth:`torch.cat`. details, see :meth:`torch.cat`.
:param list(LabelTensor) tensors: :class:`LabelTensor` instances to :param list[LabelTensor] tensors: :class:`LabelTensor` instances to
concatenate concatenate
:param int dim: dimensions on which you want to perform the operation :param int dim: dimensions on which you want to perform the operation
(default is 0) (default is 0)
@@ -344,7 +344,7 @@ class LabelTensor(torch.Tensor):
Stacks a list of tensors along a new dimension. For more details, see Stacks a list of tensors along a new dimension. For more details, see
:meth:`torch.stack`. :meth:`torch.stack`.
:param list(LabelTensor) tensors: A list of tensors to stack. :param list[LabelTensor] tensors: A list of tensors to stack.
All tensors must have the same shape. All tensors must have the same shape.
:return: A new :class:`LabelTensor` instance obtained by stacking the :return: A new :class:`LabelTensor` instance obtained by stacking the
input tensors, with the updated labels. input tensors, with the updated labels.
@@ -466,7 +466,7 @@ class LabelTensor(torch.Tensor):
""" """
Stack tensors vertically. For more details, see :meth:`torch.vstack`. Stack tensors vertically. For more details, see :meth:`torch.vstack`.
:param list(LabelTensor) label_tensors: The :class:`LabelTensor` :param list of LabelTensor label_tensors: The :class:`LabelTensor`
instances to stack. They need to have equal labels. instances to stack. They need to have equal labels.
:return: A new :class:`LabelTensor` instance obtained by stacking the :return: A new :class:`LabelTensor` instance obtained by stacking the
input tensors vertically. input tensors vertically.
@@ -485,7 +485,7 @@ class LabelTensor(torch.Tensor):
:param dict old_labels: Labels from which retrieve data. :param dict old_labels: Labels from which retrieve data.
:param dict to_update_labels: Labels to update. :param dict to_update_labels: Labels to update.
:param index: Index of dof to retain. :param index: Index of dof to retain.
:type index: int | slice | list | torch.Tensor] :type index: int | slice | list[int] | tuple[int] | torch.Tensor
:param int dim: The dimension to update. :param int dim: The dimension to update.
:raises: ValueError: If the index type is not supported. :raises: ValueError: If the index type is not supported.
@@ -529,7 +529,7 @@ class LabelTensor(torch.Tensor):
labels based on the index. labels based on the index.
:param index: The index used to access the item :param index: The index used to access the item
:type index: int | str | tuple | list | torch.Tensor :type index: int | str | tuple of int | list ot int | torch.Tensor
:return: A new :class:`LabelTensor` instance obtained __getitem__ :return: A new :class:`LabelTensor` instance obtained __getitem__
operation on :class:`torch.Tensor` part of the instance, with the operation on :class:`torch.Tensor` part of the instance, with the
updated labels. updated labels.
@@ -630,7 +630,7 @@ class LabelTensor(torch.Tensor):
accordingly. For more details, see :meth:`torch.Tensor.permute`. accordingly. For more details, see :meth:`torch.Tensor.permute`.
:param dims: The dimensions to permute the tensor to. :param dims: The dimensions to permute the tensor to.
:type dims: tuple, list :type dims: tuple[int] | list[int]
:return: A new object with permuted dimensions and reordered labels. :return: A new object with permuted dimensions and reordered labels.
:rtype: LabelTensor :rtype: LabelTensor
""" """
@@ -668,8 +668,8 @@ class LabelTensor(torch.Tensor):
Computes the summation of a list of :class:`LabelTensor` instances. Computes the summation of a list of :class:`LabelTensor` instances.
:param list(LabelTensor) tensors: A list of tensors to sum. All tensors :param list[LabelTensor] tensors: A list of tensors to sum. All
must have the same shape and labels. tensors must have the same shape and labels.
:return: A new `LabelTensor` containing the element-wise sum of the :return: A new `LabelTensor` containing the element-wise sum of the
input tensors. input tensors.
:rtype: LabelTensor :rtype: LabelTensor
@@ -711,7 +711,7 @@ class LabelTensor(torch.Tensor):
Override the reshape method to update the labels of the tensor. Override the reshape method to update the labels of the tensor.
For more details, see :meth:`torch.Tensor.reshape`. For more details, see :meth:`torch.Tensor.reshape`.
:param tuple shape: The new shape of the tensor. :param tuple of int shape: The new shape of the tensor.
:return: A new :class:`LabelTensor` instance with the updated shape and :return: A new :class:`LabelTensor` instance with the updated shape and
labels. labels.
:rtype: LabelTensor :rtype: LabelTensor