Fix doc
This commit is contained in:
@@ -32,8 +32,8 @@ class ConditionInterface(metaclass=ABCMeta):
|
|||||||
"""
|
"""
|
||||||
Set the problem to which the condition is associated.
|
Set the problem to which the condition is associated.
|
||||||
|
|
||||||
:param value: Problem to which the condition is associated.
|
:param pina.problem.AbstractProblem value: Problem to which the
|
||||||
:type value: pina.problem.AbstractProblem
|
condition is associated.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self._problem = value
|
self._problem = value
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from ..graph import Graph
|
|||||||
class DataCondition(ConditionInterface):
|
class DataCondition(ConditionInterface):
|
||||||
"""
|
"""
|
||||||
This condition must be used every time a Unsupervised Loss is needed in
|
This condition must be used every time a Unsupervised Loss is needed in
|
||||||
the Solver. The conditionalvariable can be passed as extra-input when
|
the Solver. The `conditional_variable` can be passed as extra-input when
|
||||||
the model learns a conditional distribution.
|
the model learns a conditional distribution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -31,7 +31,8 @@ class DataCondition(ConditionInterface):
|
|||||||
:param conditional_variables: Conditional variables for the condition.
|
:param conditional_variables: Conditional variables for the condition.
|
||||||
:type conditional_variables: torch.Tensor | LabelTensor
|
:type conditional_variables: torch.Tensor | LabelTensor
|
||||||
:return: Subclass of DataCondition.
|
:return: Subclass of DataCondition.
|
||||||
:rtype: TensorDataCondition | GraphDataCondition
|
:rtype: pina.condition.data_condition.TensorDataCondition |
|
||||||
|
pina.condition.data_condition.GraphDataCondition
|
||||||
|
|
||||||
:raises ValueError: If input is not of type :class:`torch.Tensor`,
|
:raises ValueError: If input is not of type :class:`torch.Tensor`,
|
||||||
:class:`pina.label_tensor.LabelTensor`, :class:`pina.graph.Graph`,
|
:class:`pina.label_tensor.LabelTensor`, :class:`pina.graph.Graph`,
|
||||||
|
|||||||
@@ -30,7 +30,9 @@ class InputEquationCondition(ConditionInterface):
|
|||||||
:param EquationInterface equation: Equation object containing the
|
:param EquationInterface equation: Equation object containing the
|
||||||
equation function.
|
equation function.
|
||||||
:return: Subclass of InputEquationCondition, based on the input type.
|
:return: Subclass of InputEquationCondition, based on the input type.
|
||||||
:rtype: InputTensorEquationCondition | InputGraphEquationCondition
|
:rtype: pina.condition.input_equation_condition.
|
||||||
|
InputTensorEquationCondition |
|
||||||
|
pina.condition.input_equation_condition.InputGraphEquationCondition
|
||||||
|
|
||||||
:raises ValueError: If input is not of type
|
:raises ValueError: If input is not of type
|
||||||
:class:`pina.label_tensor.LabelTensor`, :class:`pina.graph.Graph`.
|
:class:`pina.label_tensor.LabelTensor`, :class:`pina.graph.Graph`.
|
||||||
@@ -105,7 +107,7 @@ class InputGraphEquationCondition(InputEquationCondition):
|
|||||||
in the :class:`pina.graph.Graph` object.
|
in the :class:`pina.graph.Graph` object.
|
||||||
|
|
||||||
:param input: Input data.
|
:param input: Input data.
|
||||||
:type input: torch.Tensor | Graph | torch_geometric.data.Data
|
:type input: torch.Tensor | Graph | Data
|
||||||
|
|
||||||
:raises ValueError: If the input data object does not contain at least
|
:raises ValueError: If the input data object does not contain at least
|
||||||
one LabelTensor.
|
one LabelTensor.
|
||||||
|
|||||||
@@ -25,23 +25,26 @@ class InputTargetCondition(ConditionInterface):
|
|||||||
the types of input and target data.
|
the types of input and target data.
|
||||||
|
|
||||||
:param input: Input data for the condition.
|
:param input: Input data for the condition.
|
||||||
:type input: torch.Tensor | LabelTensor | Graph |
|
:type input: torch.Tensor | LabelTensor | Graph |
|
||||||
torch_geometric.data.Data | list[Graph] |
|
torch_geometric.data.Data | list[Graph] |
|
||||||
list[torch_geometric.data.Data] | tuple[Graph] |
|
list[torch_geometric.data.Data] | tuple[Graph] |
|
||||||
tuple[torch_geometric.data.Data]
|
tuple[torch_geometric.data.Data]
|
||||||
:param target: Target data for the condition.
|
:param target: Target data for the condition.
|
||||||
:type target: torch.Tensor | LabelTensor | Graph |
|
:type target: torch.Tensor | LabelTensor | Graph |
|
||||||
torch_geometric.data.Data | list[Graph] |
|
torch_geometric.data.Data | list[Graph] |
|
||||||
list[torch_geometric.data.Data] | tuple[Graph] |
|
list[torch_geometric.data.Data] | tuple[Graph] |
|
||||||
tuple[torch_geometric.data.Data]
|
tuple[torch_geometric.data.Data]
|
||||||
:return: Subclass of InputTargetCondition
|
:return: Subclass of InputTargetCondition
|
||||||
:rtype: TensorInputTensorTargetCondition | \
|
:rtype: pina.condition.input_target_condition.
|
||||||
TensorInputGraphTargetCondition | \
|
TensorInputTensorTargetCondition |
|
||||||
GraphInputTensorTargetCondition | \
|
pina.condition.input_target_condition.
|
||||||
GraphInputGraphTargetCondition
|
TensorInputGraphTargetCondition |
|
||||||
|
pina.condition.input_target_condition.
|
||||||
|
GraphInputTensorTargetCondition |
|
||||||
|
pina.condition.input_target_condition.GraphInputGraphTargetCondition
|
||||||
|
|
||||||
:raises ValueError: If input and or target are not of type
|
:raises ValueError: If input and or target are not of type
|
||||||
:class:`torch.Tensor`, :class:`pina.label_tensor.LabelTensor`,
|
:class:`torch.Tensor`, :class:`pina.label_tensor.LabelTensor`,
|
||||||
:class:`pina.graph.Graph`, or :class:`~torch_geometric.data.Data`.
|
:class:`pina.graph.Graph`, or :class:`~torch_geometric.data.Data`.
|
||||||
"""
|
"""
|
||||||
if cls != InputTargetCondition:
|
if cls != InputTargetCondition:
|
||||||
|
|||||||
114
pina/graph.py
114
pina/graph.py
@@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
This module provides an interface to build torch_geometric.data.Data objects.
|
Module to build Graph objects and perform operations on them.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -11,7 +11,8 @@ from .utils import check_consistency, is_function
|
|||||||
|
|
||||||
class Graph(Data):
|
class Graph(Data):
|
||||||
"""
|
"""
|
||||||
A class to build torch_geometric.data.Data objects.
|
Extends :class:`~torch_geometric.data.Data` class to include additional
|
||||||
|
checks and functionlities.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(
|
def __new__(
|
||||||
@@ -19,13 +20,16 @@ class Graph(Data):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Instantiates a new instance of the Graph class, performing type
|
Create a new instance of the :class:`pina.graph.Graph` class by checking
|
||||||
consistency checks.
|
the consistency of the input data and storing the attributes.
|
||||||
|
|
||||||
:param kwargs: Parameters to construct the Graph object.
|
:param kwargs: Parameters used to initialize the
|
||||||
:return: A new instance of the Graph class.
|
:class:`pina.graph.Graph` object.
|
||||||
|
:type kwargs: dict
|
||||||
|
:return: A new instance of the :class:`pina.graph.Graph` class.
|
||||||
:rtype: Graph
|
:rtype: Graph
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# create class instance
|
# create class instance
|
||||||
instance = Data.__new__(cls)
|
instance = Data.__new__(cls)
|
||||||
|
|
||||||
@@ -45,27 +49,28 @@ class Graph(Data):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the Graph object by setting the node features, edge index,
|
Initialize the object by setting the node features, edge index,
|
||||||
edge attributes, and positions. The edge index is preprocessed to make
|
edge attributes, and positions. The edge index is preprocessed to make
|
||||||
the graph undirected if required. For more details, see the
|
the graph undirected if required. For more details, see the
|
||||||
:meth: `torch_geometric.data.Data`
|
:meth:`torch_geometric.data.Data`
|
||||||
|
|
||||||
:param x: Optional tensor of node features (N, F) where F is the number
|
:param x: Optional tensor of node features `(N, F)` where `F` is the
|
||||||
of features per node.
|
number of features per node.
|
||||||
:type x: torch.Tensor, LabelTensor
|
:type x: torch.Tensor, LabelTensor
|
||||||
:param torch.Tensor edge_index: A tensor of shape (2, E) representing
|
:param torch.Tensor edge_index: A tensor of shape `(2, E)` representing
|
||||||
the indices of the graph's edges.
|
the indices of the graph's edges.
|
||||||
:param pos: A tensor of shape (N, D) representing the positions of N
|
:param pos: A tensor of shape `(N, D)` representing the positions of `N`
|
||||||
points in D-dimensional space.
|
points in `D`-dimensional space.
|
||||||
:type pos: torch.Tensor | LabelTensor
|
:type pos: torch.Tensor | LabelTensor
|
||||||
:param edge_attr: Optional tensor of edge_featured (E, F') where F' is
|
:param edge_attr: Optional tensor of edge_featured `(E, F')` where `F'`
|
||||||
the number of edge features
|
is the number of edge features
|
||||||
:param bool undirected: Whether to make the graph undirected
|
:param bool undirected: Whether to make the graph undirected
|
||||||
:param kwargs: Additional keyword arguments passed to the
|
:param kwargs: Additional keyword arguments passed to the
|
||||||
`torch_geometric.data.Data` class constructor. If the argument
|
`torch_geometric.data.Data` class constructor. If the argument
|
||||||
is a `torch.Tensor` or `LabelTensor`, it is included in the Data
|
is a `torch.Tensor` or `LabelTensor`, it is included in the graph
|
||||||
object as a graph parameter.
|
object.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# preprocessing
|
# preprocessing
|
||||||
self._preprocess_edge_index(edge_index, undirected)
|
self._preprocess_edge_index(edge_index, undirected)
|
||||||
|
|
||||||
@@ -107,6 +112,8 @@ class Graph(Data):
|
|||||||
Check if the position tensor is consistent.
|
Check if the position tensor is consistent.
|
||||||
|
|
||||||
:param torch.Tensor pos: The position tensor.
|
:param torch.Tensor pos: The position tensor.
|
||||||
|
|
||||||
|
:raises ValueError: If the position tensor is not consistent.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if pos is not None:
|
if pos is not None:
|
||||||
@@ -120,6 +127,8 @@ class Graph(Data):
|
|||||||
Check if the edge index is consistent.
|
Check if the edge index is consistent.
|
||||||
|
|
||||||
:param torch.Tensor edge_index: The edge index tensor.
|
:param torch.Tensor edge_index: The edge index tensor.
|
||||||
|
|
||||||
|
:raises ValueError: If the edge index tensor is not consistent.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
check_consistency(edge_index, (torch.Tensor, LabelTensor))
|
check_consistency(edge_index, (torch.Tensor, LabelTensor))
|
||||||
@@ -136,6 +145,9 @@ class Graph(Data):
|
|||||||
|
|
||||||
:param torch.Tensor edge_attr: The edge attribute tensor.
|
:param torch.Tensor edge_attr: The edge attribute tensor.
|
||||||
:param torch.Tensor edge_index: The edge index tensor.
|
:param torch.Tensor edge_index: The edge index tensor.
|
||||||
|
|
||||||
|
:raises ValueError: If the edge attribute tensor is not consistent.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if edge_attr is not None:
|
if edge_attr is not None:
|
||||||
@@ -217,27 +229,28 @@ class GraphBuilder:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Compute the edge attributes and create a new instance of the Graph
|
Compute the edge attributes and create a new instance of the
|
||||||
class.
|
:class:`pina.graph.Graph` class.
|
||||||
|
|
||||||
:param pos: A tensor of shape (N, D) representing the positions of N
|
:param pos: A tensor of shape `(N, D)` representing the positions of `N`
|
||||||
points in D-dimensional space.
|
points in `D`-dimensional space.
|
||||||
:type pos: torch.Tensor or LabelTensor
|
:type pos: torch.Tensor or LabelTensor
|
||||||
:param edge_index: A tensor of shape (2, E) representing the indices of
|
:param edge_index: A tensor of shape `(2, E)` representing the indices
|
||||||
the graph's edges.
|
of the graph's edges.
|
||||||
:type edge_index: torch.Tensor
|
:type edge_index: torch.Tensor
|
||||||
:param x: Optional tensor of node features of shape (N, F), where F is
|
:param x: Optional tensor of node features of shape `(N, F)`, where `F`
|
||||||
the number of features per node.
|
is the number of features per node.
|
||||||
:type x: torch.Tensor | LabelTensor, optional
|
:type x: torch.Tensor | LabelTensor, optional
|
||||||
:param edge_attr: Optional tensor of edge attributes of shape (E, F),
|
:param edge_attr: Optional tensor of edge attributes of shape `(E, F)`,
|
||||||
where F is the number of features per edge.
|
where `F` is the number of features per edge.
|
||||||
:type edge_attr: torch.Tensor, optional
|
:type edge_attr: torch.Tensor, optional
|
||||||
:param custom_edge_func: A custom function to compute edge attributes.
|
:param custom_edge_func: A custom function to compute edge attributes.
|
||||||
If provided, overrides `edge_attr`.
|
If provided, overrides `edge_attr`.
|
||||||
:type custom_edge_func: callable, optional
|
:type custom_edge_func: callable, optional
|
||||||
:param kwargs: Additional keyword arguments passed to the Graph class
|
:param kwargs: Additional keyword arguments passed to the
|
||||||
constructor.
|
:class:`pina.graph.Graph` class constructor.
|
||||||
:return: A Graph instance constructed using the provided information.
|
:return: A :class:`pina.graph.Graph` instance constructed using the
|
||||||
|
provided information.
|
||||||
:rtype: Graph
|
:rtype: Graph
|
||||||
"""
|
"""
|
||||||
edge_attr = cls._create_edge_attr(
|
edge_attr = cls._create_edge_attr(
|
||||||
@@ -271,42 +284,46 @@ class GraphBuilder:
|
|||||||
|
|
||||||
class RadiusGraph(GraphBuilder):
|
class RadiusGraph(GraphBuilder):
|
||||||
"""
|
"""
|
||||||
A class to build a radius graph.
|
A class to build a graph based on a radius.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls, pos, radius, **kwargs):
|
def __new__(cls, pos, radius, **kwargs):
|
||||||
"""
|
"""
|
||||||
Extends the `GraphBuilder` class to compute edge_index based on a
|
Extends the :class:`pina.graph.GraphBuilder` class to compute
|
||||||
radius. Each point is connected to all the points within the radius.
|
edge_index based on a radius. Each point is connected to all the points
|
||||||
|
within the radius.
|
||||||
|
|
||||||
:param pos: A tensor of shape (N, D) representing the positions of N
|
:param pos: A tensor of shape `(N, D)` representing the positions of `N`
|
||||||
points in D-dimensional space.
|
points in `D`-dimensional space.
|
||||||
:type pos: torch.Tensor or LabelTensor
|
:type pos: torch.Tensor or LabelTensor
|
||||||
:param radius: The radius within which points are connected.
|
:param radius: The radius within which points are connected.
|
||||||
:type radius: float
|
:type radius: float
|
||||||
:param kwargs: Additional keyword arguments to be passed to the
|
:param kwargs: Additional keyword arguments to be passed to the
|
||||||
`GraphBuilder` and `Graph` constructors.
|
:class:`pina.graph.GraphBuilder` and :class:`pina.graph.Graph`
|
||||||
:return: A `Graph` instance containing the input information and the
|
constructors.
|
||||||
computed edge_index.
|
:return: A :class:`pina.graph.Graph` instance containing the input
|
||||||
|
information and the computed edge_index.
|
||||||
:rtype: Graph
|
:rtype: Graph
|
||||||
"""
|
"""
|
||||||
|
|
||||||
edge_index = cls.compute_radius_graph(pos, radius)
|
edge_index = cls.compute_radius_graph(pos, radius)
|
||||||
return super().__new__(cls, pos=pos, edge_index=edge_index, **kwargs)
|
return super().__new__(cls, pos=pos, edge_index=edge_index, **kwargs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def compute_radius_graph(points, radius):
|
def compute_radius_graph(points, radius):
|
||||||
"""
|
"""
|
||||||
Computes edge_index for a given set of points base on the radius.
|
Computes `edge_index` for a given set of points base on the radius.
|
||||||
Each point is connected to all the points within the radius.
|
Each point is connected to all the points within the radius.
|
||||||
|
|
||||||
:param points: A tensor of shape (N, D) representing the positions of
|
:param points: A tensor of shape `(N, D)` representing the positions of
|
||||||
N points in D-dimensional space.
|
N points in D-dimensional space.
|
||||||
:type points: torch.Tensor | LabelTensor
|
:type points: torch.Tensor | LabelTensor
|
||||||
:param float radius: The number of nearest neighbors to find for each
|
:param float radius: The number of nearest neighbors to find for each
|
||||||
point.
|
point.
|
||||||
:rtype torch.Tensor: A tensor of shape (2, E), where E is the number of
|
:rtype torch.Tensor: A tensor of shape `(2, E)`, where `E` is the number
|
||||||
edges, representing the edge indices of the KNN graph.
|
of edges, representing the edge indices of the KNN graph.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dist = torch.cdist(points, points, p=2)
|
dist = torch.cdist(points, points, p=2)
|
||||||
return (
|
return (
|
||||||
torch.nonzero(dist <= radius, as_tuple=False)
|
torch.nonzero(dist <= radius, as_tuple=False)
|
||||||
@@ -317,13 +334,13 @@ class RadiusGraph(GraphBuilder):
|
|||||||
|
|
||||||
class KNNGraph(GraphBuilder):
|
class KNNGraph(GraphBuilder):
|
||||||
"""
|
"""
|
||||||
A class to build a KNN graph.
|
A class to build a K-nearest neighbors graph.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls, pos, neighbours, **kwargs):
|
def __new__(cls, pos, neighbours, **kwargs):
|
||||||
"""
|
"""
|
||||||
Creates a new instance of the Graph class using k-nearest neighbors
|
Extends the :class:`pina.graph.GraphBuilder` class to compute edge_index
|
||||||
algorithm to define the edges.
|
based on a K-nearest neighbors algorithm.
|
||||||
|
|
||||||
:param pos: A tensor of shape (N, D) representing the positions of N
|
:param pos: A tensor of shape (N, D) representing the positions of N
|
||||||
points in D-dimensional space.
|
points in D-dimensional space.
|
||||||
@@ -334,8 +351,8 @@ class KNNGraph(GraphBuilder):
|
|||||||
The additional keyword arguments to be passed to GraphBuilder
|
The additional keyword arguments to be passed to GraphBuilder
|
||||||
and Graph classes
|
and Graph classes
|
||||||
|
|
||||||
:return: Graph instance containg the information passed in input and
|
:return: A :class:`pina.graph.Graph` instance containg the
|
||||||
the computed edge_index
|
information passed in input and the computed edge_index
|
||||||
:rtype: Graph
|
:rtype: Graph
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -371,7 +388,8 @@ class LabelBatch(Batch):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_data_list(cls, data_list):
|
def from_data_list(cls, data_list):
|
||||||
"""
|
"""
|
||||||
Create a Batch object from a list of Data objects.
|
Create a Batch object from a list of :class:`~torch_geometric.data.Data`
|
||||||
|
objects.
|
||||||
|
|
||||||
:param data_list: List of Data/Graph objects
|
:param data_list: List of Data/Graph objects
|
||||||
:type data_list: list[Data] | list[Graph]
|
:type data_list: list[Data] | list[Graph]
|
||||||
|
|||||||
@@ -14,13 +14,15 @@ class LabelTensor(torch.Tensor):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def __new__(cls, x, labels, *args, **kwargs):
|
def __new__(cls, x, labels, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Create a new instance of the :class:`LabelTensor` class.
|
Create a new instance of the :class:`pina.label_tensor.LabelTensor`
|
||||||
|
class.
|
||||||
|
|
||||||
:param torch.Tensor x: :class:`torch.tensor` instance to be casted as a
|
:param torch.Tensor x: :class:`torch.tensor` instance to be casted as a
|
||||||
:class:`LabelTensor`.
|
:class:`pina.label_tensor.LabelTensor`.
|
||||||
:param labels: Labels to assign to the tensor.
|
:param labels: Labels to assign to the tensor.
|
||||||
:type labels: str | list[str] | dict
|
:type labels: str | list[str] | dict
|
||||||
:return: The instance of the :class:`LabelTensor` class.
|
:return: The instance of the :class:`pina.label_tensor.LabelTensor`
|
||||||
|
class.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -31,9 +33,10 @@ class LabelTensor(torch.Tensor):
|
|||||||
@property
|
@property
|
||||||
def tensor(self):
|
def tensor(self):
|
||||||
"""
|
"""
|
||||||
Give the tensor part of the :class:`LabelTensor` object.
|
Give the tensor part of the :class:`pina.label_tensor.LabelTensor`
|
||||||
|
object.
|
||||||
|
|
||||||
:return: tensor part of the :class:`LabelTensor`.
|
:return: tensor part of the :class:`pina.label_tensor.LabelTensor`.
|
||||||
:rtype: torch.Tensor
|
:rtype: torch.Tensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -41,9 +44,10 @@ class LabelTensor(torch.Tensor):
|
|||||||
|
|
||||||
def __init__(self, x, labels):
|
def __init__(self, x, labels):
|
||||||
"""
|
"""
|
||||||
Construct a :class:`LabelTensor` by passing a dict of the labels and a
|
Construct a :class:`pina.label_tensor.LabelTensor` by passing a dict of
|
||||||
:class:`torch.Tensor`. Internally, the initialization method will store
|
the labels and a :class:`torch.Tensor`. Internally, the initialization
|
||||||
check the compatibility of the labels with the tensor shape.
|
method will store check the compatibility of the labels with the tensor
|
||||||
|
shape.
|
||||||
|
|
||||||
:Example:
|
:Example:
|
||||||
>>> from pina import LabelTensor
|
>>> from pina import LabelTensor
|
||||||
@@ -271,9 +275,10 @@ class LabelTensor(torch.Tensor):
|
|||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
"""
|
"""
|
||||||
The string representation of the :class:`LabelTensor`.
|
The string representation of the :class:`pina.label_tensor.LabelTensor`.
|
||||||
|
|
||||||
:return: String representation of the :class:`LabelTensor` instance.
|
:return: String representation of the
|
||||||
|
:class:`pina.label_tensor.LabelTensor` instance.
|
||||||
:rtype: str
|
:rtype: str
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -290,8 +295,8 @@ class LabelTensor(torch.Tensor):
|
|||||||
Concatenate a list of tensors along a specified dimension. For more
|
Concatenate a list of tensors along a specified dimension. For more
|
||||||
details, see :meth:`torch.cat`.
|
details, see :meth:`torch.cat`.
|
||||||
|
|
||||||
:param list[LabelTensor] tensors: :class:`LabelTensor` instances to
|
:param list[LabelTensor] tensors: :class:`pina.label_tensor.LabelTensor`
|
||||||
concatenate
|
instances to concatenate
|
||||||
:param int dim: dimensions on which you want to perform the operation
|
:param int dim: dimensions on which you want to perform the operation
|
||||||
(default is 0)
|
(default is 0)
|
||||||
:return: A new :class:`LabelTensor' instance obtained by concatenating
|
:return: A new :class:`LabelTensor' instance obtained by concatenating
|
||||||
@@ -346,8 +351,8 @@ class LabelTensor(torch.Tensor):
|
|||||||
|
|
||||||
:param list[LabelTensor] tensors: A list of tensors to stack.
|
:param list[LabelTensor] tensors: A list of tensors to stack.
|
||||||
All tensors must have the same shape.
|
All tensors must have the same shape.
|
||||||
:return: A new :class:`LabelTensor` instance obtained by stacking the
|
:return: A new :class:`pina.label_tensor.LabelTensor` instance obtained
|
||||||
input tensors, with the updated labels.
|
by stacking the input tensors, with the updated labels.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -368,8 +373,8 @@ class LabelTensor(torch.Tensor):
|
|||||||
:param bool mode: A boolean value indicating whether the tensor should
|
:param bool mode: A boolean value indicating whether the tensor should
|
||||||
track gradients.If `True`, the tensor will track gradients;
|
track gradients.If `True`, the tensor will track gradients;
|
||||||
if `False`, it will not.
|
if `False`, it will not.
|
||||||
:return: The :class:`LabelTensor` itself with the updated
|
:return: The :class:`pina.label_tensor.LabelTensor` itself with the
|
||||||
`requires_grad` state and retained labels.
|
updated `requires_grad` state and retained labels.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -394,8 +399,8 @@ class LabelTensor(torch.Tensor):
|
|||||||
Performs Tensor dtype and/or device conversion. For more details, see
|
Performs Tensor dtype and/or device conversion. For more details, see
|
||||||
:meth:`torch.Tensor.to`.
|
:meth:`torch.Tensor.to`.
|
||||||
|
|
||||||
:return: A new :class:`LabelTensor` instance with the updated dtype
|
:return: A new :class:`pina.label_tensor.LabelTensor` instance with the
|
||||||
and/or device and retained labels.
|
updated dtype and/or device and retained labels.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -405,11 +410,11 @@ class LabelTensor(torch.Tensor):
|
|||||||
|
|
||||||
def clone(self, *args, **kwargs):
|
def clone(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Clone the :class:`LabelTensor`. For more details, see
|
Clone the :class:`pina.label_tensor.LabelTensor`. For more details, see
|
||||||
:meth:`torch.Tensor.clone`.
|
:meth:`torch.Tensor.clone`.
|
||||||
|
|
||||||
:return: A new :class:`LabelTensor` instance with the same data and
|
:return: A new :class:`pina.label_tensor.LabelTensor` instance with the
|
||||||
labels but allocated in a different memory location.
|
same data and labels but allocated in a different memory location.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -466,10 +471,11 @@ class LabelTensor(torch.Tensor):
|
|||||||
"""
|
"""
|
||||||
Stack tensors vertically. For more details, see :meth:`torch.vstack`.
|
Stack tensors vertically. For more details, see :meth:`torch.vstack`.
|
||||||
|
|
||||||
:param list of LabelTensor label_tensors: The :class:`LabelTensor`
|
:param list of LabelTensor label_tensors: The
|
||||||
instances to stack. They need to have equal labels.
|
:class:`pina.label_tensor.LabelTensor` instances to stack. They need
|
||||||
:return: A new :class:`LabelTensor` instance obtained by stacking the
|
to have equal labels.
|
||||||
input tensors vertically.
|
:return: A new :class:`pina.label_tensor.LabelTensor` instance obtained
|
||||||
|
by stacking the input tensors vertically.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -524,15 +530,15 @@ class LabelTensor(torch.Tensor):
|
|||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
""" "
|
""" "
|
||||||
Override the __getitem__ method to handle the labels of the
|
Override the __getitem__ method to handle the labels of the
|
||||||
:class:`LabelTensor` instance. It first performs __getitem__ operation
|
:class:`pina.label_tensor.LabelTensor` instance. It first performs
|
||||||
on the :class:`torch.Tensor` part of the instance, then updates the
|
__getitem__ operation on the :class:`torch.Tensor` part of the instance,
|
||||||
labels based on the index.
|
then updates the labels based on the index.
|
||||||
|
|
||||||
:param index: The index used to access the item
|
:param index: The index used to access the item
|
||||||
:type index: int | str | tuple of int | list ot int | torch.Tensor
|
:type index: int | str | tuple of int | list ot int | torch.Tensor
|
||||||
:return: A new :class:`LabelTensor` instance obtained __getitem__
|
:return: A new :class:`pina.label_tensor.LabelTensor` instance obtained
|
||||||
operation on :class:`torch.Tensor` part of the instance, with the
|
`__getitem__` operation on :class:`torch.Tensor` part of the
|
||||||
updated labels.
|
instance, with the updated labels.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
|
|
||||||
:raises KeyError: If an invalid label index is provided.
|
:raises KeyError: If an invalid label index is provided.
|
||||||
@@ -665,7 +671,8 @@ class LabelTensor(torch.Tensor):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def summation(tensors):
|
def summation(tensors):
|
||||||
"""
|
"""
|
||||||
Computes the summation of a list of :class:`LabelTensor` instances.
|
Computes the summation of a list of
|
||||||
|
:class:`pina.label_tensor.LabelTensor` instances.
|
||||||
|
|
||||||
|
|
||||||
:param list[LabelTensor] tensors: A list of tensors to sum. All
|
:param list[LabelTensor] tensors: A list of tensors to sum. All
|
||||||
@@ -712,8 +719,8 @@ class LabelTensor(torch.Tensor):
|
|||||||
For more details, see :meth:`torch.Tensor.reshape`.
|
For more details, see :meth:`torch.Tensor.reshape`.
|
||||||
|
|
||||||
:param tuple of int shape: The new shape of the tensor.
|
:param tuple of int shape: The new shape of the tensor.
|
||||||
:return: A new :class:`LabelTensor` instance with the updated shape and
|
:return: A new :class:`pina.label_tensor.LabelTensor` instance with the
|
||||||
labels.
|
updated shape and labels.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user