Tmp fixes

This commit is contained in:
FilippoOlivo
2025-03-14 10:20:56 +01:00
parent e780671fd0
commit a3081cc09f
8 changed files with 74 additions and 69 deletions

View File

@@ -37,10 +37,11 @@ class Collector:
@property @property
def full(self): def full(self):
""" """
Returns True if the collector is full. The collector is considered full Returns ``True`` if the collector is full. The collector is considered
if all conditions have entries in the data_collection dictionary. full if all conditions have entries in the ``data_collection``
dictionary.
:return: True if all conditions are ready, False otherwise. :return: ``True`` if all conditions are ready, ``False`` otherwise.
:rtype: bool :rtype: bool
""" """
@@ -49,9 +50,9 @@ class Collector:
@full.setter @full.setter
def full(self, value): def full(self, value):
""" """
Set the full variable. Set the ``_full`` variable.
:param bool value: The value to set the full property to. :param bool value: The value to set the ``_full`` variable.
""" """
check_consistency(value, bool) check_consistency(value, bool)
@@ -117,7 +118,8 @@ class Collector:
""" """
Store inside data collections the sampled data of the problem. These Store inside data collections the sampled data of the problem. These
comes from the conditions that require sampling (e.g. comes from the conditions that require sampling (e.g.
DomainEquationCondition). :class:`~pina.condition.domain_equation_condition.
DomainEquationCondition`).
""" """
for condition_name in self.problem.conditions: for condition_name in self.problem.conditions:

View File

