This commit is contained in:
FilippoOlivo
2025-03-12 12:03:47 +01:00
committed by Nicola Demo
parent 59e6ee595c
commit ae796ce34c
6 changed files with 128 additions and 97 deletions

View File

@@ -32,8 +32,8 @@ class ConditionInterface(metaclass=ABCMeta):
"""
Set the problem to which the condition is associated.
:param value: Problem to which the condition is associated.
:type value: pina.problem.AbstractProblem
:param pina.problem.AbstractProblem value: Problem to which the
condition is associated.
"""
self._problem = value

View File

@@ -12,7 +12,7 @@ from ..graph import Graph
class DataCondition(ConditionInterface):
"""
This condition must be used every time a Unsupervised Loss is needed in
the Solver. The conditionalvariable can be passed as extra-input when
the Solver. The `conditional_variable` can be passed as extra-input when
the model learns a conditional distribution.
"""
@@ -31,7 +31,8 @@ class DataCondition(ConditionInterface):
:param conditional_variables: Conditional variables for the condition.
:type conditional_variables: torch.Tensor | LabelTensor
:return: Subclass of DataCondition.
:rtype: TensorDataCondition | GraphDataCondition
:rtype: pina.condition.data_condition.TensorDataCondition |
pina.condition.data_condition.GraphDataCondition
:raises ValueError: If input is not of type :class:`torch.Tensor`,
:class:`pina.label_tensor.LabelTensor`, :class:`pina.graph.Graph`,

View File

@@ -30,7 +30,9 @@ class InputEquationCondition(ConditionInterface):
:param EquationInterface equation: Equation object containing the
equation function.
:return: Subclass of InputEquationCondition, based on the input type.
:rtype: InputTensorEquationCondition | InputGraphEquationCondition
:rtype: pina.condition.input_equation_condition.
InputTensorEquationCondition |
pina.condition.input_equation_condition.InputGraphEquationCondition
:raises ValueError: If input is not of type
:class:`pina.label_tensor.LabelTensor`, :class:`pina.graph.Graph`.
@@ -105,7 +107,7 @@ class InputGraphEquationCondition(InputEquationCondition):
in the :class:`pina.graph.Graph` object.
:param input: Input data.
:type input: torch.Tensor | Graph | torch_geometric.data.Data
:type input: torch.Tensor | Graph | Data
:raises ValueError: If the input data object does not contain at least
one LabelTensor.

View File

@@ -35,10 +35,13 @@ class InputTargetCondition(ConditionInterface):
list[torch_geometric.data.Data] | tuple[Graph] |
tuple[torch_geometric.data.Data]
:return: Subclass of InputTargetCondition
:rtype: TensorInputTensorTargetCondition | \
TensorInputGraphTargetCondition | \
GraphInputTensorTargetCondition | \
GraphInputGraphTargetCondition
:rtype: pina.condition.input_target_condition.
TensorInputTensorTargetCondition |
pina.condition.input_target_condition.
TensorInputGraphTargetCondition |
pina.condition.input_target_condition.
GraphInputTensorTargetCondition |
pina.condition.input_target_condition.GraphInputGraphTargetCondition
:raises ValueError: If input and or target are not of type
:class:`torch.Tensor`, :class:`pina.label_tensor.LabelTensor`,

View File

