This commit is contained in:
FilippoOlivo
2025-03-12 12:03:47 +01:00
committed by Nicola Demo
parent 59e6ee595c
commit ae796ce34c
6 changed files with 128 additions and 97 deletions

View File

@@ -32,8 +32,8 @@ class ConditionInterface(metaclass=ABCMeta):
""" """
Set the problem to which the condition is associated. Set the problem to which the condition is associated.
:param value: Problem to which the condition is associated. :param pina.problem.AbstractProblem value: Problem to which the
:type value: pina.problem.AbstractProblem condition is associated.
""" """
self._problem = value self._problem = value

View File

@@ -12,7 +12,7 @@ from ..graph import Graph
class DataCondition(ConditionInterface): class DataCondition(ConditionInterface):
""" """
This condition must be used every time a Unsupervised Loss is needed in This condition must be used every time a Unsupervised Loss is needed in
the Solver. The conditionalvariable can be passed as extra-input when the Solver. The `conditional_variable` can be passed as extra-input when
the model learns a conditional distribution. the model learns a conditional distribution.
""" """
@@ -31,7 +31,8 @@ class DataCondition(ConditionInterface):
:param conditional_variables: Conditional variables for the condition. :param conditional_variables: Conditional variables for the condition.
:type conditional_variables: torch.Tensor | LabelTensor :type conditional_variables: torch.Tensor | LabelTensor
:return: Subclass of DataCondition. :return: Subclass of DataCondition.
:rtype: TensorDataCondition | GraphDataCondition :rtype: pina.condition.data_condition.TensorDataCondition |
pina.condition.data_condition.GraphDataCondition
:raises ValueError: If input is not of type :class:`torch.Tensor`, :raises ValueError: If input is not of type :class:`torch.Tensor`,
:class:`pina.label_tensor.LabelTensor`, :class:`pina.graph.Graph`, :class:`pina.label_tensor.LabelTensor`, :class:`pina.graph.Graph`,

View File

@@ -30,7 +30,9 @@ class InputEquationCondition(ConditionInterface):
:param EquationInterface equation: Equation object containing the :param EquationInterface equation: Equation object containing the
equation function. equation function.
:return: Subclass of InputEquationCondition, based on the input type. :return: Subclass of InputEquationCondition, based on the input type.
:rtype: InputTensorEquationCondition | InputGraphEquationCondition :rtype: pina.condition.input_equation_condition.
InputTensorEquationCondition |
pina.condition.input_equation_condition.InputGraphEquationCondition
:raises ValueError: If input is not of type :raises ValueError: If input is not of type
:class:`pina.label_tensor.LabelTensor`, :class:`pina.graph.Graph`. :class:`pina.label_tensor.LabelTensor`, :class:`pina.graph.Graph`.
@@ -105,7 +107,7 @@ class InputGraphEquationCondition(InputEquationCondition):
in the :class:`pina.graph.Graph` object. in the :class:`pina.graph.Graph` object.
:param input: Input data. :param input: Input data.
:type input: torch.Tensor | Graph | torch_geometric.data.Data :type input: torch.Tensor | Graph | Data
:raises ValueError: If the input data object does not contain at least :raises ValueError: If the input data object does not contain at least
one LabelTensor. one LabelTensor.

View File

