Update doc LT

This commit is contained in:
FilippoOlivo
2025-03-12 09:28:54 +01:00
parent d3629b2b54
commit e7998b629b

View File

@@ -19,7 +19,7 @@ class LabelTensor(torch.Tensor):
:param torch.Tensor x: :class:`torch.tensor` instance to be casted as a
:class:`LabelTensor`.
:param labels: Labels to assign to the tensor.
:type labels: str | list(str) | dict
:type labels: str | list[str] | dict
:return: The instance of the :class:`LabelTensor` class.
:rtype: LabelTensor
"""
@@ -115,7 +115,7 @@ class LabelTensor(torch.Tensor):
- **str**: The string is assigned to the last dimension.
:param labels: Labels to assign to the class variable _labels.
:type labels: str | list(str) | dict
:type labels: str | list[str] | dict
"""
if not hasattr(self, "_labels"):
@@ -205,11 +205,11 @@ class LabelTensor(torch.Tensor):
a list of labels is passed, the last dimension is considered.
If a dictionary is passed, the keys are the dimension names and the
values are the labels to extract.
:type labels_to_extract: str | list(str) | tuple(str) | dict
:type labels_to_extract: str | list[str] | tuple[str] | dict
:return: The extracted tensor with the updated labels.
:rtype: LabelTensor
:raises TypeError: Labels are not ``str``, ``list(str)`` or ``dict``
:raises TypeError: Labels are not ``str``, ``list of str`` or ``dict``
properly setted.
:raises ValueError: Label to extract is not in the labels ``list``.
"""
@@ -290,7 +290,7 @@ class LabelTensor(torch.Tensor):
Concatenate a list of tensors along a specified dimension. For more
details, see :meth:`torch.cat`.
:param list(LabelTensor) tensors: :class:`LabelTensor` instances to
:param list[LabelTensor] tensors: :class:`LabelTensor` instances to
concatenate
:param int dim: dimensions on which you want to perform the operation
(default is 0)
@@ -344,7 +344,7 @@ class LabelTensor(torch.Tensor):
Stacks a list of tensors along a new dimension. For more details, see
:meth:`torch.stack`.
:param list(LabelTensor) tensors: A list of tensors to stack.
:param list[LabelTensor] tensors: A list of tensors to stack.
All tensors must have the same shape.
:return: A new :class:`LabelTensor` instance obtained by stacking the
input tensors, with the updated labels.
@@ -466,7 +466,7 @@ class LabelTensor(torch.Tensor):
"""
Stack tensors vertically. For more details, see :meth:`torch.vstack`.
:param list(LabelTensor) label_tensors: The :class:`LabelTensor`
:param list of LabelTensor label_tensors: The :class:`LabelTensor`
instances to stack. They need to have equal labels.
:return: A new :class:`LabelTensor` instance obtained by stacking the
input tensors vertically.
@@ -485,7 +485,7 @@ class LabelTensor(torch.Tensor):
:param dict old_labels: Labels from which retrieve data.
:param dict to_update_labels: Labels to update.
:param index: Index of dof to retain.
:type index: int | slice | list | torch.Tensor]
:type index: int | slice | list[int] | tuple[int] | torch.Tensor
:param int dim: The dimension to update.
:raises: ValueError: If the index type is not supported.
@@ -529,7 +529,7 @@ class LabelTensor(torch.Tensor):
labels based on the index.
:param index: The index used to access the item
:type index: int | str | tuple | list | torch.Tensor
:type index: int | str | tuple of int | list ot int | torch.Tensor
:return: A new :class:`LabelTensor` instance obtained __getitem__
operation on :class:`torch.Tensor` part of the instance, with the
updated labels.
@@ -630,7 +630,7 @@ class LabelTensor(torch.Tensor):
accordingly. For more details, see :meth:`torch.Tensor.permute`.
:param dims: The dimensions to permute the tensor to.
:type dims: tuple, list
:type dims: tuple[int] | list[int]
:return: A new object with permuted dimensions and reordered labels.
:rtype: LabelTensor
"""
@@ -668,8 +668,8 @@ class LabelTensor(torch.Tensor):
Computes the summation of a list of :class:`LabelTensor` instances.
:param list(LabelTensor) tensors: A list of tensors to sum. All tensors
must have the same shape and labels.
:param list[LabelTensor] tensors: A list of tensors to sum. All
tensors must have the same shape and labels.
:return: A new `LabelTensor` containing the element-wise sum of the
input tensors.
:rtype: LabelTensor
@@ -711,7 +711,7 @@ class LabelTensor(torch.Tensor):
Override the reshape method to update the labels of the tensor.
For more details, see :meth:`torch.Tensor.reshape`.
:param tuple shape: The new shape of the tensor.
:param tuple of int shape: The new shape of the tensor.
:return: A new :class:`LabelTensor` instance with the updated shape and
labels.
:rtype: LabelTensor