Other fixes

This commit is contained in:
FilippoOlivo
2025-03-12 12:29:18 +01:00
committed by Nicola Demo
parent ae796ce34c
commit f587c3bf65
9 changed files with 85 additions and 82 deletions

View File

@@ -14,14 +14,14 @@ class LabelTensor(torch.Tensor):
@staticmethod
def __new__(cls, x, labels, *args, **kwargs):
"""
Create a new instance of the :class:`pina.label_tensor.LabelTensor`
Create a new instance of the :class:`~pina.label_tensor.LabelTensor`
class.
:param torch.Tensor x: :class:`torch.tensor` instance to be casted as a
:class:`pina.label_tensor.LabelTensor`.
:class:`~pina.label_tensor.LabelTensor`.
:param labels: Labels to assign to the tensor.
:type labels: str | list[str] | dict
:return: The instance of the :class:`pina.label_tensor.LabelTensor`
:return: The instance of the :class:`~pina.label_tensor.LabelTensor`
class.
:rtype: LabelTensor
"""
@@ -33,10 +33,10 @@ class LabelTensor(torch.Tensor):
@property
def tensor(self):
"""
Give the tensor part of the :class:`pina.label_tensor.LabelTensor`
Give the tensor part of the :class:`~pina.label_tensor.LabelTensor`
object.
:return: tensor part of the :class:`pina.label_tensor.LabelTensor`.
:return: tensor part of the :class:`~pina.label_tensor.LabelTensor`.
:rtype: torch.Tensor
"""
@@ -44,7 +44,7 @@ class LabelTensor(torch.Tensor):
def __init__(self, x, labels):
"""
Construct a :class:`pina.label_tensor.LabelTensor` by passing a dict of
Construct a :class:`~pina.label_tensor.LabelTensor` by passing a dict of
the labels and a :class:`torch.Tensor`. Internally, the initialization
method will store check the compatibility of the labels with the tensor
shape.
@@ -275,10 +275,10 @@ class LabelTensor(torch.Tensor):
def __str__(self):
"""
The string representation of the :class:`pina.label_tensor.LabelTensor`.
The string representation of the :class:`~pina.label_tensor.LabelTensor`.
:return: String representation of the
:class:`pina.label_tensor.LabelTensor` instance.
:class:`~pina.label_tensor.LabelTensor` instance.
:rtype: str
"""
@@ -295,7 +295,7 @@ class LabelTensor(torch.Tensor):
Concatenate a list of tensors along a specified dimension. For more
details, see :meth:`torch.cat`.
:param list[LabelTensor] tensors: :class:`pina.label_tensor.LabelTensor`
:param list[LabelTensor] tensors: :class:`~pina.label_tensor.LabelTensor`
instances to concatenate
:param int dim: dimensions on which you want to perform the operation
(default is 0)
@@ -351,7 +351,7 @@ class LabelTensor(torch.Tensor):
:param list[LabelTensor] tensors: A list of tensors to stack.
All tensors must have the same shape.
:return: A new :class:`pina.label_tensor.LabelTensor` instance obtained
:return: A new :class:`~pina.label_tensor.LabelTensor` instance obtained
by stacking the input tensors, with the updated labels.
:rtype: LabelTensor
"""
@@ -373,7 +373,7 @@ class LabelTensor(torch.Tensor):
:param bool mode: A boolean value indicating whether the tensor should
track gradients.If `True`, the tensor will track gradients;
if `False`, it will not.
:return: The :class:`pina.label_tensor.LabelTensor` itself with the
:return: The :class:`~pina.label_tensor.LabelTensor` itself with the
updated `requires_grad` state and retained labels.
:rtype: LabelTensor
"""
@@ -399,7 +399,7 @@ class LabelTensor(torch.Tensor):
Performs Tensor dtype and/or device conversion. For more details, see
:meth:`torch.Tensor.to`.
:return: A new :class:`pina.label_tensor.LabelTensor` instance with the
:return: A new :class:`~pina.label_tensor.LabelTensor` instance with the
updated dtype and/or device and retained labels.
:rtype: LabelTensor
"""
@@ -410,10 +410,10 @@ class LabelTensor(torch.Tensor):
def clone(self, *args, **kwargs):
"""
Clone the :class:`pina.label_tensor.LabelTensor`. For more details, see
Clone the :class:`~pina.label_tensor.LabelTensor`. For more details, see
:meth:`torch.Tensor.clone`.
:return: A new :class:`pina.label_tensor.LabelTensor` instance with the
:return: A new :class:`~pina.label_tensor.LabelTensor` instance with the
same data and labels but allocated in a different memory location.
:rtype: LabelTensor
"""
@@ -472,9 +472,9 @@ class LabelTensor(torch.Tensor):
Stack tensors vertically. For more details, see :meth:`torch.vstack`.
:param list of LabelTensor label_tensors: The
:class:`pina.label_tensor.LabelTensor` instances to stack. They need
:class:`~pina.label_tensor.LabelTensor` instances to stack. They need
to have equal labels.
:return: A new :class:`pina.label_tensor.LabelTensor` instance obtained
:return: A new :class:`~pina.label_tensor.LabelTensor` instance obtained
by stacking the input tensors vertically.
:rtype: LabelTensor
"""
@@ -530,13 +530,13 @@ class LabelTensor(torch.Tensor):
def __getitem__(self, index):
""" "
Override the __getitem__ method to handle the labels of the
:class:`pina.label_tensor.LabelTensor` instance. It first performs
:class:`~pina.label_tensor.LabelTensor` instance. It first performs
__getitem__ operation on the :class:`torch.Tensor` part of the instance,
then updates the labels based on the index.
:param index: The index used to access the item
:type index: int | str | tuple of int | list ot int | torch.Tensor
:return: A new :class:`pina.label_tensor.LabelTensor` instance obtained
:return: A new :class:`~pina.label_tensor.LabelTensor` instance obtained
`__getitem__` operation on :class:`torch.Tensor` part of the
instance, with the updated labels.
:rtype: LabelTensor
@@ -672,7 +672,7 @@ class LabelTensor(torch.Tensor):
def summation(tensors):
"""
Computes the summation of a list of
:class:`pina.label_tensor.LabelTensor` instances.
:class:`~pina.label_tensor.LabelTensor` instances.
:param list[LabelTensor] tensors: A list of tensors to sum. All
@@ -719,7 +719,7 @@ class LabelTensor(torch.Tensor):
For more details, see :meth:`torch.Tensor.reshape`.
:param tuple of int shape: The new shape of the tensor.
:return: A new :class:`pina.label_tensor.LabelTensor` instance with the
:return: A new :class:`~pina.label_tensor.LabelTensor` instance with the
updated shape and labels.
:rtype: LabelTensor
"""