Other fixes

This commit is contained in:
FilippoOlivo
2025-03-12 12:29:18 +01:00
parent d857b47002
commit 033d36c5a8
9 changed files with 85 additions and 82 deletions

View File

@@ -18,7 +18,8 @@ class Collector:
and the problem and initializing the data collections (dictionary where and the problem and initializing the data collections (dictionary where
data will be stored). data will be stored).
:param AbstractProblem problem: The problem to collect data from. :param pina.problem.abstract_problem.AbstractProblem problem: The
problem to collect data from.
""" """
# creating a hook between collector and problem # creating a hook between collector and problem
self.problem = problem self.problem = problem
@@ -71,16 +72,17 @@ class Collector:
Problem connected to the collector. Problem connected to the collector.
:return: The problem from which the data is collected. :return: The problem from which the data is collected.
:rtype: AbstractProblem :rtype: pina.problem.abstract_problem.AbstractProblem
""" """
return self._problem return self._problem
@problem.setter @problem.setter
def problem(self, value): def problem(self, value):
""" """
Ser the problem connected to the collector. Set the problem connected to the collector.
:param AbstractProblem value: The problem to connect to the collector. :param pina.problem.abstract_problem.AbstractProblem value: The problem
to connect to the collector.
""" """
self._problem = value self._problem = value

View File

@@ -89,15 +89,15 @@ class Condition:
Create a new condition object based on the keyword arguments passed. Create a new condition object based on the keyword arguments passed.
- `input` and `target`: - `input` and `target`:
:class:`pina.condition.input_target_condition.InputTargetCondition` :class:`~pina.condition.input_target_condition.InputTargetCondition`
- `domain` and `equation`: - `domain` and `equation`:
:class:`pina.condition.domain_equation_condition. :class:`~pina.condition.domain_equation_condition.
DomainEquationCondition` DomainEquationCondition`
- `input` and `equation`: :class:`pina.condition. - `input` and `equation`: :class:`~pina.condition.
input_equation_condition.InputEquationCondition` input_equation_condition.InputEquationCondition`
- `input`: :class:`pina.condition.data_condition.DataCondition` - `input`: :class:`~pina.condition.data_condition.DataCondition`
- `input` and `conditional_variables`: - `input` and `conditional_variables`:
:class:`pina.condition.data_condition.DataCondition` :class:`~pina.condition.data_condition.DataCondition`
:raises ValueError: No valid condition has been found. :raises ValueError: No valid condition has been found.
:return: A new condition instance belonging to the proper class. :return: A new condition instance belonging to the proper class.

View File