@@ -32,7 +32,7 @@ class Condition:
The class ``Condition`` is used to represent the constraints (physical The class ``Condition`` is used to represent the constraints (physical
equations, boundary conditions, etc.) that should be satisfied in the equations, boundary conditions, etc.) that should be satisfied in the
problem at hand. Condition objects are used to formulate the problem at hand. Condition objects are used to formulate the
PINA :obj:`pina.problem.abstract_problem.AbstractProblem` object. PINA :class:`~pina.problem.abstract_problem.AbstractProblem` object.
Conditions can be specified in four ways: Conditions can be specified in four ways:
1. By specifying the input and target of the condition; in such a 1. By specifying the input and target of the condition; in such a
@@ -61,7 +61,7 @@ class Condition:
input, there are different implementations of the condition. For more input, there are different implementations of the condition. For more
details, see :class:`~pina.condition.data_condition.DataCondition`. details, see :class:`~pina.condition.data_condition.DataCondition`.
Example:: :Example:
>>> from pina import Condition >>> from pina import Condition
>>> condition = Condition( >>> condition = Condition(

View File

@@ -28,7 +28,7 @@ class DataCondition(ConditionInterface):
def __new__(cls, input, conditional_variables=None): def __new__(cls, input, conditional_variables=None):
""" """
Instantiate the appropriate subclass of :class:`DataCondition` based on Instantiate the appropriate subclass of :class:`DataCondition` based on
the type of `input`. the type of ``input``.
:param input: Input data for the condition. :param input: Input data for the condition.
:type input: torch.Tensor | LabelTensor | Graph | :type input: torch.Tensor | LabelTensor | Graph |
@@ -72,7 +72,7 @@ class DataCondition(ConditionInterface):
:type conditional_variables: torch.Tensor or LabelTensor :type conditional_variables: torch.Tensor or LabelTensor
.. note:: .. note::
If either `input` is composed by a list of If either ``input`` is composed by a list of
:class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data`, :class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data`,
all elements must have the same structure (keys and data all elements must have the same structure (keys and data
types) types)

View File

@@ -29,7 +29,7 @@ class InputEquationCondition(ConditionInterface):
def __new__(cls, input, equation): def __new__(cls, input, equation):
""" """
Instantiate the appropriate subclass of :class:`InputEquationCondition` Instantiate the appropriate subclass of :class:`InputEquationCondition`
based on the type of `input`. based on the type of ``input``.
:param input: Input data for the condition. :param input: Input data for the condition.
:type input: LabelTensor | Graph | list[Graph] | tuple[Graph] :type input: LabelTensor | Graph | list[Graph] | tuple[Graph]
@@ -74,7 +74,7 @@ class InputEquationCondition(ConditionInterface):
equation function. equation function.
.. note:: .. note::
If `input` is composed by a list of :class:`~pina.graph.Graph` If ``input`` is composed by a list of :class:`~pina.graph.Graph`
objects, all elements must have the same structure (keys and data objects, all elements must have the same structure (keys and data
types). Moreover, at least one attribute must be a types). Moreover, at least one attribute must be a
:class:`~pina.label_tensor.LabelTensor`. :class:`~pina.label_tensor.LabelTensor`.

View File

@@ -52,7 +52,7 @@ class InputTargetCondition(ConditionInterface):
GraphInputTensorTargetCondition | GraphInputTensorTargetCondition |
pina.condition.input_target_condition.GraphInputGraphTargetCondition pina.condition.input_target_condition.GraphInputGraphTargetCondition
:raises ValueError: If `input` and/or `target` are not of type :raises ValueError: If ``input`` and/or ``target`` are not of type
:class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`, :class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`,
:class:`~pina.graph.Graph`, or :class:`~torch_geometric.data.Data`. :class:`~pina.graph.Graph`, or :class:`~torch_geometric.data.Data`.
""" """
@@ -94,7 +94,7 @@ class InputTargetCondition(ConditionInterface):
def __init__(self, input, target): def __init__(self, input, target):
""" """
Initialize the object by storing the `input` and `target` data. Initialize the object by storing the ``input`` and ``target`` data.
:param input: Input data for the condition. :param input: Input data for the condition.
:type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | :type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
@@ -104,7 +104,7 @@ class InputTargetCondition(ConditionInterface):
list[Data] | tuple[Graph] | tuple[Data] list[Data] | tuple[Graph] | tuple[Data]
.. note:: .. note::
If either `input` or `target` are composed by a list of If either ``input`` or ``target`` are composed by a list of
:class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` :class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data`
objects, all elements must have the same structure (keys and data objects, all elements must have the same structure (keys and data
types) types)
@@ -130,14 +130,14 @@ class InputTargetCondition(ConditionInterface):
class TensorInputTensorTargetCondition(InputTargetCondition): class TensorInputTensorTargetCondition(InputTargetCondition):
""" """
InputTargetCondition subclass for :class:`torch.Tensor` or InputTargetCondition subclass for :class:`torch.Tensor` or
:class:`~pina.label_tensor.LabelTensor` `input` and `target` data. :class:`~pina.label_tensor.LabelTensor` ``input`` and ``target`` data.
""" """
class TensorInputGraphTargetCondition(InputTargetCondition): class TensorInputGraphTargetCondition(InputTargetCondition):
""" """
InputTargetCondition subclass for :class:`torch.Tensor` or InputTargetCondition subclass for :class:`torch.Tensor` or
:class:`~pina.label_tensor.LabelTensor` `input` and :class:`~pina.label_tensor.LabelTensor` ``input`` and
:class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` `target` :class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` `target`
data. data.
""" """
@@ -146,13 +146,13 @@ class TensorInputGraphTargetCondition(InputTargetCondition):
class GraphInputTensorTargetCondition(InputTargetCondition): class GraphInputTensorTargetCondition(InputTargetCondition):
""" """
InputTargetCondition subclass for :class:`~pina.graph.Graph` o InputTargetCondition subclass for :class:`~pina.graph.Graph` o
:class:`~torch_geometric.data.Data` `input` and :class:`torch.Tensor` or :class:`~torch_geometric.data.Data` ``input`` and :class:`torch.Tensor` or
:class:`~pina.label_tensor.LabelTensor` `target` data. :class:`~pina.label_tensor.LabelTensor` ``target`` data.
""" """
class GraphInputGraphTargetCondition(InputTargetCondition): class GraphInputGraphTargetCondition(InputTargetCondition):
""" """
InputTargetCondition subclass for :class:`~pina.graph.Graph`/ InputTargetCondition subclass for :class:`~pina.graph.Graph`/
:class:`~torch_geometric.data.Data` `input` and `target` data. :class:`~torch_geometric.data.Data` ``input`` and ``target`` data.
""" """

View File

@@ -17,7 +17,7 @@ from ..collector import Collector
class DummyDataloader: class DummyDataloader:
""" """
Dataloader used when batch size is `None`. It returns the entire dataset Dataloader used when batch size is ``None``. It returns the entire dataset
in a single batch. in a single batch.
""" """
@@ -38,7 +38,7 @@ class DummyDataloader:
:param dataset: The dataset object to be processed. :param dataset: The dataset object to be processed.
:type dataset: PinaDataset :type dataset: PinaDataset
.. note:: This data loader is used when the batch size is `None`. .. note:: This data loader is used when the batch size is ``None``.
""" """
if ( if (
@@ -273,7 +273,7 @@ class PinaDataModule(LightningDataModule):
be in the range [0, 1]. be in the range [0, 1].
:param float val_size: Fraction of elements in the validation split. It :param float val_size: Fraction of elements in the validation split. It
must be in the range [0, 1]. must be in the range [0, 1].
:param batch_size: The batch size used for training. If `None`, the :param batch_size: The batch size used for training. If ``None``, the
entire dataset is returned in a single batch. entire dataset is returned in a single batch.
:type batch_size: int | None :type batch_size: int | None
:param bool shuffle: Whether to shuffle the dataset before splitting. :param bool shuffle: Whether to shuffle the dataset before splitting.

View File

@@ -33,8 +33,8 @@ class PinaDatasetFactory:
:param dict conditions_dict: Dictionary containing all the conditions :param dict conditions_dict: Dictionary containing all the conditions
to be included in the dataset instance. to be included in the dataset instance.
:return: A subclass of :class:`~pina.data.dataset.PinaDataset`. :return: A subclass of :class:`~pina.data.dataset.PinaDataset`.
:rtype: :class:`~pina.data.dataset.PinaTensorDataset` | :rtype: pina.data.dataset.PinaTensorDataset |
:class:`~pina.data.dataset.PinaGraphDataset` pina.data.dataset.PinaGraphDataset
:raises ValueError: If an empty dictionary is provided. :raises ValueError: If an empty dictionary is provided.
""" """
@@ -255,7 +255,7 @@ class PinaGraphDataset(PinaDataset):
def _create_tensor_batch(self, data): def _create_tensor_batch(self, data):
""" """
Reshape properly `data` tensor to be processed handle by the graph Reshape properly ``data`` tensor to be processed handle by the graph
based models. based models.
:param data: torch.Tensor object of shape (N, ...) where N is the :param data: torch.Tensor object of shape (N, ...) where N is the

View File

@@ -29,7 +29,6 @@ class Graph(Data):
:return: A new instance of the :class:`~pina.graph.Graph` class. :return: A new instance of the :class:`~pina.graph.Graph` class.
:rtype: Graph :rtype: Graph
""" """
# create class instance # create class instance
instance = Data.__new__(cls) instance = Data.__new__(cls)
@@ -54,23 +53,20 @@ class Graph(Data):
the graph undirected if required. For more details, see the the graph undirected if required. For more details, see the
:meth:`torch_geometric.data.Data` :meth:`torch_geometric.data.Data`
:param x: Optional tensor of node features `(N, F)` where `F` is the :param x: Optional tensor of node features ``(N, F)`` where ``F`` is the
number of features per node. number of features per node.
:type x: torch.Tensor, LabelTensor :type x: torch.Tensor, LabelTensor
:param torch.Tensor edge_index: A tensor of shape `(2, E)` representing :param torch.Tensor edge_index: A tensor of shape ``(2, E)`` representing
the indices of the graph's edges. the indices of the graph's edges.
:param pos: A tensor of shape `(N, D)` representing the positions of `N` :param pos: A tensor of shape ``(N, D)`` representing the positions of
points in `D`-dimensional space. ``N`` points in ``D``-dimensional space.
:type pos: torch.Tensor | LabelTensor :type pos: torch.Tensor | LabelTensor
:param edge_attr: Optional tensor of edge_featured `(E, F')` where `F'` :param edge_attr: Optional tensor of edge_featured ``(E, F')`` where
is the number of edge features ``F'`` is the number of edge features
:param bool undirected: Whether to make the graph undirected :param bool undirected: Whether to make the graph undirected
:param kwargs: Additional keyword arguments passed to the :param kwargs: Additional keyword arguments passed to the
`torch_geometric.data.Data` class constructor. If the argument :class:`~torch_geometric.data.Data` class constructor.
is a `torch.Tensor` or `LabelTensor`, it is included in the graph
object.
""" """
# preprocessing # preprocessing
self._preprocess_edge_index(edge_index, undirected) self._preprocess_edge_index(edge_index, undirected)
@@ -86,7 +82,6 @@ class Graph(Data):
:param kwargs: Attributes to be checked for consistency. :param kwargs: Attributes to be checked for consistency.
:type kwargs: dict :type kwargs: dict
""" """
# default types, specified in cls.__new__, by default they are Nont # default types, specified in cls.__new__, by default they are Nont
# if specified in **kwargs they get override # if specified in **kwargs they get override
x, pos, edge_index, edge_attr = None, None, None, None x, pos, edge_index, edge_attr = None, None, None, None
@@ -110,12 +105,9 @@ class Graph(Data):
def _check_pos_consistency(pos): def _check_pos_consistency(pos):
""" """
Check if the position tensor is consistent. Check if the position tensor is consistent.
:param torch.Tensor pos: The position tensor. :param torch.Tensor pos: The position tensor.
:raises ValueError: If the position tensor is not consistent. :raises ValueError: If the position tensor is not consistent.
""" """
if pos is not None: if pos is not None:
check_consistency(pos, (torch.Tensor, LabelTensor)) check_consistency(pos, (torch.Tensor, LabelTensor))
if pos.ndim != 2: if pos.ndim != 2:
@@ -127,10 +119,8 @@ class Graph(Data):
Check if the edge index is consistent. Check if the edge index is consistent.
:param torch.Tensor edge_index: The edge index tensor. :param torch.Tensor edge_index: The edge index tensor.
:raises ValueError: If the edge index tensor is not consistent. :raises ValueError: If the edge index tensor is not consistent.
""" """
check_consistency(edge_index, (torch.Tensor, LabelTensor)) check_consistency(edge_index, (torch.Tensor, LabelTensor))
if edge_index.ndim != 2: if edge_index.ndim != 2:
raise ValueError("edge_index must be a 2D tensor.") raise ValueError("edge_index must be a 2D tensor.")
@@ -145,11 +135,8 @@ class Graph(Data):
:param torch.Tensor edge_attr: The edge attribute tensor. :param torch.Tensor edge_attr: The edge attribute tensor.
:param torch.Tensor edge_index: The edge index tensor. :param torch.Tensor edge_index: The edge index tensor.
:raises ValueError: If the edge attribute tensor is not consistent. :raises ValueError: If the edge attribute tensor is not consistent.
""" """
if edge_attr is not None: if edge_attr is not None:
check_consistency(edge_attr, (torch.Tensor, LabelTensor)) check_consistency(edge_attr, (torch.Tensor, LabelTensor))
if edge_attr.ndim != 2: if edge_attr.ndim != 2:
@@ -171,7 +158,6 @@ class Graph(Data):
:param torch.Tensor x: The input tensor. :param torch.Tensor x: The input tensor.
:param torch.Tensor pos: The position tensor. :param torch.Tensor pos: The position tensor.
""" """
if x is not None: if x is not None:
check_consistency(x, (torch.Tensor, LabelTensor)) check_consistency(x, (torch.Tensor, LabelTensor))
if x.ndim != 2: if x.ndim != 2:
@@ -193,7 +179,6 @@ class Graph(Data):
:return: The preprocessed edge index. :return: The preprocessed edge index.
:rtype: torch.Tensor :rtype: torch.Tensor
""" """
if undirected: if undirected:
edge_index = to_undirected(edge_index) edge_index = to_undirected(edge_index)
return edge_index return edge_index
@@ -246,7 +231,7 @@ class GraphBuilder:
:type edge_attr: torch.Tensor, optional :type edge_attr: torch.Tensor, optional
:param custom_edge_func: A custom function to compute edge attributes. :param custom_edge_func: A custom function to compute edge attributes.
If provided, overrides `edge_attr`. If provided, overrides `edge_attr`.
:type custom_edge_func: callable, optional :type custom_edge_func: Callable, optional
:param kwargs: Additional keyword arguments passed to the :param kwargs: Additional keyword arguments passed to the
:class:`~pina.graph.Graph` class constructor. :class:`~pina.graph.Graph` class constructor.
:return: A :class:`~pina.graph.Graph` instance constructed using the :return: A :class:`~pina.graph.Graph` instance constructed using the
@@ -266,6 +251,18 @@ class GraphBuilder:
@staticmethod @staticmethod
def _create_edge_attr(pos, edge_index, edge_attr, func): def _create_edge_attr(pos, edge_index, edge_attr, func):
"""
Create the edge attributes based on the input parameters.
:param pos: Positions of the points.
:type pos: torch.Tensor | LabelTensor
:param torch.Tensor edge_index: Edge indices.
:param bool edge_attr: Whether to compute the edge attributes.
:param Callable func: Function to compute the edge attributes.
:raises ValueError: If ``func`` is not a function.
:return: The edge attributes.
:rtype: torch.Tensor | LabelTensor | None
"""
check_consistency(edge_attr, bool) check_consistency(edge_attr, bool)
if edge_attr: if edge_attr:
if is_function(func): if is_function(func):
@@ -275,6 +272,14 @@ class GraphBuilder:
@staticmethod @staticmethod
def _build_edge_attr(pos, edge_index): def _build_edge_attr(pos, edge_index):
"""
Default function to compute the edge attributes.
:param pos: Positions of the points.
:type pos: torch.Tensor | LabelTensor
:param torch.Tensor edge_index: Edge indices.
:return: The edge attributes.
:rtype: torch.Tensor
"""
return ( return (
(pos[edge_index[0]] - pos[edge_index[1]]) (pos[edge_index[0]] - pos[edge_index[1]])
.abs() .abs()
@@ -293,37 +298,34 @@ class RadiusGraph(GraphBuilder):
edge_index based on a radius. Each point is connected to all the points edge_index based on a radius. Each point is connected to all the points
within the radius. within the radius.
:param pos: A tensor of shape `(N, D)` representing the positions of `N` :param pos: A tensor of shape ``(N, D)`` representing the positions of
points in `D`-dimensional space. ``N`` points in ``D``-dimensional space.
:type pos: torch.Tensor or LabelTensor :type pos: torch.Tensor | LabelTensor
:param radius: The radius within which points are connected. :param float radius: The radius within which points are connected.
:type radius: float
:param kwargs: Additional keyword arguments to be passed to the :param kwargs: Additional keyword arguments to be passed to the
:class:`~pina.graph.GraphBuilder` and :class:`~pina.graph.Graph` :class:`~pina.graph.GraphBuilder` and :class:`~pina.graph.Graph`
constructors. constructors.
:return: A :class:`~pina.graph.Graph` instance containing the input :return: A :class:`~pina.graph.Graph` instance containing the input
information and the computed edge_index. information and the computed ``edge_index``.
:rtype: Graph :rtype: Graph
""" """
edge_index = cls.compute_radius_graph(pos, radius) edge_index = cls.compute_radius_graph(pos, radius)
return super().__new__(cls, pos=pos, edge_index=edge_index, **kwargs) return super().__new__(cls, pos=pos, edge_index=edge_index, **kwargs)
@staticmethod @staticmethod
def compute_radius_graph(points, radius): def compute_radius_graph(points, radius):
""" """
Computes `edge_index` for a given set of points base on the radius. Computes ``edge_index`` for a given set of points base on the radius.
Each point is connected to all the points within the radius. Each point is connected to all the points within the radius.
:param points: A tensor of shape `(N, D)` representing the positions of :param points: A tensor of shape ``(N, D)`` representing the positions
N points in D-dimensional space. of ``N`` points in ``D``-dimensional space.
:type points: torch.Tensor | LabelTensor :type points: torch.Tensor | LabelTensor
:param float radius: The number of nearest neighbors to find for each :param float radius: The number of nearest neighbors to find for each
point. point.
:rtype torch.Tensor: A tensor of shape `(2, E)`, where `E` is the number :rtype torch.Tensor: A tensor of shape ``(2, E)``, where ``E`` is the
of edges, representing the edge indices of the KNN graph. number of edges, representing the edge indices of the KNN graph.
""" """
dist = torch.cdist(points, points, p=2) dist = torch.cdist(points, points, p=2)
return ( return (
torch.nonzero(dist <= radius, as_tuple=False) torch.nonzero(dist <= radius, as_tuple=False)
@@ -342,8 +344,8 @@ class KNNGraph(GraphBuilder):
Extends the :class:`~pina.graph.GraphBuilder` class to compute Extends the :class:`~pina.graph.GraphBuilder` class to compute
edge_index based on a K-nearest neighbors algorithm. edge_index based on a K-nearest neighbors algorithm.
:param pos: A tensor of shape (N, D) representing the positions of N :param pos: A tensor of shape ``(N, D)`` representing the positions of
points in D-dimensional space. ``N`` points in ``D``-dimensional space.
:type pos: torch.Tensor | LabelTensor :type pos: torch.Tensor | LabelTensor
:param int neighbours: The number of nearest neighbors to consider when :param int neighbours: The number of nearest neighbors to consider when
building the graph. building the graph.
@@ -352,7 +354,7 @@ class KNNGraph(GraphBuilder):
and Graph classes and Graph classes
:return: A :class:`~pina.graph.Graph` instance containg the :return: A :class:`~pina.graph.Graph` instance containg the
information passed in input and the computed edge_index information passed in input and the computed ``edge_index``
:rtype: Graph :rtype: Graph
""" """
@@ -364,11 +366,11 @@ class KNNGraph(GraphBuilder):
""" """
Computes the edge_index based k-nearest neighbors graph algorithm Computes the edge_index based k-nearest neighbors graph algorithm
:param points: A tensor of shape (N, D) representing the positions of :param points: A tensor of shape ``(N, D)`` representing the positions
N points in D-dimensional space. of ``N`` points in ``D``-dimensional space.
:type points: torch.Tensor | LabelTensor :type points: torch.Tensor | LabelTensor
:param int k: The number of nearest neighbors to find for each point. :param int k: The number of nearest neighbors to find for each point.
:return: A tensor of shape (2, E), where E is the number of :return: A tensor of shape ``(2, E)``, where ``E`` is the number of
edges, representing the edge indices of the KNN graph. edges, representing the edge indices of the KNN graph.
:rtype: torch.Tensor :rtype: torch.Tensor
""" """
@@ -391,7 +393,8 @@ class LabelBatch(Batch):
Create a Batch object from a list of :class:`~torch_geometric.data.Data` Create a Batch object from a list of :class:`~torch_geometric.data.Data`
objects. objects.
:param data_list: List of Data/Graph objects :param data_list: List of :class:`~torch_geometric.data.Data` or
:class:`~pina.graph.Graph` objects.
:type data_list: list[Data] | list[Graph] :type data_list: list[Data] | list[Graph]
:return: A Batch object containing the data in the list :return: A Batch object containing the data in the list
:rtype: Batch :rtype: Batch