Other fixes
This commit is contained in:
committed by
Nicola Demo
parent
ae796ce34c
commit
f587c3bf65
@@ -18,7 +18,8 @@ class Collector:
|
||||
and the problem and initializing the data collections (dictionary where
|
||||
data will be stored).
|
||||
|
||||
:param AbstractProblem problem: The problem to collect data from.
|
||||
:param pina.problem.abstract_problem.AbstractProblem problem: The
|
||||
problem to collect data from.
|
||||
"""
|
||||
# creating a hook between collector and problem
|
||||
self.problem = problem
|
||||
@@ -71,16 +72,17 @@ class Collector:
|
||||
Problem connected to the collector.
|
||||
|
||||
:return: The problem from which the data is collected.
|
||||
:rtype: AbstractProblem
|
||||
:rtype: pina.problem.abstract_problem.AbstractProblem
|
||||
"""
|
||||
return self._problem
|
||||
|
||||
@problem.setter
|
||||
def problem(self, value):
|
||||
"""
|
||||
Ser the problem connected to the collector.
|
||||
Set the problem connected to the collector.
|
||||
|
||||
:param AbstractProblem value: The problem to connect to the collector.
|
||||
:param pina.problem.abstract_problem.AbstractProblem value: The problem
|
||||
to connect to the collector.
|
||||
"""
|
||||
|
||||
self._problem = value
|
||||
|
||||
@@ -89,15 +89,15 @@ class Condition:
|
||||
Create a new condition object based on the keyword arguments passed.
|
||||
|
||||
- `input` and `target`:
|
||||
:class:`pina.condition.input_target_condition.InputTargetCondition`
|
||||
:class:`~pina.condition.input_target_condition.InputTargetCondition`
|
||||
- `domain` and `equation`:
|
||||
:class:`pina.condition.domain_equation_condition.
|
||||
:class:`~pina.condition.domain_equation_condition.
|
||||
DomainEquationCondition`
|
||||
- `input` and `equation`: :class:`pina.condition.
|
||||
- `input` and `equation`: :class:`~pina.condition.
|
||||
input_equation_condition.InputEquationCondition`
|
||||
- `input`: :class:`pina.condition.data_condition.DataCondition`
|
||||
- `input`: :class:`~pina.condition.data_condition.DataCondition`
|
||||
- `input` and `conditional_variables`:
|
||||
:class:`pina.condition.data_condition.DataCondition`
|
||||
:class:`~pina.condition.data_condition.DataCondition`
|
||||
|
||||
:raises ValueError: No valid condition has been found.
|
||||
:return: A new condition instance belonging to the proper class.
|
||||
|
||||
@@ -35,7 +35,7 @@ class DataCondition(ConditionInterface):
|
||||
pina.condition.data_condition.GraphDataCondition
|
||||
|
||||
:raises ValueError: If input is not of type :class:`torch.Tensor`,
|
||||
:class:`pina.label_tensor.LabelTensor`, :class:`pina.graph.Graph`,
|
||||
:class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`,
|
||||
or :class:`~torch_geometric.data.Data`.
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ class DataCondition(ConditionInterface):
|
||||
:type conditional_variables: torch.Tensor or LabelTensor
|
||||
|
||||
.. note::
|
||||
If either `input` is composed by a list of :class:`pina.graph.Graph`
|
||||
If either `input` is composed by a list of :class:`~pina.graph.Graph`
|
||||
or :class:`~torch_geometric.data.Data` objects, all elements must
|
||||
have the same structure (keys and data types)
|
||||
"""
|
||||
@@ -80,12 +80,12 @@ class DataCondition(ConditionInterface):
|
||||
class TensorDataCondition(DataCondition):
|
||||
"""
|
||||
DataCondition for :class:`torch.Tensor` or
|
||||
:class:`pina.label_tensor.LabelTensor` input data
|
||||
:class:`~pina.label_tensor.LabelTensor` input data
|
||||
"""
|
||||
|
||||
|
||||
class GraphDataCondition(DataCondition):
|
||||
"""
|
||||
DataCondition for :class:`pina.graph.Graph` or
|
||||
DataCondition for :class:`~pina.graph.Graph` or
|
||||
:class:`~torch_geometric.data.Data` input data
|
||||
"""
|
||||
|
||||
@@ -35,7 +35,7 @@ class InputEquationCondition(ConditionInterface):
|
||||
pina.condition.input_equation_condition.InputGraphEquationCondition
|
||||
|
||||
:raises ValueError: If input is not of type
|
||||
:class:`pina.label_tensor.LabelTensor`, :class:`pina.graph.Graph`.
|
||||
:class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`.
|
||||
"""
|
||||
|
||||
# If the class is already a subclass, return the instance
|
||||
@@ -62,15 +62,16 @@ class InputEquationCondition(ConditionInterface):
|
||||
Initialize the InputEquationCondition by storing the input and equation.
|
||||
|
||||
:param input: Input data for the condition.
|
||||
:type input: LabelTensor | Graph | list[Graph] | tuple[Graph]
|
||||
:type input: pina.label_tensor.LabelTensor | pina.graph.Graph |
|
||||
list[pina.graph.Graph] | tuple[pina.graph.Graph]
|
||||
:param EquationInterface equation: Equation object containing the
|
||||
equation function.
|
||||
|
||||
.. note::
|
||||
If `input` is composed by a list of :class:`pina.graph.Graph`
|
||||
If `input` is composed by a list of :class:`~pina.graph.Graph`
|
||||
objects, all elements must have the same structure (keys and data
|
||||
types). Moreover, at least one attribute must be a
|
||||
:class:`pina.label_tensor.LabelTensor`.
|
||||
:class:`~pina.label_tensor.LabelTensor`.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
@@ -90,21 +91,21 @@ class InputEquationCondition(ConditionInterface):
|
||||
|
||||
class InputTensorEquationCondition(InputEquationCondition):
|
||||
"""
|
||||
InputEquationCondition subclass for :class:`pina.label_tensor.LabelTensor`
|
||||
InputEquationCondition subclass for :class:`~pina.label_tensor.LabelTensor`
|
||||
input data.
|
||||
"""
|
||||
|
||||
|
||||
class InputGraphEquationCondition(InputEquationCondition):
|
||||
"""
|
||||
InputEquationCondition subclass for :class:`pina.graph.Graph` input data.
|
||||
InputEquationCondition subclass for :class:`~pina.graph.Graph` input data.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _check_label_tensor(input):
|
||||
"""
|
||||
Check if at least one :class:`pina.label_tensor.LabelTensor` is present
|
||||
in the :class:`pina.graph.Graph` object.
|
||||
Check if at least one :class:`~pina.label_tensor.LabelTensor` is present
|
||||
in the :class:`~pina.graph.Graph` object.
|
||||
|
||||
:param input: Input data.
|
||||
:type input: torch.Tensor | Graph | Data
|
||||
|
||||
@@ -44,8 +44,8 @@ class InputTargetCondition(ConditionInterface):
|
||||
pina.condition.input_target_condition.GraphInputGraphTargetCondition
|
||||
|
||||
:raises ValueError: If input and or target are not of type
|
||||
:class:`torch.Tensor`, :class:`pina.label_tensor.LabelTensor`,
|
||||
:class:`pina.graph.Graph`, or :class:`~torch_geometric.data.Data`.
|
||||
:class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`,
|
||||
:class:`~pina.graph.Graph`, or :class:`~torch_geometric.data.Data`.
|
||||
"""
|
||||
if cls != InputTargetCondition:
|
||||
return super().__new__(cls)
|
||||
@@ -97,7 +97,7 @@ class InputTargetCondition(ConditionInterface):
|
||||
|
||||
.. note::
|
||||
If either `input` or `target` are composed by a list of
|
||||
:class:`pina.graph.Graph` or :class:`~torch_geometric.data.Data`
|
||||
:class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data`
|
||||
objects, all elements must have the same structure (keys and data
|
||||
types)
|
||||
"""
|
||||
@@ -122,29 +122,29 @@ class InputTargetCondition(ConditionInterface):
|
||||
class TensorInputTensorTargetCondition(InputTargetCondition):
|
||||
"""
|
||||
InputTargetCondition subclass for :class:`torch.Tensor` or
|
||||
:class:`pina.label_tensor.LabelTensor` input and target data.
|
||||
:class:`~pina.label_tensor.LabelTensor` input and target data.
|
||||
"""
|
||||
|
||||
|
||||
class TensorInputGraphTargetCondition(InputTargetCondition):
|
||||
"""
|
||||
InputTargetCondition subclass for :class:`torch.Tensor` or
|
||||
:class:`pina.label_tensor.LabelTensor` input and
|
||||
:class:`pina.graph.Graph` or :class:`~torch_geometric.data.Data` target
|
||||
:class:`~pina.label_tensor.LabelTensor` input and
|
||||
:class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` target
|
||||
data.
|
||||
"""
|
||||
|
||||
|
||||
class GraphInputTensorTargetCondition(InputTargetCondition):
|
||||
"""
|
||||
InputTargetCondition subclass for :class:`pina.graph.Graph` o
|
||||
InputTargetCondition subclass for :class:`~pina.graph.Graph` o
|
||||
:class:`~torch_geometric.data.Data` input and :class:`torch.Tensor` or
|
||||
:class:`pina.label_tensor.LabelTensor` target data.
|
||||
:class:`~pina.label_tensor.LabelTensor` target data.
|
||||
"""
|
||||
|
||||
|
||||
class GraphInputGraphTargetCondition(InputTargetCondition):
|
||||
"""
|
||||
InputTargetCondition subclass for :class:`pina.graph.Graph`/
|
||||
InputTargetCondition subclass for :class:`~pina.graph.Graph`/
|
||||
:class:`~torch_geometric.data.Data` input and target data.
|
||||
"""
|
||||
|
||||
@@ -157,7 +157,7 @@ class Collator:
|
||||
def _collate_tensor_dataset(data_list):
|
||||
"""
|
||||
Function used to collate the data when the dataset is a
|
||||
:class:`pina.data.dataset.PinaTensorDataset`.
|
||||
:class:`~pina.data.dataset.PinaTensorDataset`.
|
||||
|
||||
:param data_list: Elements to be collated.
|
||||
:type data_list: list[torch.Tensor] | list[LabelTensor]
|
||||
@@ -165,7 +165,7 @@ class Collator:
|
||||
:rtype: dict
|
||||
|
||||
:raises RuntimeError: If the data is not a :class:`torch.Tensor` or a
|
||||
:class:`pina.label_tensor.LabelTensor`.
|
||||
:class:`~pina.label_tensor.LabelTensor`.
|
||||
"""
|
||||
|
||||
if isinstance(data_list[0], LabelTensor):
|
||||
@@ -177,7 +177,7 @@ class Collator:
|
||||
def _collate_graph_dataset(self, data_list):
|
||||
"""
|
||||
Function used to collate the data when the dataset is a
|
||||
:class:`pina.data.dataset.PinaGraphDataset`.
|
||||
:class:`~pina.data.dataset.PinaGraphDataset`.
|
||||
|
||||
:param data_list: Elememts to be collated.
|
||||
:type data_list: list[Data] | list[Graph]
|
||||
@@ -185,7 +185,7 @@ class Collator:
|
||||
:rtype: dict
|
||||
|
||||
:raises RuntimeError: If the data is not a
|
||||
:class:`~torch_geometric.data.Data` or a :class:`pina.graph.Graph`.
|
||||
:class:`~torch_geometric.data.Data` or a :class:`~pina.graph.Graph`.
|
||||
"""
|
||||
|
||||
if isinstance(data_list[0], LabelTensor):
|
||||
|
||||
@@ -15,26 +15,26 @@ class PinaDatasetFactory:
|
||||
Depending on the type inside the conditions, it creates a different dataset
|
||||
object:
|
||||
|
||||
- :class:`pina.data.dataset.PinaTensorDataset` for handling
|
||||
:class:`torch.Tensor` and :class:`pina.label_tensor.LabelTensor` data.
|
||||
- :class:`pina.data.dataset.PinaGraphDataset` for handling
|
||||
:class:`pina.graph.Graph` and :class:`~torch_geometric.data.Data` data.
|
||||
- :class:`~pina.data.dataset.PinaTensorDataset` for handling
|
||||
:class:`torch.Tensor` and :class:`~pina.label_tensor.LabelTensor` data.
|
||||
- :class:`~pina.data.dataset.PinaGraphDataset` for handling
|
||||
:class:`~pina.graph.Graph` and :class:`~torch_geometric.data.Data` data.
|
||||
"""
|
||||
|
||||
def __new__(cls, conditions_dict, **kwargs):
|
||||
"""
|
||||
Instantiate the appropriate subclass of
|
||||
:class:`pina.data.dataset.PinaDataset`.
|
||||
:class:`~pina.data.dataset.PinaDataset`.
|
||||
|
||||
If a graph is present in the conditions, returns a
|
||||
:class:`pina.data.dataset.PinaGraphDataset`, otherwise returns a
|
||||
:class:`pina.data.dataset.PinaTensorDataset`.
|
||||
:class:`~pina.data.dataset.PinaGraphDataset`, otherwise returns a
|
||||
:class:`~pina.data.dataset.PinaTensorDataset`.
|
||||
|
||||
:param dict conditions_dict: Dictionary containing all the conditions
|
||||
to be included in the dataset instance.
|
||||
:return: A subclass of :class:`pina.data.dataset.PinaDataset`.
|
||||
:rtype: :class:`pina.data.dataset.PinaTensorDataset` |
|
||||
:class:`pina.data.dataset.PinaGraphDataset`
|
||||
:return: A subclass of :class:`~pina.data.dataset.PinaDataset`.
|
||||
:rtype: :class:`~pina.data.dataset.PinaTensorDataset` |
|
||||
:class:`~pina.data.dataset.PinaGraphDataset`
|
||||
|
||||
:raises ValueError: If an empty dictionary is provided.
|
||||
"""
|
||||
@@ -75,15 +75,15 @@ class PinaDatasetFactory:
|
||||
class PinaDataset(Dataset, ABC):
|
||||
"""
|
||||
Abstract class for the PINA dataset. It defines the common interface for
|
||||
the :class:`pina.data.dataset.PinaTensorDataset` and
|
||||
:class:`pina.data.dataset.PinaGraphDataset` classes.
|
||||
the :class:`~pina.data.dataset.PinaTensorDataset` and
|
||||
:class:`~pina.data.dataset.PinaGraphDataset` classes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, conditions_dict, max_conditions_lengths, automatic_batching
|
||||
):
|
||||
"""
|
||||
Initialize a :class:`pina.data.dataset.PinaDataset` instance by storing
|
||||
Initialize a :class:`~pina.data.dataset.PinaDataset` instance by storing
|
||||
the providedconditions dictionary, the maximum number of conditions to
|
||||
consider, and the automatic batching flag.
|
||||
|
||||
@@ -93,7 +93,7 @@ class PinaDataset(Dataset, ABC):
|
||||
points to include in a single batch for each condition.
|
||||
:param bool automatic_batching: Indicates whether PyTorch automatic
|
||||
batching is enabled in
|
||||
:class:`pina.data.data_module.PinaDataModule`.
|
||||
:class:`~pina.data.data_module.PinaDataModule`.
|
||||
"""
|
||||
|
||||
# Store the conditions dictionary
|
||||
@@ -205,7 +205,7 @@ class PinaDataset(Dataset, ABC):
|
||||
class PinaTensorDataset(PinaDataset):
|
||||
"""
|
||||
Dataset class for the PINA dataset with :class:`torch.Tensor` and
|
||||
:class:`pina.label_tensor.LabelTensor` data.
|
||||
:class:`~pina.label_tensor.LabelTensor` data.
|
||||
"""
|
||||
|
||||
# Override _retrive_data method for torch.Tensor data
|
||||
@@ -215,7 +215,7 @@ class PinaTensorDataset(PinaDataset):
|
||||
|
||||
:param dict data: Dictionary containing the data
|
||||
(only :class:`torch.Tensor` or
|
||||
:class:`pina.label_tensor.LabelTensor`).
|
||||
:class:`~pina.label_tensor.LabelTensor`).
|
||||
:param list[int] idx_list: indices to retrieve.
|
||||
:return: Dictionary containing the data at the given indices.
|
||||
:rtype: dict
|
||||
@@ -237,7 +237,7 @@ class PinaTensorDataset(PinaDataset):
|
||||
class PinaGraphDataset(PinaDataset):
|
||||
"""
|
||||
Dataset class for the PINA dataset with :class:`~torch_geometric.data.Data`
|
||||
and :class:`pina.graph.Graph` data.
|
||||
and :class:`~pina.graph.Graph` data.
|
||||
"""
|
||||
|
||||
def _create_graph_batch(self, data):
|
||||
|
||||
@@ -20,13 +20,13 @@ class Graph(Data):
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Create a new instance of the :class:`pina.graph.Graph` class by checking
|
||||
Create a new instance of the :class:`~pina.graph.Graph` class by checking
|
||||
the consistency of the input data and storing the attributes.
|
||||
|
||||
:param kwargs: Parameters used to initialize the
|
||||
:class:`pina.graph.Graph` object.
|
||||
:class:`~pina.graph.Graph` object.
|
||||
:type kwargs: dict
|
||||
:return: A new instance of the :class:`pina.graph.Graph` class.
|
||||
:return: A new instance of the :class:`~pina.graph.Graph` class.
|
||||
:rtype: Graph
|
||||
"""
|
||||
|
||||
@@ -230,7 +230,7 @@ class GraphBuilder:
|
||||
):
|
||||
"""
|
||||
Compute the edge attributes and create a new instance of the
|
||||
:class:`pina.graph.Graph` class.
|
||||
:class:`~pina.graph.Graph` class.
|
||||
|
||||
:param pos: A tensor of shape `(N, D)` representing the positions of `N`
|
||||
points in `D`-dimensional space.
|
||||
@@ -248,8 +248,8 @@ class GraphBuilder:
|
||||
If provided, overrides `edge_attr`.
|
||||
:type custom_edge_func: callable, optional
|
||||
:param kwargs: Additional keyword arguments passed to the
|
||||
:class:`pina.graph.Graph` class constructor.
|
||||
:return: A :class:`pina.graph.Graph` instance constructed using the
|
||||
:class:`~pina.graph.Graph` class constructor.
|
||||
:return: A :class:`~pina.graph.Graph` instance constructed using the
|
||||
provided information.
|
||||
:rtype: Graph
|
||||
"""
|
||||
@@ -289,7 +289,7 @@ class RadiusGraph(GraphBuilder):
|
||||
|
||||
def __new__(cls, pos, radius, **kwargs):
|
||||
"""
|
||||
Extends the :class:`pina.graph.GraphBuilder` class to compute
|
||||
Extends the :class:`~pina.graph.GraphBuilder` class to compute
|
||||
edge_index based on a radius. Each point is connected to all the points
|
||||
within the radius.
|
||||
|
||||
@@ -299,9 +299,9 @@ class RadiusGraph(GraphBuilder):
|
||||
:param radius: The radius within which points are connected.
|
||||
:type radius: float
|
||||
:param kwargs: Additional keyword arguments to be passed to the
|
||||
:class:`pina.graph.GraphBuilder` and :class:`pina.graph.Graph`
|
||||
:class:`~pina.graph.GraphBuilder` and :class:`~pina.graph.Graph`
|
||||
constructors.
|
||||
:return: A :class:`pina.graph.Graph` instance containing the input
|
||||
:return: A :class:`~pina.graph.Graph` instance containing the input
|
||||
information and the computed edge_index.
|
||||
:rtype: Graph
|
||||
"""
|
||||
@@ -339,7 +339,7 @@ class KNNGraph(GraphBuilder):
|
||||
|
||||
def __new__(cls, pos, neighbours, **kwargs):
|
||||
"""
|
||||
Extends the :class:`pina.graph.GraphBuilder` class to compute edge_index
|
||||
Extends the :class:`~pina.graph.GraphBuilder` class to compute edge_index
|
||||
based on a K-nearest neighbors algorithm.
|
||||
|
||||
:param pos: A tensor of shape (N, D) representing the positions of N
|
||||
@@ -351,7 +351,7 @@ class KNNGraph(GraphBuilder):
|
||||
The additional keyword arguments to be passed to GraphBuilder
|
||||
and Graph classes
|
||||
|
||||
:return: A :class:`pina.graph.Graph` instance containg the
|
||||
:return: A :class:`~pina.graph.Graph` instance containg the
|
||||
information passed in input and the computed edge_index
|
||||
:rtype: Graph
|
||||
"""
|
||||
|
||||
@@ -14,14 +14,14 @@ class LabelTensor(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, x, labels, *args, **kwargs):
|
||||
"""
|
||||
Create a new instance of the :class:`pina.label_tensor.LabelTensor`
|
||||
Create a new instance of the :class:`~pina.label_tensor.LabelTensor`
|
||||
class.
|
||||
|
||||
:param torch.Tensor x: :class:`torch.tensor` instance to be casted as a
|
||||
:class:`pina.label_tensor.LabelTensor`.
|
||||
:class:`~pina.label_tensor.LabelTensor`.
|
||||
:param labels: Labels to assign to the tensor.
|
||||
:type labels: str | list[str] | dict
|
||||
:return: The instance of the :class:`pina.label_tensor.LabelTensor`
|
||||
:return: The instance of the :class:`~pina.label_tensor.LabelTensor`
|
||||
class.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
@@ -33,10 +33,10 @@ class LabelTensor(torch.Tensor):
|
||||
@property
|
||||
def tensor(self):
|
||||
"""
|
||||
Give the tensor part of the :class:`pina.label_tensor.LabelTensor`
|
||||
Give the tensor part of the :class:`~pina.label_tensor.LabelTensor`
|
||||
object.
|
||||
|
||||
:return: tensor part of the :class:`pina.label_tensor.LabelTensor`.
|
||||
:return: tensor part of the :class:`~pina.label_tensor.LabelTensor`.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
|
||||
@@ -44,7 +44,7 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
def __init__(self, x, labels):
|
||||
"""
|
||||
Construct a :class:`pina.label_tensor.LabelTensor` by passing a dict of
|
||||
Construct a :class:`~pina.label_tensor.LabelTensor` by passing a dict of
|
||||
the labels and a :class:`torch.Tensor`. Internally, the initialization
|
||||
method will store check the compatibility of the labels with the tensor
|
||||
shape.
|
||||
@@ -275,10 +275,10 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
def __str__(self):
|
||||
"""
|
||||
The string representation of the :class:`pina.label_tensor.LabelTensor`.
|
||||
The string representation of the :class:`~pina.label_tensor.LabelTensor`.
|
||||
|
||||
:return: String representation of the
|
||||
:class:`pina.label_tensor.LabelTensor` instance.
|
||||
:class:`~pina.label_tensor.LabelTensor` instance.
|
||||
:rtype: str
|
||||
"""
|
||||
|
||||
@@ -295,7 +295,7 @@ class LabelTensor(torch.Tensor):
|
||||
Concatenate a list of tensors along a specified dimension. For more
|
||||
details, see :meth:`torch.cat`.
|
||||
|
||||
:param list[LabelTensor] tensors: :class:`pina.label_tensor.LabelTensor`
|
||||
:param list[LabelTensor] tensors: :class:`~pina.label_tensor.LabelTensor`
|
||||
instances to concatenate
|
||||
:param int dim: dimensions on which you want to perform the operation
|
||||
(default is 0)
|
||||
@@ -351,7 +351,7 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
:param list[LabelTensor] tensors: A list of tensors to stack.
|
||||
All tensors must have the same shape.
|
||||
:return: A new :class:`pina.label_tensor.LabelTensor` instance obtained
|
||||
:return: A new :class:`~pina.label_tensor.LabelTensor` instance obtained
|
||||
by stacking the input tensors, with the updated labels.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
@@ -373,7 +373,7 @@ class LabelTensor(torch.Tensor):
|
||||
:param bool mode: A boolean value indicating whether the tensor should
|
||||
track gradients.If `True`, the tensor will track gradients;
|
||||
if `False`, it will not.
|
||||
:return: The :class:`pina.label_tensor.LabelTensor` itself with the
|
||||
:return: The :class:`~pina.label_tensor.LabelTensor` itself with the
|
||||
updated `requires_grad` state and retained labels.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
@@ -399,7 +399,7 @@ class LabelTensor(torch.Tensor):
|
||||
Performs Tensor dtype and/or device conversion. For more details, see
|
||||
:meth:`torch.Tensor.to`.
|
||||
|
||||
:return: A new :class:`pina.label_tensor.LabelTensor` instance with the
|
||||
:return: A new :class:`~pina.label_tensor.LabelTensor` instance with the
|
||||
updated dtype and/or device and retained labels.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
@@ -410,10 +410,10 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
def clone(self, *args, **kwargs):
|
||||
"""
|
||||
Clone the :class:`pina.label_tensor.LabelTensor`. For more details, see
|
||||
Clone the :class:`~pina.label_tensor.LabelTensor`. For more details, see
|
||||
:meth:`torch.Tensor.clone`.
|
||||
|
||||
:return: A new :class:`pina.label_tensor.LabelTensor` instance with the
|
||||
:return: A new :class:`~pina.label_tensor.LabelTensor` instance with the
|
||||
same data and labels but allocated in a different memory location.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
@@ -472,9 +472,9 @@ class LabelTensor(torch.Tensor):
|
||||
Stack tensors vertically. For more details, see :meth:`torch.vstack`.
|
||||
|
||||
:param list of LabelTensor label_tensors: The
|
||||
:class:`pina.label_tensor.LabelTensor` instances to stack. They need
|
||||
:class:`~pina.label_tensor.LabelTensor` instances to stack. They need
|
||||
to have equal labels.
|
||||
:return: A new :class:`pina.label_tensor.LabelTensor` instance obtained
|
||||
:return: A new :class:`~pina.label_tensor.LabelTensor` instance obtained
|
||||
by stacking the input tensors vertically.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
@@ -530,13 +530,13 @@ class LabelTensor(torch.Tensor):
|
||||
def __getitem__(self, index):
|
||||
""" "
|
||||
Override the __getitem__ method to handle the labels of the
|
||||
:class:`pina.label_tensor.LabelTensor` instance. It first performs
|
||||
:class:`~pina.label_tensor.LabelTensor` instance. It first performs
|
||||
__getitem__ operation on the :class:`torch.Tensor` part of the instance,
|
||||
then updates the labels based on the index.
|
||||
|
||||
:param index: The index used to access the item
|
||||
:type index: int | str | tuple of int | list ot int | torch.Tensor
|
||||
:return: A new :class:`pina.label_tensor.LabelTensor` instance obtained
|
||||
:return: A new :class:`~pina.label_tensor.LabelTensor` instance obtained
|
||||
`__getitem__` operation on :class:`torch.Tensor` part of the
|
||||
instance, with the updated labels.
|
||||
:rtype: LabelTensor
|
||||
@@ -672,7 +672,7 @@ class LabelTensor(torch.Tensor):
|
||||
def summation(tensors):
|
||||
"""
|
||||
Computes the summation of a list of
|
||||
:class:`pina.label_tensor.LabelTensor` instances.
|
||||
:class:`~pina.label_tensor.LabelTensor` instances.
|
||||
|
||||
|
||||
:param list[LabelTensor] tensors: A list of tensors to sum. All
|
||||
@@ -719,7 +719,7 @@ class LabelTensor(torch.Tensor):
|
||||
For more details, see :meth:`torch.Tensor.reshape`.
|
||||
|
||||
:param tuple of int shape: The new shape of the tensor.
|
||||
:return: A new :class:`pina.label_tensor.LabelTensor` instance with the
|
||||
:return: A new :class:`~pina.label_tensor.LabelTensor` instance with the
|
||||
updated shape and labels.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user