Black formatting on LabelTensor

This commit is contained in:
FilippoOlivo
2025-03-11 14:48:17 +01:00
parent c99157d98a
commit 92bb04fafe

View File

@@ -294,7 +294,7 @@ class LabelTensor(torch.Tensor):
concatenate concatenate
:param int dim: dimensions on which you want to perform the operation :param int dim: dimensions on which you want to perform the operation
(default is 0) (default is 0)
:return: A new :class:`LabelTensor' instance obtained by concatenating :return: A new :class:`LabelTensor' instance obtained by concatenating
the input instances, with the updated labels. the input instances, with the updated labels.
:rtype: LabelTensor :rtype: LabelTensor
@@ -346,7 +346,7 @@ class LabelTensor(torch.Tensor):
:param list(LabelTensor) tensors: A list of tensors to stack. :param list(LabelTensor) tensors: A list of tensors to stack.
All tensors must have the same shape. All tensors must have the same shape.
:return: A new :class:`LabelTensor` instance obtained by stacking the :return: A new :class:`LabelTensor` instance obtained by stacking the
input tensors, with the updated labels. input tensors, with the updated labels.
:rtype: LabelTensor :rtype: LabelTensor
""" """
@@ -368,7 +368,7 @@ class LabelTensor(torch.Tensor):
:param bool mode: A boolean value indicating whether the tensor should :param bool mode: A boolean value indicating whether the tensor should
track gradients.If `True`, the tensor will track gradients; track gradients.If `True`, the tensor will track gradients;
if `False`, it will not. if `False`, it will not.
:return: The :class:`LabelTensor` itself with the updated :return: The :class:`LabelTensor` itself with the updated
`requires_grad` state and retained labels. `requires_grad` state and retained labels.
:rtype: LabelTensor :rtype: LabelTensor
""" """
@@ -394,7 +394,7 @@ class LabelTensor(torch.Tensor):
Performs Tensor dtype and/or device conversion. For more details, see Performs Tensor dtype and/or device conversion. For more details, see
:meth:`torch.Tensor.to`. :meth:`torch.Tensor.to`.
:return: A new :class:`LabelTensor` instance with the updated dtype :return: A new :class:`LabelTensor` instance with the updated dtype
and/or device and retained labels. and/or device and retained labels.
:rtype: LabelTensor :rtype: LabelTensor
""" """
@@ -408,7 +408,7 @@ class LabelTensor(torch.Tensor):
Clone the :class:`LabelTensor`. For more details, see Clone the :class:`LabelTensor`. For more details, see
:meth:`torch.Tensor.clone`. :meth:`torch.Tensor.clone`.
:return: A new :class:`LabelTensor` instance with the same data and :return: A new :class:`LabelTensor` instance with the same data and
labels but allocated in a different memory location. labels but allocated in a different memory location.
:rtype: LabelTensor :rtype: LabelTensor
""" """
@@ -466,7 +466,7 @@ class LabelTensor(torch.Tensor):
""" """
Stack tensors vertically. For more details, see :meth:`torch.vstack`. Stack tensors vertically. For more details, see :meth:`torch.vstack`.
:param list(LabelTensor) label_tensors: The :class:`LabelTensor` :param list(LabelTensor) label_tensors: The :class:`LabelTensor`
instances to stack. They need to have equal labels. instances to stack. They need to have equal labels.
:return: A new :class:`LabelTensor` instance obtained by stacking the :return: A new :class:`LabelTensor` instance obtained by stacking the
input tensors vertically. input tensors vertically.
@@ -523,7 +523,7 @@ class LabelTensor(torch.Tensor):
def __getitem__(self, index): def __getitem__(self, index):
""" " """ "
Override the __getitem__ method to handle the labels of the Override the __getitem__ method to handle the labels of the
:class:`LabelTensor` instance. It first performs __getitem__ operation :class:`LabelTensor` instance. It first performs __getitem__ operation
on the :class:`torch.Tensor` part of the instance, then updates the on the :class:`torch.Tensor` part of the instance, then updates the
labels based on the index. labels based on the index.
@@ -666,7 +666,7 @@ class LabelTensor(torch.Tensor):
def summation(tensors): def summation(tensors):
""" """
Computes the summation of a list of :class:`LabelTensor` instances. Computes the summation of a list of :class:`LabelTensor` instances.
:param list(LabelTensor) tensors: A list of tensors to sum. All tensors :param list(LabelTensor) tensors: A list of tensors to sum. All tensors
must have the same shape and labels. must have the same shape and labels.