@@ -35,7 +35,7 @@ class DataCondition(ConditionInterface):
pina.condition.data_condition.GraphDataCondition pina.condition.data_condition.GraphDataCondition
:raises ValueError: If input is not of type :class:`torch.Tensor`, :raises ValueError: If input is not of type :class:`torch.Tensor`,
:class:`pina.label_tensor.LabelTensor`, :class:`pina.graph.Graph`, :class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`,
or :class:`~torch_geometric.data.Data`. or :class:`~torch_geometric.data.Data`.
@@ -68,7 +68,7 @@ class DataCondition(ConditionInterface):
:type conditional_variables: torch.Tensor or LabelTensor :type conditional_variables: torch.Tensor or LabelTensor
.. note:: .. note::
If either `input` is composed by a list of :class:`pina.graph.Graph` If either `input` is composed by a list of :class:`~pina.graph.Graph`
or :class:`~torch_geometric.data.Data` objects, all elements must or :class:`~torch_geometric.data.Data` objects, all elements must
have the same structure (keys and data types) have the same structure (keys and data types)
""" """
@@ -80,12 +80,12 @@ class DataCondition(ConditionInterface):
class TensorDataCondition(DataCondition): class TensorDataCondition(DataCondition):
""" """
DataCondition for :class:`torch.Tensor` or DataCondition for :class:`torch.Tensor` or
:class:`pina.label_tensor.LabelTensor` input data :class:`~pina.label_tensor.LabelTensor` input data
""" """
class GraphDataCondition(DataCondition): class GraphDataCondition(DataCondition):
""" """
DataCondition for :class:`pina.graph.Graph` or DataCondition for :class:`~pina.graph.Graph` or
:class:`~torch_geometric.data.Data` input data :class:`~torch_geometric.data.Data` input data
""" """

View File

@@ -35,7 +35,7 @@ class InputEquationCondition(ConditionInterface):
pina.condition.input_equation_condition.InputGraphEquationCondition pina.condition.input_equation_condition.InputGraphEquationCondition
:raises ValueError: If input is not of type :raises ValueError: If input is not of type
:class:`pina.label_tensor.LabelTensor`, :class:`pina.graph.Graph`. :class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`.
""" """
# If the class is already a subclass, return the instance # If the class is already a subclass, return the instance
@@ -62,15 +62,16 @@ class InputEquationCondition(ConditionInterface):
Initialize the InputEquationCondition by storing the input and equation. Initialize the InputEquationCondition by storing the input and equation.
:param input: Input data for the condition. :param input: Input data for the condition.
:type input: LabelTensor | Graph | list[Graph] | tuple[Graph] :type input: pina.label_tensor.LabelTensor | pina.graph.Graph |
list[pina.graph.Graph] | tuple[pina.graph.Graph]
:param EquationInterface equation: Equation object containing the :param EquationInterface equation: Equation object containing the
equation function. equation function.
.. note:: .. note::
If `input` is composed by a list of :class:`pina.graph.Graph` If `input` is composed by a list of :class:`~pina.graph.Graph`
objects, all elements must have the same structure (keys and data objects, all elements must have the same structure (keys and data
types). Moreover, at least one attribute must be a types). Moreover, at least one attribute must be a
:class:`pina.label_tensor.LabelTensor`. :class:`~pina.label_tensor.LabelTensor`.
""" """
super().__init__() super().__init__()
@@ -90,21 +91,21 @@ class InputEquationCondition(ConditionInterface):
class InputTensorEquationCondition(InputEquationCondition): class InputTensorEquationCondition(InputEquationCondition):
""" """
InputEquationCondition subclass for :class:`pina.label_tensor.LabelTensor` InputEquationCondition subclass for :class:`~pina.label_tensor.LabelTensor`
input data. input data.
""" """
class InputGraphEquationCondition(InputEquationCondition): class InputGraphEquationCondition(InputEquationCondition):
""" """
InputEquationCondition subclass for :class:`pina.graph.Graph` input data. InputEquationCondition subclass for :class:`~pina.graph.Graph` input data.
""" """
@staticmethod @staticmethod
def _check_label_tensor(input): def _check_label_tensor(input):
""" """
Check if at least one :class:`pina.label_tensor.LabelTensor` is present Check if at least one :class:`~pina.label_tensor.LabelTensor` is present
in the :class:`pina.graph.Graph` object. in the :class:`~pina.graph.Graph` object.
:param input: Input data. :param input: Input data.
:type input: torch.Tensor | Graph | Data :type input: torch.Tensor | Graph | Data

View File