@@ -1,5 +1,5 @@
"""
This module provides an interface to build torch_geometric.data.Data objects.
Module to build Graph objects and perform operations on them.
"""
import torch
@@ -11,7 +11,8 @@ from .utils import check_consistency, is_function
class Graph(Data):
"""
A class to build torch_geometric.data.Data objects.
Extends :class:`~torch_geometric.data.Data` class to include additional
checks and functionlities.
"""
def __new__(
@@ -19,13 +20,16 @@ class Graph(Data):
**kwargs,
):
"""
Instantiates a new instance of the Graph class, performing type
consistency checks.
Create a new instance of the :class:`pina.graph.Graph` class by checking
the consistency of the input data and storing the attributes.
:param kwargs: Parameters to construct the Graph object.
:return: A new instance of the Graph class.
:param kwargs: Parameters used to initialize the
:class:`pina.graph.Graph` object.
:type kwargs: dict
:return: A new instance of the :class:`pina.graph.Graph` class.
:rtype: Graph
"""
# create class instance
instance = Data.__new__(cls)
@@ -45,27 +49,28 @@ class Graph(Data):
**kwargs,
):
"""
Initialize the Graph object by setting the node features, edge index,
Initialize the object by setting the node features, edge index,
edge attributes, and positions. The edge index is preprocessed to make
the graph undirected if required. For more details, see the
:meth: `torch_geometric.data.Data`
:meth:`torch_geometric.data.Data`
:param x: Optional tensor of node features (N, F) where F is the number
of features per node.
:param x: Optional tensor of node features `(N, F)` where `F` is the
number of features per node.
:type x: torch.Tensor, LabelTensor
:param torch.Tensor edge_index: A tensor of shape (2, E) representing
:param torch.Tensor edge_index: A tensor of shape `(2, E)` representing
the indices of the graph's edges.
:param pos: A tensor of shape (N, D) representing the positions of N
points in D-dimensional space.
:param pos: A tensor of shape `(N, D)` representing the positions of `N`
points in `D`-dimensional space.
:type pos: torch.Tensor | LabelTensor
:param edge_attr: Optional tensor of edge_featured (E, F') where F' is
the number of edge features
:param edge_attr: Optional tensor of edge_featured `(E, F')` where `F'`
is the number of edge features
:param bool undirected: Whether to make the graph undirected
:param kwargs: Additional keyword arguments passed to the
`torch_geometric.data.Data` class constructor. If the argument
is a `torch.Tensor` or `LabelTensor`, it is included in the Data
object as a graph parameter.
is a `torch.Tensor` or `LabelTensor`, it is included in the graph
object.
"""
# preprocessing
self._preprocess_edge_index(edge_index, undirected)
@@ -107,6 +112,8 @@ class Graph(Data):
Check if the position tensor is consistent.
:param torch.Tensor pos: The position tensor.
:raises ValueError: If the position tensor is not consistent.
"""
if pos is not None:
@@ -120,6 +127,8 @@ class Graph(Data):
Check if the edge index is consistent.
:param torch.Tensor edge_index: The edge index tensor.
:raises ValueError: If the edge index tensor is not consistent.
"""
check_consistency(edge_index, (torch.Tensor, LabelTensor))
@@ -136,6 +145,9 @@ class Graph(Data):
:param torch.Tensor edge_attr: The edge attribute tensor.
:param torch.Tensor edge_index: The edge index tensor.
:raises ValueError: If the edge attribute tensor is not consistent.
"""
if edge_attr is not None:
@@ -217,27 +229,28 @@ class GraphBuilder:
**kwargs,
):
"""
Compute the edge attributes and create a new instance of the Graph
class.
Compute the edge attributes and create a new instance of the
:class:`pina.graph.Graph` class.
:param pos: A tensor of shape (N, D) representing the positions of N
points in D-dimensional space.
:param pos: A tensor of shape `(N, D)` representing the positions of `N`
points in `D`-dimensional space.
:type pos: torch.Tensor or LabelTensor
:param edge_index: A tensor of shape (2, E) representing the indices of
the graph's edges.
:param edge_index: A tensor of shape `(2, E)` representing the indices
of the graph's edges.
:type edge_index: torch.Tensor
:param x: Optional tensor of node features of shape (N, F), where F is
the number of features per node.
:param x: Optional tensor of node features of shape `(N, F)`, where `F`
is the number of features per node.
:type x: torch.Tensor | LabelTensor, optional
:param edge_attr: Optional tensor of edge attributes of shape (E, F),
where F is the number of features per edge.
:param edge_attr: Optional tensor of edge attributes of shape `(E, F)`,
where `F` is the number of features per edge.
:type edge_attr: torch.Tensor, optional
:param custom_edge_func: A custom function to compute edge attributes.
If provided, overrides `edge_attr`.
:type custom_edge_func: callable, optional
:param kwargs: Additional keyword arguments passed to the Graph class
constructor.
:return: A Graph instance constructed using the provided information.
:param kwargs: Additional keyword arguments passed to the
:class:`pina.graph.Graph` class constructor.
:return: A :class:`pina.graph.Graph` instance constructed using the
provided information.
:rtype: Graph
"""
edge_attr = cls._create_edge_attr(
@@ -271,42 +284,46 @@ class GraphBuilder:
class RadiusGraph(GraphBuilder):
"""
A class to build a radius graph.
A class to build a graph based on a radius.
"""
def __new__(cls, pos, radius, **kwargs):
"""
Extends the `GraphBuilder` class to compute edge_index based on a
radius. Each point is connected to all the points within the radius.
Extends the :class:`pina.graph.GraphBuilder` class to compute
edge_index based on a radius. Each point is connected to all the points
within the radius.
:param pos: A tensor of shape (N, D) representing the positions of N
points in D-dimensional space.
:param pos: A tensor of shape `(N, D)` representing the positions of `N`
points in `D`-dimensional space.
:type pos: torch.Tensor or LabelTensor
:param radius: The radius within which points are connected.
:type radius: float
:param kwargs: Additional keyword arguments to be passed to the
`GraphBuilder` and `Graph` constructors.
:return: A `Graph` instance containing the input information and the
computed edge_index.
:class:`pina.graph.GraphBuilder` and :class:`pina.graph.Graph`
constructors.
:return: A :class:`pina.graph.Graph` instance containing the input
information and the computed edge_index.
:rtype: Graph
"""
edge_index = cls.compute_radius_graph(pos, radius)
return super().__new__(cls, pos=pos, edge_index=edge_index, **kwargs)
@staticmethod
def compute_radius_graph(points, radius):
"""
Computes edge_index for a given set of points base on the radius.
Computes `edge_index` for a given set of points base on the radius.
Each point is connected to all the points within the radius.
:param points: A tensor of shape (N, D) representing the positions of
:param points: A tensor of shape `(N, D)` representing the positions of
N points in D-dimensional space.
:type points: torch.Tensor | LabelTensor
:param float radius: The number of nearest neighbors to find for each
point.
:rtype torch.Tensor: A tensor of shape (2, E), where E is the number of
edges, representing the edge indices of the KNN graph.
:rtype torch.Tensor: A tensor of shape `(2, E)`, where `E` is the number
of edges, representing the edge indices of the KNN graph.
"""
dist = torch.cdist(points, points, p=2)
return (
torch.nonzero(dist <= radius, as_tuple=False)
@@ -317,13 +334,13 @@ class RadiusGraph(GraphBuilder):
class KNNGraph(GraphBuilder):
"""
A class to build a KNN graph.
A class to build a K-nearest neighbors graph.
"""
def __new__(cls, pos, neighbours, **kwargs):
"""
Creates a new instance of the Graph class using k-nearest neighbors
algorithm to define the edges.
Extends the :class:`pina.graph.GraphBuilder` class to compute edge_index
based on a K-nearest neighbors algorithm.
:param pos: A tensor of shape (N, D) representing the positions of N
points in D-dimensional space.
@@ -334,8 +351,8 @@ class KNNGraph(GraphBuilder):
The additional keyword arguments to be passed to GraphBuilder
and Graph classes
:return: Graph instance containg the information passed in input and
the computed edge_index
:return: A :class:`pina.graph.Graph` instance containg the
information passed in input and the computed edge_index
:rtype: Graph
"""
@@ -371,7 +388,8 @@ class LabelBatch(Batch):
@classmethod
def from_data_list(cls, data_list):
"""
Create a Batch object from a list of Data objects.
Create a Batch object from a list of :class:`~torch_geometric.data.Data`
objects.
:param data_list: List of Data/Graph objects
:type data_list: list[Data] | list[Graph]