@@ -25,23 +25,26 @@ class InputTargetCondition(ConditionInterface):
the types of input and target data. the types of input and target data.
:param input: Input data for the condition. :param input: Input data for the condition.
:type input: torch.Tensor | LabelTensor | Graph | :type input: torch.Tensor | LabelTensor | Graph |
torch_geometric.data.Data | list[Graph] | torch_geometric.data.Data | list[Graph] |
list[torch_geometric.data.Data] | tuple[Graph] | list[torch_geometric.data.Data] | tuple[Graph] |
tuple[torch_geometric.data.Data] tuple[torch_geometric.data.Data]
:param target: Target data for the condition. :param target: Target data for the condition.
:type target: torch.Tensor | LabelTensor | Graph | :type target: torch.Tensor | LabelTensor | Graph |
torch_geometric.data.Data | list[Graph] | torch_geometric.data.Data | list[Graph] |
list[torch_geometric.data.Data] | tuple[Graph] | list[torch_geometric.data.Data] | tuple[Graph] |
tuple[torch_geometric.data.Data] tuple[torch_geometric.data.Data]
:return: Subclass of InputTargetCondition :return: Subclass of InputTargetCondition
:rtype: TensorInputTensorTargetCondition | \ :rtype: pina.condition.input_target_condition.
TensorInputGraphTargetCondition | \ TensorInputTensorTargetCondition |
GraphInputTensorTargetCondition | \ pina.condition.input_target_condition.
GraphInputGraphTargetCondition TensorInputGraphTargetCondition |
pina.condition.input_target_condition.
GraphInputTensorTargetCondition |
pina.condition.input_target_condition.GraphInputGraphTargetCondition
:raises ValueError: If input and or target are not of type :raises ValueError: If input and or target are not of type
:class:`torch.Tensor`, :class:`pina.label_tensor.LabelTensor`, :class:`torch.Tensor`, :class:`pina.label_tensor.LabelTensor`,
:class:`pina.graph.Graph`, or :class:`~torch_geometric.data.Data`. :class:`pina.graph.Graph`, or :class:`~torch_geometric.data.Data`.
""" """
if cls != InputTargetCondition: if cls != InputTargetCondition:

View File