@@ -44,8 +44,8 @@ class InputTargetCondition(ConditionInterface):
pina.condition.input_target_condition.GraphInputGraphTargetCondition pina.condition.input_target_condition.GraphInputGraphTargetCondition
:raises ValueError: If input and or target are not of type :raises ValueError: If input and or target are not of type
:class:`torch.Tensor`, :class:`pina.label_tensor.LabelTensor`, :class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`,
:class:`pina.graph.Graph`, or :class:`~torch_geometric.data.Data`. :class:`~pina.graph.Graph`, or :class:`~torch_geometric.data.Data`.
""" """
if cls != InputTargetCondition: if cls != InputTargetCondition:
return super().__new__(cls) return super().__new__(cls)
@@ -97,7 +97,7 @@ class InputTargetCondition(ConditionInterface):
.. note:: .. note::
If either `input` or `target` are composed by a list of If either `input` or `target` are composed by a list of
:class:`pina.graph.Graph` or :class:`~torch_geometric.data.Data` :class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data`
objects, all elements must have the same structure (keys and data objects, all elements must have the same structure (keys and data
types) types)
""" """
@@ -122,29 +122,29 @@ class InputTargetCondition(ConditionInterface):
class TensorInputTensorTargetCondition(InputTargetCondition): class TensorInputTensorTargetCondition(InputTargetCondition):
""" """
InputTargetCondition subclass for :class:`torch.Tensor` or InputTargetCondition subclass for :class:`torch.Tensor` or
:class:`pina.label_tensor.LabelTensor` input and target data. :class:`~pina.label_tensor.LabelTensor` input and target data.
""" """
class TensorInputGraphTargetCondition(InputTargetCondition): class TensorInputGraphTargetCondition(InputTargetCondition):
""" """
InputTargetCondition subclass for :class:`torch.Tensor` or InputTargetCondition subclass for :class:`torch.Tensor` or
:class:`pina.label_tensor.LabelTensor` input and :class:`~pina.label_tensor.LabelTensor` input and
:class:`pina.graph.Graph` or :class:`~torch_geometric.data.Data` target :class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` target
data. data.
""" """
class GraphInputTensorTargetCondition(InputTargetCondition): class GraphInputTensorTargetCondition(InputTargetCondition):
""" """
InputTargetCondition subclass for :class:`pina.graph.Graph` o InputTargetCondition subclass for :class:`~pina.graph.Graph` o
:class:`~torch_geometric.data.Data` input and :class:`torch.Tensor` or :class:`~torch_geometric.data.Data` input and :class:`torch.Tensor` or
:class:`pina.label_tensor.LabelTensor` target data. :class:`~pina.label_tensor.LabelTensor` target data.
""" """
class GraphInputGraphTargetCondition(InputTargetCondition): class GraphInputGraphTargetCondition(InputTargetCondition):
""" """
InputTargetCondition subclass for :class:`pina.graph.Graph`/ InputTargetCondition subclass for :class:`~pina.graph.Graph`/
:class:`~torch_geometric.data.Data` input and target data. :class:`~torch_geometric.data.Data` input and target data.
""" """

View File

@@ -157,7 +157,7 @@ class Collator:
def _collate_tensor_dataset(data_list): def _collate_tensor_dataset(data_list):
""" """
Function used to collate the data when the dataset is a Function used to collate the data when the dataset is a
:class:`pina.data.dataset.PinaTensorDataset`. :class:`~pina.data.dataset.PinaTensorDataset`.
:param data_list: Elements to be collated. :param data_list: Elements to be collated.
:type data_list: list[torch.Tensor] | list[LabelTensor] :type data_list: list[torch.Tensor] | list[LabelTensor]
@@ -165,7 +165,7 @@ class Collator:
:rtype: dict :rtype: dict
:raises RuntimeError: If the data is not a :class:`torch.Tensor` or a :raises RuntimeError: If the data is not a :class:`torch.Tensor` or a
:class:`pina.label_tensor.LabelTensor`. :class:`~pina.label_tensor.LabelTensor`.
""" """
if isinstance(data_list[0], LabelTensor): if isinstance(data_list[0], LabelTensor):
@@ -177,7 +177,7 @@ class Collator:
def _collate_graph_dataset(self, data_list): def _collate_graph_dataset(self, data_list):
""" """
Function used to collate the data when the dataset is a Function used to collate the data when the dataset is a
:class:`pina.data.dataset.PinaGraphDataset`. :class:`~pina.data.dataset.PinaGraphDataset`.
:param data_list: Elememts to be collated. :param data_list: Elememts to be collated.
:type data_list: list[Data] | list[Graph] :type data_list: list[Data] | list[Graph]
@@ -185,7 +185,7 @@ class Collator:
:rtype: dict :rtype: dict
:raises RuntimeError: If the data is not a :raises RuntimeError: If the data is not a
:class:`~torch_geometric.data.Data` or a :class:`pina.graph.Graph`. :class:`~torch_geometric.data.Data` or a :class:`~pina.graph.Graph`.
""" """
if isinstance(data_list[0], LabelTensor): if isinstance(data_list[0], LabelTensor):

View File

@@ -15,26 +15,26 @@ class PinaDatasetFactory:
Depending on the type inside the conditions, it creates a different dataset Depending on the type inside the conditions, it creates a different dataset
object: object:
- :class:`pina.data.dataset.PinaTensorDataset` for handling - :class:`~pina.data.dataset.PinaTensorDataset` for handling
:class:`torch.Tensor` and :class:`pina.label_tensor.LabelTensor` data. :class:`torch.Tensor` and :class:`~pina.label_tensor.LabelTensor` data.
- :class:`pina.data.dataset.PinaGraphDataset` for handling - :class:`~pina.data.dataset.PinaGraphDataset` for handling
:class:`pina.graph.Graph` and :class:`~torch_geometric.data.Data` data. :class:`~pina.graph.Graph` and :class:`~torch_geometric.data.Data` data.
""" """
def __new__(cls, conditions_dict, **kwargs): def __new__(cls, conditions_dict, **kwargs):
""" """
Instantiate the appropriate subclass of Instantiate the appropriate subclass of
:class:`pina.data.dataset.PinaDataset`. :class:`~pina.data.dataset.PinaDataset`.
If a graph is present in the conditions, returns a If a graph is present in the conditions, returns a
:class:`pina.data.dataset.PinaGraphDataset`, otherwise returns a :class:`~pina.data.dataset.PinaGraphDataset`, otherwise returns a
:class:`pina.data.dataset.PinaTensorDataset`. :class:`~pina.data.dataset.PinaTensorDataset`.
:param dict conditions_dict: Dictionary containing all the conditions :param dict conditions_dict: Dictionary containing all the conditions
to be included in the dataset instance. to be included in the dataset instance.
:return: A subclass of :class:`pina.data.dataset.PinaDataset`. :return: A subclass of :class:`~pina.data.dataset.PinaDataset`.
:rtype: :class:`pina.data.dataset.PinaTensorDataset` | :rtype: :class:`~pina.data.dataset.PinaTensorDataset` |
:class:`pina.data.dataset.PinaGraphDataset` :class:`~pina.data.dataset.PinaGraphDataset`
:raises ValueError: If an empty dictionary is provided. :raises ValueError: If an empty dictionary is provided.
""" """
@@ -75,15 +75,15 @@ class PinaDatasetFactory:
class PinaDataset(Dataset, ABC): class PinaDataset(Dataset, ABC):
""" """
Abstract class for the PINA dataset. It defines the common interface for Abstract class for the PINA dataset. It defines the common interface for
the :class:`pina.data.dataset.PinaTensorDataset` and the :class:`~pina.data.dataset.PinaTensorDataset` and
:class:`pina.data.dataset.PinaGraphDataset` classes. :class:`~pina.data.dataset.PinaGraphDataset` classes.
""" """
def __init__( def __init__(
self, conditions_dict, max_conditions_lengths, automatic_batching self, conditions_dict, max_conditions_lengths, automatic_batching
): ):
""" """
Initialize a :class:`pina.data.dataset.PinaDataset` instance by storing Initialize a :class:`~pina.data.dataset.PinaDataset` instance by storing
the providedconditions dictionary, the maximum number of conditions to the providedconditions dictionary, the maximum number of conditions to
consider, and the automatic batching flag. consider, and the automatic batching flag.
@@ -93,7 +93,7 @@ class PinaDataset(Dataset, ABC):
points to include in a single batch for each condition. points to include in a single batch for each condition.
:param bool automatic_batching: Indicates whether PyTorch automatic :param bool automatic_batching: Indicates whether PyTorch automatic
batching is enabled in batching is enabled in
:class:`pina.data.data_module.PinaDataModule`. :class:`~pina.data.data_module.PinaDataModule`.
""" """
# Store the conditions dictionary # Store the conditions dictionary
@@ -205,7 +205,7 @@ class PinaDataset(Dataset, ABC):
class PinaTensorDataset(PinaDataset): class PinaTensorDataset(PinaDataset):
""" """
Dataset class for the PINA dataset with :class:`torch.Tensor` and Dataset class for the PINA dataset with :class:`torch.Tensor` and
:class:`pina.label_tensor.LabelTensor` data. :class:`~pina.label_tensor.LabelTensor` data.
""" """
# Override _retrive_data method for torch.Tensor data # Override _retrive_data method for torch.Tensor data
@@ -215,7 +215,7 @@ class PinaTensorDataset(PinaDataset):
:param dict data: Dictionary containing the data :param dict data: Dictionary containing the data
(only :class:`torch.Tensor` or (only :class:`torch.Tensor` or
:class:`pina.label_tensor.LabelTensor`). :class:`~pina.label_tensor.LabelTensor`).
:param list[int] idx_list: indices to retrieve. :param list[int] idx_list: indices to retrieve.
:return: Dictionary containing the data at the given indices. :return: Dictionary containing the data at the given indices.
:rtype: dict :rtype: dict
@@ -237,7 +237,7 @@ class PinaTensorDataset(PinaDataset):
class PinaGraphDataset(PinaDataset): class PinaGraphDataset(PinaDataset):
""" """
Dataset class for the PINA dataset with :class:`~torch_geometric.data.Data` Dataset class for the PINA dataset with :class:`~torch_geometric.data.Data`
and :class:`pina.graph.Graph` data. and :class:`~pina.graph.Graph` data.
""" """
def _create_graph_batch(self, data): def _create_graph_batch(self, data):

View File

@@ -20,13 +20,13 @@ class Graph(Data):
**kwargs, **kwargs,
): ):
""" """
Create a new instance of the :class:`pina.graph.Graph` class by checking Create a new instance of the :class:`~pina.graph.Graph` class by checking
the consistency of the input data and storing the attributes. the consistency of the input data and storing the attributes.
:param kwargs: Parameters used to initialize the :param kwargs: Parameters used to initialize the
:class:`pina.graph.Graph` object. :class:`~pina.graph.Graph` object.
:type kwargs: dict :type kwargs: dict
:return: A new instance of the :class:`pina.graph.Graph` class. :return: A new instance of the :class:`~pina.graph.Graph` class.
:rtype: Graph :rtype: Graph
""" """
@@ -230,7 +230,7 @@ class GraphBuilder:
): ):
""" """
Compute the edge attributes and create a new instance of the Compute the edge attributes and create a new instance of the
:class:`pina.graph.Graph` class. :class:`~pina.graph.Graph` class.
:param pos: A tensor of shape `(N, D)` representing the positions of `N` :param pos: A tensor of shape `(N, D)` representing the positions of `N`
points in `D`-dimensional space. points in `D`-dimensional space.
@@ -248,8 +248,8 @@ class GraphBuilder:
If provided, overrides `edge_attr`. If provided, overrides `edge_attr`.
:type custom_edge_func: callable, optional :type custom_edge_func: callable, optional
:param kwargs: Additional keyword arguments passed to the :param kwargs: Additional keyword arguments passed to the
:class:`pina.graph.Graph` class constructor. :class:`~pina.graph.Graph` class constructor.
:return: A :class:`pina.graph.Graph` instance constructed using the :return: A :class:`~pina.graph.Graph` instance constructed using the
provided information. provided information.
:rtype: Graph :rtype: Graph
""" """
@@ -289,7 +289,7 @@ class RadiusGraph(GraphBuilder):
def __new__(cls, pos, radius, **kwargs): def __new__(cls, pos, radius, **kwargs):
""" """
Extends the :class:`pina.graph.GraphBuilder` class to compute Extends the :class:`~pina.graph.GraphBuilder` class to compute
edge_index based on a radius. Each point is connected to all the points edge_index based on a radius. Each point is connected to all the points
within the radius. within the radius.
@@ -299,9 +299,9 @@ class RadiusGraph(GraphBuilder):
:param radius: The radius within which points are connected. :param radius: The radius within which points are connected.
:type radius: float :type radius: float
:param kwargs: Additional keyword arguments to be passed to the :param kwargs: Additional keyword arguments to be passed to the
:class:`pina.graph.GraphBuilder` and :class:`pina.graph.Graph` :class:`~pina.graph.GraphBuilder` and :class:`~pina.graph.Graph`
constructors. constructors.
:return: A :class:`pina.graph.Graph` instance containing the input :return: A :class:`~pina.graph.Graph` instance containing the input
information and the computed edge_index. information and the computed edge_index.
:rtype: Graph :rtype: Graph
""" """
@@ -339,7 +339,7 @@ class KNNGraph(GraphBuilder):
def __new__(cls, pos, neighbours, **kwargs): def __new__(cls, pos, neighbours, **kwargs):
""" """
Extends the :class:`pina.graph.GraphBuilder` class to compute edge_index Extends the :class:`~pina.graph.GraphBuilder` class to compute edge_index
based on a K-nearest neighbors algorithm. based on a K-nearest neighbors algorithm.
:param pos: A tensor of shape (N, D) representing the positions of N :param pos: A tensor of shape (N, D) representing the positions of N
@@ -351,7 +351,7 @@ class KNNGraph(GraphBuilder):
The additional keyword arguments to be passed to GraphBuilder The additional keyword arguments to be passed to GraphBuilder
and Graph classes and Graph classes
:return: A :class:`pina.graph.Graph` instance containg the :return: A :class:`~pina.graph.Graph` instance containg the
information passed in input and the computed edge_index information passed in input and the computed edge_index
:rtype: Graph :rtype: Graph
""" """

View File

@@ -14,14 +14,14 @@ class LabelTensor(torch.Tensor):
@staticmethod @staticmethod
def __new__(cls, x, labels, *args, **kwargs): def __new__(cls, x, labels, *args, **kwargs):
""" """
Create a new instance of the :class:`pina.label_tensor.LabelTensor` Create a new instance of the :class:`~pina.label_tensor.LabelTensor`
class. class.
:param torch.Tensor x: :class:`torch.tensor` instance to be casted as a :param torch.Tensor x: :class:`torch.tensor` instance to be casted as a
:class:`pina.label_tensor.LabelTensor`. :class:`~pina.label_tensor.LabelTensor`.
:param labels: Labels to assign to the tensor. :param labels: Labels to assign to the tensor.
:type labels: str | list[str] | dict :type labels: str | list[str] | dict
:return: The instance of the :class:`pina.label_tensor.LabelTensor` :return: The instance of the :class:`~pina.label_tensor.LabelTensor`
class. class.
:rtype: LabelTensor :rtype: LabelTensor
""" """
@@ -33,10 +33,10 @@ class LabelTensor(torch.Tensor):
@property @property
def tensor(self): def tensor(self):
""" """
Give the tensor part of the :class:`pina.label_tensor.LabelTensor` Give the tensor part of the :class:`~pina.label_tensor.LabelTensor`
object. object.
:return: tensor part of the :class:`pina.label_tensor.LabelTensor`. :return: tensor part of the :class:`~pina.label_tensor.LabelTensor`.
:rtype: torch.Tensor :rtype: torch.Tensor
""" """
@@ -44,7 +44,7 @@ class LabelTensor(torch.Tensor):
def __init__(self, x, labels): def __init__(self, x, labels):
""" """
Construct a :class:`pina.label_tensor.LabelTensor` by passing a dict of Construct a :class:`~pina.label_tensor.LabelTensor` by passing a dict of
the labels and a :class:`torch.Tensor`. Internally, the initialization the labels and a :class:`torch.Tensor`. Internally, the initialization
method will store check the compatibility of the labels with the tensor method will store check the compatibility of the labels with the tensor
shape. shape.
@@ -275,10 +275,10 @@ class LabelTensor(torch.Tensor):
def __str__(self): def __str__(self):
""" """
The string representation of the :class:`pina.label_tensor.LabelTensor`. The string representation of the :class:`~pina.label_tensor.LabelTensor`.
:return: String representation of the :return: String representation of the
:class:`pina.label_tensor.LabelTensor` instance. :class:`~pina.label_tensor.LabelTensor` instance.
:rtype: str :rtype: str
""" """
@@ -295,7 +295,7 @@ class LabelTensor(torch.Tensor):
Concatenate a list of tensors along a specified dimension. For more Concatenate a list of tensors along a specified dimension. For more
details, see :meth:`torch.cat`. details, see :meth:`torch.cat`.
:param list[LabelTensor] tensors: :class:`pina.label_tensor.LabelTensor` :param list[LabelTensor] tensors: :class:`~pina.label_tensor.LabelTensor`
instances to concatenate instances to concatenate
:param int dim: dimensions on which you want to perform the operation :param int dim: dimensions on which you want to perform the operation
(default is 0) (default is 0)
@@ -351,7 +351,7 @@ class LabelTensor(torch.Tensor):
:param list[LabelTensor] tensors: A list of tensors to stack. :param list[LabelTensor] tensors: A list of tensors to stack.
All tensors must have the same shape. All tensors must have the same shape.
:return: A new :class:`pina.label_tensor.LabelTensor` instance obtained :return: A new :class:`~pina.label_tensor.LabelTensor` instance obtained
by stacking the input tensors, with the updated labels. by stacking the input tensors, with the updated labels.
:rtype: LabelTensor :rtype: LabelTensor
""" """
@@ -373,7 +373,7 @@ class LabelTensor(torch.Tensor):
:param bool mode: A boolean value indicating whether the tensor should :param bool mode: A boolean value indicating whether the tensor should
track gradients.If `True`, the tensor will track gradients; track gradients.If `True`, the tensor will track gradients;
if `False`, it will not. if `False`, it will not.
:return: The :class:`pina.label_tensor.LabelTensor` itself with the :return: The :class:`~pina.label_tensor.LabelTensor` itself with the
updated `requires_grad` state and retained labels. updated `requires_grad` state and retained labels.
:rtype: LabelTensor :rtype: LabelTensor
""" """
@@ -399,7 +399,7 @@ class LabelTensor(torch.Tensor):
Performs Tensor dtype and/or device conversion. For more details, see Performs Tensor dtype and/or device conversion. For more details, see
:meth:`torch.Tensor.to`. :meth:`torch.Tensor.to`.
:return: A new :class:`pina.label_tensor.LabelTensor` instance with the :return: A new :class:`~pina.label_tensor.LabelTensor` instance with the
updated dtype and/or device and retained labels. updated dtype and/or device and retained labels.
:rtype: LabelTensor :rtype: LabelTensor
""" """
@@ -410,10 +410,10 @@ class LabelTensor(torch.Tensor):
def clone(self, *args, **kwargs): def clone(self, *args, **kwargs):
""" """
Clone the :class:`pina.label_tensor.LabelTensor`. For more details, see Clone the :class:`~pina.label_tensor.LabelTensor`. For more details, see
:meth:`torch.Tensor.clone`. :meth:`torch.Tensor.clone`.
:return: A new :class:`pina.label_tensor.LabelTensor` instance with the :return: A new :class:`~pina.label_tensor.LabelTensor` instance with the
same data and labels but allocated in a different memory location. same data and labels but allocated in a different memory location.
:rtype: LabelTensor :rtype: LabelTensor
""" """
@@ -472,9 +472,9 @@ class LabelTensor(torch.Tensor):
Stack tensors vertically. For more details, see :meth:`torch.vstack`. Stack tensors vertically. For more details, see :meth:`torch.vstack`.
:param list of LabelTensor label_tensors: The :param list of LabelTensor label_tensors: The
:class:`pina.label_tensor.LabelTensor` instances to stack. They need :class:`~pina.label_tensor.LabelTensor` instances to stack. They need
to have equal labels. to have equal labels.
:return: A new :class:`pina.label_tensor.LabelTensor` instance obtained :return: A new :class:`~pina.label_tensor.LabelTensor` instance obtained
by stacking the input tensors vertically. by stacking the input tensors vertically.
:rtype: LabelTensor :rtype: LabelTensor
""" """
@@ -530,13 +530,13 @@ class LabelTensor(torch.Tensor):
def __getitem__(self, index): def __getitem__(self, index):
""" " """ "
Override the __getitem__ method to handle the labels of the Override the __getitem__ method to handle the labels of the
:class:`pina.label_tensor.LabelTensor` instance. It first performs :class:`~pina.label_tensor.LabelTensor` instance. It first performs
__getitem__ operation on the :class:`torch.Tensor` part of the instance, __getitem__ operation on the :class:`torch.Tensor` part of the instance,
then updates the labels based on the index. then updates the labels based on the index.
:param index: The index used to access the item :param index: The index used to access the item
:type index: int | str | tuple of int | list ot int | torch.Tensor :type index: int | str | tuple of int | list ot int | torch.Tensor
:return: A new :class:`pina.label_tensor.LabelTensor` instance obtained :return: A new :class:`~pina.label_tensor.LabelTensor` instance obtained
`__getitem__` operation on :class:`torch.Tensor` part of the `__getitem__` operation on :class:`torch.Tensor` part of the
instance, with the updated labels. instance, with the updated labels.
:rtype: LabelTensor :rtype: LabelTensor
@@ -672,7 +672,7 @@ class LabelTensor(torch.Tensor):
def summation(tensors): def summation(tensors):
""" """
Computes the summation of a list of Computes the summation of a list of
:class:`pina.label_tensor.LabelTensor` instances. :class:`~pina.label_tensor.LabelTensor` instances.
:param list[LabelTensor] tensors: A list of tensors to sum. All :param list[LabelTensor] tensors: A list of tensors to sum. All
@@ -719,7 +719,7 @@ class LabelTensor(torch.Tensor):
For more details, see :meth:`torch.Tensor.reshape`. For more details, see :meth:`torch.Tensor.reshape`.
:param tuple of int shape: The new shape of the tensor. :param tuple of int shape: The new shape of the tensor.
:return: A new :class:`pina.label_tensor.LabelTensor` instance with the :return: A new :class:`~pina.label_tensor.LabelTensor` instance with the
updated shape and labels. updated shape and labels.
:rtype: LabelTensor :rtype: LabelTensor
""" """