View File

@@ -14,13 +14,15 @@ class LabelTensor(torch.Tensor):
@staticmethod
def __new__(cls, x, labels, *args, **kwargs):
"""
Create a new instance of the :class:`LabelTensor` class.
Create a new instance of the :class:`pina.label_tensor.LabelTensor`
class.
:param torch.Tensor x: :class:`torch.tensor` instance to be casted as a
:class:`LabelTensor`.
:class:`pina.label_tensor.LabelTensor`.
:param labels: Labels to assign to the tensor.
:type labels: str | list[str] | dict
:return: The instance of the :class:`LabelTensor` class.
:return: The instance of the :class:`pina.label_tensor.LabelTensor`
class.
:rtype: LabelTensor
"""
@@ -31,9 +33,10 @@ class LabelTensor(torch.Tensor):
@property
def tensor(self):
"""
Give the tensor part of the :class:`LabelTensor` object.
Give the tensor part of the :class:`pina.label_tensor.LabelTensor`
object.
:return: tensor part of the :class:`LabelTensor`.
:return: tensor part of the :class:`pina.label_tensor.LabelTensor`.
:rtype: torch.Tensor
"""
@@ -41,9 +44,10 @@ class LabelTensor(torch.Tensor):
def __init__(self, x, labels):
"""
Construct a :class:`LabelTensor` by passing a dict of the labels and a
:class:`torch.Tensor`. Internally, the initialization method will store
check the compatibility of the labels with the tensor shape.
Construct a :class:`pina.label_tensor.LabelTensor` by passing a dict of
the labels and a :class:`torch.Tensor`. Internally, the initialization
method will store check the compatibility of the labels with the tensor
shape.
:Example:
>>> from pina import LabelTensor
@@ -271,9 +275,10 @@ class LabelTensor(torch.Tensor):
def __str__(self):
"""
The string representation of the :class:`LabelTensor`.
The string representation of the :class:`pina.label_tensor.LabelTensor`.
:return: String representation of the :class:`LabelTensor` instance.
:return: String representation of the
:class:`pina.label_tensor.LabelTensor` instance.
:rtype: str
"""
@@ -290,8 +295,8 @@ class LabelTensor(torch.Tensor):
Concatenate a list of tensors along a specified dimension. For more
details, see :meth:`torch.cat`.
:param list[LabelTensor] tensors: :class:`LabelTensor` instances to
concatenate
:param list[LabelTensor] tensors: :class:`pina.label_tensor.LabelTensor`
instances to concatenate
:param int dim: dimensions on which you want to perform the operation
(default is 0)
:return: A new :class:`LabelTensor' instance obtained by concatenating
@@ -346,8 +351,8 @@ class LabelTensor(torch.Tensor):
:param list[LabelTensor] tensors: A list of tensors to stack.
All tensors must have the same shape.
:return: A new :class:`LabelTensor` instance obtained by stacking the
input tensors, with the updated labels.
:return: A new :class:`pina.label_tensor.LabelTensor` instance obtained
by stacking the input tensors, with the updated labels.
:rtype: LabelTensor
"""
@@ -368,8 +373,8 @@ class LabelTensor(torch.Tensor):
:param bool mode: A boolean value indicating whether the tensor should
track gradients.If `True`, the tensor will track gradients;
if `False`, it will not.
:return: The :class:`LabelTensor` itself with the updated
`requires_grad` state and retained labels.
:return: The :class:`pina.label_tensor.LabelTensor` itself with the
updated `requires_grad` state and retained labels.
:rtype: LabelTensor
"""
@@ -394,8 +399,8 @@ class LabelTensor(torch.Tensor):
Performs Tensor dtype and/or device conversion. For more details, see
:meth:`torch.Tensor.to`.
:return: A new :class:`LabelTensor` instance with the updated dtype
and/or device and retained labels.
:return: A new :class:`pina.label_tensor.LabelTensor` instance with the
updated dtype and/or device and retained labels.
:rtype: LabelTensor
"""
@@ -405,11 +410,11 @@ class LabelTensor(torch.Tensor):
def clone(self, *args, **kwargs):
"""
Clone the :class:`LabelTensor`. For more details, see
Clone the :class:`pina.label_tensor.LabelTensor`. For more details, see
:meth:`torch.Tensor.clone`.
:return: A new :class:`LabelTensor` instance with the same data and
labels but allocated in a different memory location.
:return: A new :class:`pina.label_tensor.LabelTensor` instance with the
same data and labels but allocated in a different memory location.
:rtype: LabelTensor
"""
@@ -466,10 +471,11 @@ class LabelTensor(torch.Tensor):
"""
Stack tensors vertically. For more details, see :meth:`torch.vstack`.
:param list of LabelTensor label_tensors: The :class:`LabelTensor`
instances to stack. They need to have equal labels.
:return: A new :class:`LabelTensor` instance obtained by stacking the
input tensors vertically.
:param list of LabelTensor label_tensors: The
:class:`pina.label_tensor.LabelTensor` instances to stack. They need
to have equal labels.
:return: A new :class:`pina.label_tensor.LabelTensor` instance obtained
by stacking the input tensors vertically.
:rtype: LabelTensor
"""
@@ -524,15 +530,15 @@ class LabelTensor(torch.Tensor):
def __getitem__(self, index):
""" "
Override the __getitem__ method to handle the labels of the
:class:`LabelTensor` instance. It first performs __getitem__ operation
on the :class:`torch.Tensor` part of the instance, then updates the
labels based on the index.
:class:`pina.label_tensor.LabelTensor` instance. It first performs
__getitem__ operation on the :class:`torch.Tensor` part of the instance,
then updates the labels based on the index.
:param index: The index used to access the item
:type index: int | str | tuple of int | list ot int | torch.Tensor
:return: A new :class:`LabelTensor` instance obtained __getitem__
operation on :class:`torch.Tensor` part of the instance, with the
updated labels.
:return: A new :class:`pina.label_tensor.LabelTensor` instance obtained
`__getitem__` operation on :class:`torch.Tensor` part of the
instance, with the updated labels.
:rtype: LabelTensor
:raises KeyError: If an invalid label index is provided.
@@ -665,7 +671,8 @@ class LabelTensor(torch.Tensor):
@staticmethod
def summation(tensors):
"""
Computes the summation of a list of :class:`LabelTensor` instances.
Computes the summation of a list of
:class:`pina.label_tensor.LabelTensor` instances.
:param list[LabelTensor] tensors: A list of tensors to sum. All
@@ -712,8 +719,8 @@ class LabelTensor(torch.Tensor):
For more details, see :meth:`torch.Tensor.reshape`.
:param tuple of int shape: The new shape of the tensor.
:return: A new :class:`LabelTensor` instance with the updated shape and
labels.
:return: A new :class:`pina.label_tensor.LabelTensor` instance with the
updated shape and labels.
:rtype: LabelTensor
"""