@@ -1,5 +1,5 @@
""" """
This module provides an interface to build torch_geometric.data.Data objects. Module to build Graph objects and perform operations on them.
""" """
import torch import torch
@@ -11,7 +11,8 @@ from .utils import check_consistency, is_function
class Graph(Data): class Graph(Data):
""" """
A class to build torch_geometric.data.Data objects. Extends :class:`~torch_geometric.data.Data` class to include additional
checks and functionlities.
""" """
def __new__( def __new__(
@@ -19,13 +20,16 @@ class Graph(Data):
**kwargs, **kwargs,
): ):
""" """
Instantiates a new instance of the Graph class, performing type Create a new instance of the :class:`pina.graph.Graph` class by checking
consistency checks. the consistency of the input data and storing the attributes.
:param kwargs: Parameters to construct the Graph object. :param kwargs: Parameters used to initialize the
:return: A new instance of the Graph class. :class:`pina.graph.Graph` object.
:type kwargs: dict
:return: A new instance of the :class:`pina.graph.Graph` class.
:rtype: Graph :rtype: Graph
""" """
# create class instance # create class instance
instance = Data.__new__(cls) instance = Data.__new__(cls)
@@ -45,27 +49,28 @@ class Graph(Data):
**kwargs, **kwargs,
): ):
""" """
Initialize the Graph object by setting the node features, edge index, Initialize the object by setting the node features, edge index,
edge attributes, and positions. The edge index is preprocessed to make edge attributes, and positions. The edge index is preprocessed to make
the graph undirected if required. For more details, see the the graph undirected if required. For more details, see the
:meth: `torch_geometric.data.Data` :meth:`torch_geometric.data.Data`
:param x: Optional tensor of node features (N, F) where F is the number :param x: Optional tensor of node features `(N, F)` where `F` is the
of features per node. number of features per node.
:type x: torch.Tensor, LabelTensor :type x: torch.Tensor, LabelTensor
:param torch.Tensor edge_index: A tensor of shape (2, E) representing :param torch.Tensor edge_index: A tensor of shape `(2, E)` representing
the indices of the graph's edges. the indices of the graph's edges.
:param pos: A tensor of shape (N, D) representing the positions of N :param pos: A tensor of shape `(N, D)` representing the positions of `N`
points in D-dimensional space. points in `D`-dimensional space.
:type pos: torch.Tensor | LabelTensor :type pos: torch.Tensor | LabelTensor
:param edge_attr: Optional tensor of edge_featured (E, F') where F' is :param edge_attr: Optional tensor of edge_featured `(E, F')` where `F'`
the number of edge features is the number of edge features
:param bool undirected: Whether to make the graph undirected :param bool undirected: Whether to make the graph undirected
:param kwargs: Additional keyword arguments passed to the :param kwargs: Additional keyword arguments passed to the
`torch_geometric.data.Data` class constructor. If the argument `torch_geometric.data.Data` class constructor. If the argument
is a `torch.Tensor` or `LabelTensor`, it is included in the Data is a `torch.Tensor` or `LabelTensor`, it is included in the graph
object as a graph parameter. object.
""" """
# preprocessing # preprocessing
self._preprocess_edge_index(edge_index, undirected) self._preprocess_edge_index(edge_index, undirected)
@@ -107,6 +112,8 @@ class Graph(Data):
Check if the position tensor is consistent. Check if the position tensor is consistent.
:param torch.Tensor pos: The position tensor. :param torch.Tensor pos: The position tensor.
:raises ValueError: If the position tensor is not consistent.
""" """
if pos is not None: if pos is not None:
@@ -120,6 +127,8 @@ class Graph(Data):
Check if the edge index is consistent. Check if the edge index is consistent.
:param torch.Tensor edge_index: The edge index tensor. :param torch.Tensor edge_index: The edge index tensor.
:raises ValueError: If the edge index tensor is not consistent.
""" """
check_consistency(edge_index, (torch.Tensor, LabelTensor)) check_consistency(edge_index, (torch.Tensor, LabelTensor))
@@ -136,6 +145,9 @@ class Graph(Data):
:param torch.Tensor edge_attr: The edge attribute tensor. :param torch.Tensor edge_attr: The edge attribute tensor.
:param torch.Tensor edge_index: The edge index tensor. :param torch.Tensor edge_index: The edge index tensor.
:raises ValueError: If the edge attribute tensor is not consistent.
""" """
if edge_attr is not None: if edge_attr is not None:
@@ -217,27 +229,28 @@ class GraphBuilder:
**kwargs, **kwargs,
): ):
""" """
Compute the edge attributes and create a new instance of the Graph Compute the edge attributes and create a new instance of the
class. :class:`pina.graph.Graph` class.
:param pos: A tensor of shape (N, D) representing the positions of N :param pos: A tensor of shape `(N, D)` representing the positions of `N`
points in D-dimensional space. points in `D`-dimensional space.
:type pos: torch.Tensor or LabelTensor :type pos: torch.Tensor or LabelTensor
:param edge_index: A tensor of shape (2, E) representing the indices of :param edge_index: A tensor of shape `(2, E)` representing the indices
the graph's edges. of the graph's edges.
:type edge_index: torch.Tensor :type edge_index: torch.Tensor
:param x: Optional tensor of node features of shape (N, F), where F is :param x: Optional tensor of node features of shape `(N, F)`, where `F`
the number of features per node. is the number of features per node.
:type x: torch.Tensor | LabelTensor, optional :type x: torch.Tensor | LabelTensor, optional
:param edge_attr: Optional tensor of edge attributes of shape (E, F), :param edge_attr: Optional tensor of edge attributes of shape `(E, F)`,
where F is the number of features per edge. where `F` is the number of features per edge.
:type edge_attr: torch.Tensor, optional :type edge_attr: torch.Tensor, optional
:param custom_edge_func: A custom function to compute edge attributes. :param custom_edge_func: A custom function to compute edge attributes.
If provided, overrides `edge_attr`. If provided, overrides `edge_attr`.
:type custom_edge_func: callable, optional :type custom_edge_func: callable, optional
:param kwargs: Additional keyword arguments passed to the Graph class :param kwargs: Additional keyword arguments passed to the
constructor. :class:`pina.graph.Graph` class constructor.
:return: A Graph instance constructed using the provided information. :return: A :class:`pina.graph.Graph` instance constructed using the
provided information.
:rtype: Graph :rtype: Graph
""" """
edge_attr = cls._create_edge_attr( edge_attr = cls._create_edge_attr(
@@ -271,42 +284,46 @@ class GraphBuilder:
class RadiusGraph(GraphBuilder): class RadiusGraph(GraphBuilder):
""" """
A class to build a radius graph. A class to build a graph based on a radius.
""" """
def __new__(cls, pos, radius, **kwargs): def __new__(cls, pos, radius, **kwargs):
""" """
Extends the `GraphBuilder` class to compute edge_index based on a Extends the :class:`pina.graph.GraphBuilder` class to compute
radius. Each point is connected to all the points within the radius. edge_index based on a radius. Each point is connected to all the points
within the radius.
:param pos: A tensor of shape (N, D) representing the positions of N :param pos: A tensor of shape `(N, D)` representing the positions of `N`
points in D-dimensional space. points in `D`-dimensional space.
:type pos: torch.Tensor or LabelTensor :type pos: torch.Tensor or LabelTensor
:param radius: The radius within which points are connected. :param radius: The radius within which points are connected.
:type radius: float :type radius: float
:param kwargs: Additional keyword arguments to be passed to the :param kwargs: Additional keyword arguments to be passed to the
`GraphBuilder` and `Graph` constructors. :class:`pina.graph.GraphBuilder` and :class:`pina.graph.Graph`
:return: A `Graph` instance containing the input information and the constructors.
computed edge_index. :return: A :class:`pina.graph.Graph` instance containing the input
information and the computed edge_index.
:rtype: Graph :rtype: Graph
""" """
edge_index = cls.compute_radius_graph(pos, radius) edge_index = cls.compute_radius_graph(pos, radius)
return super().__new__(cls, pos=pos, edge_index=edge_index, **kwargs) return super().__new__(cls, pos=pos, edge_index=edge_index, **kwargs)
@staticmethod @staticmethod
def compute_radius_graph(points, radius): def compute_radius_graph(points, radius):
""" """
Computes edge_index for a given set of points base on the radius. Computes `edge_index` for a given set of points base on the radius.
Each point is connected to all the points within the radius. Each point is connected to all the points within the radius.
:param points: A tensor of shape (N, D) representing the positions of :param points: A tensor of shape `(N, D)` representing the positions of
N points in D-dimensional space. N points in D-dimensional space.
:type points: torch.Tensor | LabelTensor :type points: torch.Tensor | LabelTensor
:param float radius: The number of nearest neighbors to find for each :param float radius: The number of nearest neighbors to find for each
point. point.
:rtype torch.Tensor: A tensor of shape (2, E), where E is the number of :rtype torch.Tensor: A tensor of shape `(2, E)`, where `E` is the number
edges, representing the edge indices of the KNN graph. of edges, representing the edge indices of the KNN graph.
""" """
dist = torch.cdist(points, points, p=2) dist = torch.cdist(points, points, p=2)
return ( return (
torch.nonzero(dist <= radius, as_tuple=False) torch.nonzero(dist <= radius, as_tuple=False)
@@ -317,13 +334,13 @@ class RadiusGraph(GraphBuilder):
class KNNGraph(GraphBuilder): class KNNGraph(GraphBuilder):
""" """
A class to build a KNN graph. A class to build a K-nearest neighbors graph.
""" """
def __new__(cls, pos, neighbours, **kwargs): def __new__(cls, pos, neighbours, **kwargs):
""" """
Creates a new instance of the Graph class using k-nearest neighbors Extends the :class:`pina.graph.GraphBuilder` class to compute edge_index
algorithm to define the edges. based on a K-nearest neighbors algorithm.
:param pos: A tensor of shape (N, D) representing the positions of N :param pos: A tensor of shape (N, D) representing the positions of N
points in D-dimensional space. points in D-dimensional space.
@@ -334,8 +351,8 @@ class KNNGraph(GraphBuilder):
The additional keyword arguments to be passed to GraphBuilder The additional keyword arguments to be passed to GraphBuilder
and Graph classes and Graph classes
:return: Graph instance containg the information passed in input and :return: A :class:`pina.graph.Graph` instance containg the
the computed edge_index information passed in input and the computed edge_index
:rtype: Graph :rtype: Graph
""" """
@@ -371,7 +388,8 @@ class LabelBatch(Batch):
@classmethod @classmethod
def from_data_list(cls, data_list): def from_data_list(cls, data_list):
""" """
Create a Batch object from a list of Data objects. Create a Batch object from a list of :class:`~torch_geometric.data.Data`
objects.
:param data_list: List of Data/Graph objects :param data_list: List of Data/Graph objects
:type data_list: list[Data] | list[Graph] :type data_list: list[Data] | list[Graph]

View File

@@ -14,13 +14,15 @@ class LabelTensor(torch.Tensor):
@staticmethod @staticmethod
def __new__(cls, x, labels, *args, **kwargs): def __new__(cls, x, labels, *args, **kwargs):
""" """
Create a new instance of the :class:`LabelTensor` class. Create a new instance of the :class:`pina.label_tensor.LabelTensor`
class.
:param torch.Tensor x: :class:`torch.tensor` instance to be casted as a :param torch.Tensor x: :class:`torch.tensor` instance to be casted as a
:class:`LabelTensor`. :class:`pina.label_tensor.LabelTensor`.
:param labels: Labels to assign to the tensor. :param labels: Labels to assign to the tensor.
:type labels: str | list[str] | dict :type labels: str | list[str] | dict
:return: The instance of the :class:`LabelTensor` class. :return: The instance of the :class:`pina.label_tensor.LabelTensor`
class.
:rtype: LabelTensor :rtype: LabelTensor
""" """
@@ -31,9 +33,10 @@ class LabelTensor(torch.Tensor):
@property @property
def tensor(self): def tensor(self):
""" """
Give the tensor part of the :class:`LabelTensor` object. Give the tensor part of the :class:`pina.label_tensor.LabelTensor`
object.
:return: tensor part of the :class:`LabelTensor`. :return: tensor part of the :class:`pina.label_tensor.LabelTensor`.
:rtype: torch.Tensor :rtype: torch.Tensor
""" """
@@ -41,9 +44,10 @@ class LabelTensor(torch.Tensor):
def __init__(self, x, labels): def __init__(self, x, labels):
""" """
Construct a :class:`LabelTensor` by passing a dict of the labels and a Construct a :class:`pina.label_tensor.LabelTensor` by passing a dict of
:class:`torch.Tensor`. Internally, the initialization method will store the labels and a :class:`torch.Tensor`. Internally, the initialization
check the compatibility of the labels with the tensor shape. method will store check the compatibility of the labels with the tensor
shape.
:Example: :Example:
>>> from pina import LabelTensor >>> from pina import LabelTensor
@@ -271,9 +275,10 @@ class LabelTensor(torch.Tensor):
def __str__(self): def __str__(self):
""" """
The string representation of the :class:`LabelTensor`. The string representation of the :class:`pina.label_tensor.LabelTensor`.
:return: String representation of the :class:`LabelTensor` instance. :return: String representation of the
:class:`pina.label_tensor.LabelTensor` instance.
:rtype: str :rtype: str
""" """
@@ -290,8 +295,8 @@ class LabelTensor(torch.Tensor):
Concatenate a list of tensors along a specified dimension. For more Concatenate a list of tensors along a specified dimension. For more
details, see :meth:`torch.cat`. details, see :meth:`torch.cat`.
:param list[LabelTensor] tensors: :class:`LabelTensor` instances to :param list[LabelTensor] tensors: :class:`pina.label_tensor.LabelTensor`
concatenate instances to concatenate
:param int dim: dimensions on which you want to perform the operation :param int dim: dimensions on which you want to perform the operation
(default is 0) (default is 0)
:return: A new :class:`LabelTensor' instance obtained by concatenating :return: A new :class:`LabelTensor' instance obtained by concatenating
@@ -346,8 +351,8 @@ class LabelTensor(torch.Tensor):
:param list[LabelTensor] tensors: A list of tensors to stack. :param list[LabelTensor] tensors: A list of tensors to stack.
All tensors must have the same shape. All tensors must have the same shape.
:return: A new :class:`LabelTensor` instance obtained by stacking the :return: A new :class:`pina.label_tensor.LabelTensor` instance obtained
input tensors, with the updated labels. by stacking the input tensors, with the updated labels.
:rtype: LabelTensor :rtype: LabelTensor
""" """
@@ -368,8 +373,8 @@ class LabelTensor(torch.Tensor):
:param bool mode: A boolean value indicating whether the tensor should :param bool mode: A boolean value indicating whether the tensor should
track gradients.If `True`, the tensor will track gradients; track gradients.If `True`, the tensor will track gradients;
if `False`, it will not. if `False`, it will not.
:return: The :class:`LabelTensor` itself with the updated :return: The :class:`pina.label_tensor.LabelTensor` itself with the
`requires_grad` state and retained labels. updated `requires_grad` state and retained labels.
:rtype: LabelTensor :rtype: LabelTensor
""" """
@@ -394,8 +399,8 @@ class LabelTensor(torch.Tensor):
Performs Tensor dtype and/or device conversion. For more details, see Performs Tensor dtype and/or device conversion. For more details, see
:meth:`torch.Tensor.to`. :meth:`torch.Tensor.to`.
:return: A new :class:`LabelTensor` instance with the updated dtype :return: A new :class:`pina.label_tensor.LabelTensor` instance with the
and/or device and retained labels. updated dtype and/or device and retained labels.
:rtype: LabelTensor :rtype: LabelTensor
""" """
@@ -405,11 +410,11 @@ class LabelTensor(torch.Tensor):
def clone(self, *args, **kwargs): def clone(self, *args, **kwargs):
""" """
Clone the :class:`LabelTensor`. For more details, see Clone the :class:`pina.label_tensor.LabelTensor`. For more details, see
:meth:`torch.Tensor.clone`. :meth:`torch.Tensor.clone`.
:return: A new :class:`LabelTensor` instance with the same data and :return: A new :class:`pina.label_tensor.LabelTensor` instance with the
labels but allocated in a different memory location. same data and labels but allocated in a different memory location.
:rtype: LabelTensor :rtype: LabelTensor
""" """
@@ -466,10 +471,11 @@ class LabelTensor(torch.Tensor):
""" """
Stack tensors vertically. For more details, see :meth:`torch.vstack`. Stack tensors vertically. For more details, see :meth:`torch.vstack`.
:param list of LabelTensor label_tensors: The :class:`LabelTensor` :param list of LabelTensor label_tensors: The
instances to stack. They need to have equal labels. :class:`pina.label_tensor.LabelTensor` instances to stack. They need
:return: A new :class:`LabelTensor` instance obtained by stacking the to have equal labels.
input tensors vertically. :return: A new :class:`pina.label_tensor.LabelTensor` instance obtained
by stacking the input tensors vertically.
:rtype: LabelTensor :rtype: LabelTensor
""" """
@@ -524,15 +530,15 @@ class LabelTensor(torch.Tensor):
def __getitem__(self, index): def __getitem__(self, index):
""" " """ "
Override the __getitem__ method to handle the labels of the Override the __getitem__ method to handle the labels of the
:class:`LabelTensor` instance. It first performs __getitem__ operation :class:`pina.label_tensor.LabelTensor` instance. It first performs
on the :class:`torch.Tensor` part of the instance, then updates the __getitem__ operation on the :class:`torch.Tensor` part of the instance,
labels based on the index. then updates the labels based on the index.
:param index: The index used to access the item :param index: The index used to access the item
:type index: int | str | tuple of int | list ot int | torch.Tensor :type index: int | str | tuple of int | list ot int | torch.Tensor
:return: A new :class:`LabelTensor` instance obtained __getitem__ :return: A new :class:`pina.label_tensor.LabelTensor` instance obtained
operation on :class:`torch.Tensor` part of the instance, with the `__getitem__` operation on :class:`torch.Tensor` part of the
updated labels. instance, with the updated labels.
:rtype: LabelTensor :rtype: LabelTensor
:raises KeyError: If an invalid label index is provided. :raises KeyError: If an invalid label index is provided.
@@ -665,7 +671,8 @@ class LabelTensor(torch.Tensor):
@staticmethod @staticmethod
def summation(tensors): def summation(tensors):
""" """
Computes the summation of a list of :class:`LabelTensor` instances. Computes the summation of a list of
:class:`pina.label_tensor.LabelTensor` instances.
:param list[LabelTensor] tensors: A list of tensors to sum. All :param list[LabelTensor] tensors: A list of tensors to sum. All
@@ -712,8 +719,8 @@ class LabelTensor(torch.Tensor):
For more details, see :meth:`torch.Tensor.reshape`. For more details, see :meth:`torch.Tensor.reshape`.
:param tuple of int shape: The new shape of the tensor. :param tuple of int shape: The new shape of the tensor.
:return: A new :class:`LabelTensor` instance with the updated shape and :return: A new :class:`pina.label_tensor.LabelTensor` instance with the
labels. updated shape and labels.
:rtype: LabelTensor :rtype: LabelTensor
""" """