Tmp fixes
This commit is contained in:
committed by
Nicola Demo
parent
c164db874b
commit
10ccae3a33
@@ -37,10 +37,11 @@ class Collector:
|
||||
@property
|
||||
def full(self):
|
||||
"""
|
||||
Returns True if the collector is full. The collector is considered full
|
||||
if all conditions have entries in the data_collection dictionary.
|
||||
Returns ``True`` if the collector is full. The collector is considered
|
||||
full if all conditions have entries in the ``data_collection``
|
||||
dictionary.
|
||||
|
||||
:return: True if all conditions are ready, False otherwise.
|
||||
:return: ``True`` if all conditions are ready, ``False`` otherwise.
|
||||
:rtype: bool
|
||||
"""
|
||||
|
||||
@@ -49,9 +50,9 @@ class Collector:
|
||||
@full.setter
|
||||
def full(self, value):
|
||||
"""
|
||||
Set the full variable.
|
||||
Set the ``_full`` variable.
|
||||
|
||||
:param bool value: The value to set the full property to.
|
||||
:param bool value: The value to set the ``_full`` variable.
|
||||
"""
|
||||
|
||||
check_consistency(value, bool)
|
||||
@@ -117,7 +118,8 @@ class Collector:
|
||||
"""
|
||||
Store inside data collections the sampled data of the problem. These
|
||||
comes from the conditions that require sampling (e.g.
|
||||
DomainEquationCondition).
|
||||
:class:`~pina.condition.domain_equation_condition.
|
||||
DomainEquationCondition`).
|
||||
"""
|
||||
|
||||
for condition_name in self.problem.conditions:
|
||||
|
||||
@@ -32,7 +32,7 @@ class Condition:
|
||||
The class ``Condition`` is used to represent the constraints (physical
|
||||
equations, boundary conditions, etc.) that should be satisfied in the
|
||||
problem at hand. Condition objects are used to formulate the
|
||||
PINA :obj:`pina.problem.abstract_problem.AbstractProblem` object.
|
||||
PINA :class:`~pina.problem.abstract_problem.AbstractProblem` object.
|
||||
Conditions can be specified in four ways:
|
||||
|
||||
1. By specifying the input and target of the condition; in such a
|
||||
@@ -61,7 +61,7 @@ class Condition:
|
||||
input, there are different implementations of the condition. For more
|
||||
details, see :class:`~pina.condition.data_condition.DataCondition`.
|
||||
|
||||
Example::
|
||||
:Example:
|
||||
|
||||
>>> from pina import Condition
|
||||
>>> condition = Condition(
|
||||
|
||||
@@ -28,7 +28,7 @@ class DataCondition(ConditionInterface):
|
||||
def __new__(cls, input, conditional_variables=None):
|
||||
"""
|
||||
Instantiate the appropriate subclass of :class:`DataCondition` based on
|
||||
the type of `input`.
|
||||
the type of ``input``.
|
||||
|
||||
:param input: Input data for the condition.
|
||||
:type input: torch.Tensor | LabelTensor | Graph |
|
||||
@@ -72,7 +72,7 @@ class DataCondition(ConditionInterface):
|
||||
:type conditional_variables: torch.Tensor or LabelTensor
|
||||
|
||||
.. note::
|
||||
If either `input` is composed by a list of
|
||||
If either ``input`` is composed by a list of
|
||||
:class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data`,
|
||||
all elements must have the same structure (keys and data
|
||||
types)
|
||||
|
||||
@@ -29,7 +29,7 @@ class InputEquationCondition(ConditionInterface):
|
||||
def __new__(cls, input, equation):
|
||||
"""
|
||||
Instantiate the appropriate subclass of :class:`InputEquationCondition`
|
||||
based on the type of `input`.
|
||||
based on the type of ``input``.
|
||||
|
||||
:param input: Input data for the condition.
|
||||
:type input: LabelTensor | Graph | list[Graph] | tuple[Graph]
|
||||
@@ -74,7 +74,7 @@ class InputEquationCondition(ConditionInterface):
|
||||
equation function.
|
||||
|
||||
.. note::
|
||||
If `input` is composed by a list of :class:`~pina.graph.Graph`
|
||||
If ``input`` is composed by a list of :class:`~pina.graph.Graph`
|
||||
objects, all elements must have the same structure (keys and data
|
||||
types). Moreover, at least one attribute must be a
|
||||
:class:`~pina.label_tensor.LabelTensor`.
|
||||
|
||||
@@ -52,7 +52,7 @@ class InputTargetCondition(ConditionInterface):
|
||||
GraphInputTensorTargetCondition |
|
||||
pina.condition.input_target_condition.GraphInputGraphTargetCondition
|
||||
|
||||
:raises ValueError: If `input` and/or `target` are not of type
|
||||
:raises ValueError: If ``input`` and/or ``target`` are not of type
|
||||
:class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`,
|
||||
:class:`~pina.graph.Graph`, or :class:`~torch_geometric.data.Data`.
|
||||
"""
|
||||
@@ -94,7 +94,7 @@ class InputTargetCondition(ConditionInterface):
|
||||
|
||||
def __init__(self, input, target):
|
||||
"""
|
||||
Initialize the object by storing the `input` and `target` data.
|
||||
Initialize the object by storing the ``input`` and ``target`` data.
|
||||
|
||||
:param input: Input data for the condition.
|
||||
:type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
|
||||
@@ -104,7 +104,7 @@ class InputTargetCondition(ConditionInterface):
|
||||
list[Data] | tuple[Graph] | tuple[Data]
|
||||
|
||||
.. note::
|
||||
If either `input` or `target` are composed by a list of
|
||||
If either ``input`` or ``target`` are composed by a list of
|
||||
:class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data`
|
||||
objects, all elements must have the same structure (keys and data
|
||||
types)
|
||||
@@ -130,14 +130,14 @@ class InputTargetCondition(ConditionInterface):
|
||||
class TensorInputTensorTargetCondition(InputTargetCondition):
|
||||
"""
|
||||
InputTargetCondition subclass for :class:`torch.Tensor` or
|
||||
:class:`~pina.label_tensor.LabelTensor` `input` and `target` data.
|
||||
:class:`~pina.label_tensor.LabelTensor` ``input`` and ``target`` data.
|
||||
"""
|
||||
|
||||
|
||||
class TensorInputGraphTargetCondition(InputTargetCondition):
|
||||
"""
|
||||
InputTargetCondition subclass for :class:`torch.Tensor` or
|
||||
:class:`~pina.label_tensor.LabelTensor` `input` and
|
||||
:class:`~pina.label_tensor.LabelTensor` ``input`` and
|
||||
:class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` `target`
|
||||
data.
|
||||
"""
|
||||
@@ -146,13 +146,13 @@ class TensorInputGraphTargetCondition(InputTargetCondition):
|
||||
class GraphInputTensorTargetCondition(InputTargetCondition):
|
||||
"""
|
||||
InputTargetCondition subclass for :class:`~pina.graph.Graph` o
|
||||
:class:`~torch_geometric.data.Data` `input` and :class:`torch.Tensor` or
|
||||
:class:`~pina.label_tensor.LabelTensor` `target` data.
|
||||
:class:`~torch_geometric.data.Data` ``input`` and :class:`torch.Tensor` or
|
||||
:class:`~pina.label_tensor.LabelTensor` ``target`` data.
|
||||
"""
|
||||
|
||||
|
||||
class GraphInputGraphTargetCondition(InputTargetCondition):
|
||||
"""
|
||||
InputTargetCondition subclass for :class:`~pina.graph.Graph`/
|
||||
:class:`~torch_geometric.data.Data` `input` and `target` data.
|
||||
:class:`~torch_geometric.data.Data` ``input`` and ``target`` data.
|
||||
"""
|
||||
|
||||
@@ -17,7 +17,7 @@ from ..collector import Collector
|
||||
|
||||
class DummyDataloader:
|
||||
"""
|
||||
Dataloader used when batch size is `None`. It returns the entire dataset
|
||||
Dataloader used when batch size is ``None``. It returns the entire dataset
|
||||
in a single batch.
|
||||
"""
|
||||
|
||||
@@ -38,7 +38,7 @@ class DummyDataloader:
|
||||
:param dataset: The dataset object to be processed.
|
||||
:type dataset: PinaDataset
|
||||
|
||||
.. note:: This data loader is used when the batch size is `None`.
|
||||
.. note:: This data loader is used when the batch size is ``None``.
|
||||
"""
|
||||
|
||||
if (
|
||||
@@ -273,7 +273,7 @@ class PinaDataModule(LightningDataModule):
|
||||
be in the range [0, 1].
|
||||
:param float val_size: Fraction of elements in the validation split. It
|
||||
must be in the range [0, 1].
|
||||
:param batch_size: The batch size used for training. If `None`, the
|
||||
:param batch_size: The batch size used for training. If ``None``, the
|
||||
entire dataset is returned in a single batch.
|
||||
:type batch_size: int | None
|
||||
:param bool shuffle: Whether to shuffle the dataset before splitting.
|
||||
|
||||
@@ -33,8 +33,8 @@ class PinaDatasetFactory:
|
||||
:param dict conditions_dict: Dictionary containing all the conditions
|
||||
to be included in the dataset instance.
|
||||
:return: A subclass of :class:`~pina.data.dataset.PinaDataset`.
|
||||
:rtype: :class:`~pina.data.dataset.PinaTensorDataset` |
|
||||
:class:`~pina.data.dataset.PinaGraphDataset`
|
||||
:rtype: pina.data.dataset.PinaTensorDataset |
|
||||
pina.data.dataset.PinaGraphDataset
|
||||
|
||||
:raises ValueError: If an empty dictionary is provided.
|
||||
"""
|
||||
@@ -255,7 +255,7 @@ class PinaGraphDataset(PinaDataset):
|
||||
|
||||
def _create_tensor_batch(self, data):
|
||||
"""
|
||||
Reshape properly `data` tensor to be processed handle by the graph
|
||||
Reshape properly ``data`` tensor to be processed handle by the graph
|
||||
based models.
|
||||
|
||||
:param data: torch.Tensor object of shape (N, ...) where N is the
|
||||
|
||||
@@ -29,7 +29,6 @@ class Graph(Data):
|
||||
:return: A new instance of the :class:`~pina.graph.Graph` class.
|
||||
:rtype: Graph
|
||||
"""
|
||||
|
||||
# create class instance
|
||||
instance = Data.__new__(cls)
|
||||
|
||||
@@ -54,23 +53,20 @@ class Graph(Data):
|
||||
the graph undirected if required. For more details, see the
|
||||
:meth:`torch_geometric.data.Data`
|
||||
|
||||
:param x: Optional tensor of node features `(N, F)` where `F` is the
|
||||
:param x: Optional tensor of node features ``(N, F)`` where ``F`` is the
|
||||
number of features per node.
|
||||
:type x: torch.Tensor, LabelTensor
|
||||
:param torch.Tensor edge_index: A tensor of shape `(2, E)` representing
|
||||
:param torch.Tensor edge_index: A tensor of shape ``(2, E)`` representing
|
||||
the indices of the graph's edges.
|
||||
:param pos: A tensor of shape `(N, D)` representing the positions of `N`
|
||||
points in `D`-dimensional space.
|
||||
:param pos: A tensor of shape ``(N, D)`` representing the positions of
|
||||
``N`` points in ``D``-dimensional space.
|
||||
:type pos: torch.Tensor | LabelTensor
|
||||
:param edge_attr: Optional tensor of edge_featured `(E, F')` where `F'`
|
||||
is the number of edge features
|
||||
:param edge_attr: Optional tensor of edge_featured ``(E, F')`` where
|
||||
``F'`` is the number of edge features
|
||||
:param bool undirected: Whether to make the graph undirected
|
||||
:param kwargs: Additional keyword arguments passed to the
|
||||
`torch_geometric.data.Data` class constructor. If the argument
|
||||
is a `torch.Tensor` or `LabelTensor`, it is included in the graph
|
||||
object.
|
||||
:class:`~torch_geometric.data.Data` class constructor.
|
||||
"""
|
||||
|
||||
# preprocessing
|
||||
self._preprocess_edge_index(edge_index, undirected)
|
||||
|
||||
@@ -86,7 +82,6 @@ class Graph(Data):
|
||||
:param kwargs: Attributes to be checked for consistency.
|
||||
:type kwargs: dict
|
||||
"""
|
||||
|
||||
# default types, specified in cls.__new__, by default they are Nont
|
||||
# if specified in **kwargs they get override
|
||||
x, pos, edge_index, edge_attr = None, None, None, None
|
||||
@@ -110,12 +105,9 @@ class Graph(Data):
|
||||
def _check_pos_consistency(pos):
|
||||
"""
|
||||
Check if the position tensor is consistent.
|
||||
|
||||
:param torch.Tensor pos: The position tensor.
|
||||
|
||||
:raises ValueError: If the position tensor is not consistent.
|
||||
"""
|
||||
|
||||
if pos is not None:
|
||||
check_consistency(pos, (torch.Tensor, LabelTensor))
|
||||
if pos.ndim != 2:
|
||||
@@ -127,10 +119,8 @@ class Graph(Data):
|
||||
Check if the edge index is consistent.
|
||||
|
||||
:param torch.Tensor edge_index: The edge index tensor.
|
||||
|
||||
:raises ValueError: If the edge index tensor is not consistent.
|
||||
"""
|
||||
|
||||
check_consistency(edge_index, (torch.Tensor, LabelTensor))
|
||||
if edge_index.ndim != 2:
|
||||
raise ValueError("edge_index must be a 2D tensor.")
|
||||
@@ -145,11 +135,8 @@ class Graph(Data):
|
||||
|
||||
:param torch.Tensor edge_attr: The edge attribute tensor.
|
||||
:param torch.Tensor edge_index: The edge index tensor.
|
||||
|
||||
:raises ValueError: If the edge attribute tensor is not consistent.
|
||||
|
||||
"""
|
||||
|
||||
if edge_attr is not None:
|
||||
check_consistency(edge_attr, (torch.Tensor, LabelTensor))
|
||||
if edge_attr.ndim != 2:
|
||||
@@ -171,7 +158,6 @@ class Graph(Data):
|
||||
:param torch.Tensor x: The input tensor.
|
||||
:param torch.Tensor pos: The position tensor.
|
||||
"""
|
||||
|
||||
if x is not None:
|
||||
check_consistency(x, (torch.Tensor, LabelTensor))
|
||||
if x.ndim != 2:
|
||||
@@ -193,7 +179,6 @@ class Graph(Data):
|
||||
:return: The preprocessed edge index.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
|
||||
if undirected:
|
||||
edge_index = to_undirected(edge_index)
|
||||
return edge_index
|
||||
@@ -246,7 +231,7 @@ class GraphBuilder:
|
||||
:type edge_attr: torch.Tensor, optional
|
||||
:param custom_edge_func: A custom function to compute edge attributes.
|
||||
If provided, overrides `edge_attr`.
|
||||
:type custom_edge_func: callable, optional
|
||||
:type custom_edge_func: Callable, optional
|
||||
:param kwargs: Additional keyword arguments passed to the
|
||||
:class:`~pina.graph.Graph` class constructor.
|
||||
:return: A :class:`~pina.graph.Graph` instance constructed using the
|
||||
@@ -266,6 +251,18 @@ class GraphBuilder:
|
||||
|
||||
@staticmethod
|
||||
def _create_edge_attr(pos, edge_index, edge_attr, func):
|
||||
"""
|
||||
Create the edge attributes based on the input parameters.
|
||||
|
||||
:param pos: Positions of the points.
|
||||
:type pos: torch.Tensor | LabelTensor
|
||||
:param torch.Tensor edge_index: Edge indices.
|
||||
:param bool edge_attr: Whether to compute the edge attributes.
|
||||
:param Callable func: Function to compute the edge attributes.
|
||||
:raises ValueError: If ``func`` is not a function.
|
||||
:return: The edge attributes.
|
||||
:rtype: torch.Tensor | LabelTensor | None
|
||||
"""
|
||||
check_consistency(edge_attr, bool)
|
||||
if edge_attr:
|
||||
if is_function(func):
|
||||
@@ -275,6 +272,14 @@ class GraphBuilder:
|
||||
|
||||
@staticmethod
|
||||
def _build_edge_attr(pos, edge_index):
|
||||
"""
|
||||
Default function to compute the edge attributes.
|
||||
:param pos: Positions of the points.
|
||||
:type pos: torch.Tensor | LabelTensor
|
||||
:param torch.Tensor edge_index: Edge indices.
|
||||
:return: The edge attributes.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return (
|
||||
(pos[edge_index[0]] - pos[edge_index[1]])
|
||||
.abs()
|
||||
@@ -293,37 +298,34 @@ class RadiusGraph(GraphBuilder):
|
||||
edge_index based on a radius. Each point is connected to all the points
|
||||
within the radius.
|
||||
|
||||
:param pos: A tensor of shape `(N, D)` representing the positions of `N`
|
||||
points in `D`-dimensional space.
|
||||
:type pos: torch.Tensor or LabelTensor
|
||||
:param radius: The radius within which points are connected.
|
||||
:type radius: float
|
||||
:param pos: A tensor of shape ``(N, D)`` representing the positions of
|
||||
``N`` points in ``D``-dimensional space.
|
||||
:type pos: torch.Tensor | LabelTensor
|
||||
:param float radius: The radius within which points are connected.
|
||||
:param kwargs: Additional keyword arguments to be passed to the
|
||||
:class:`~pina.graph.GraphBuilder` and :class:`~pina.graph.Graph`
|
||||
constructors.
|
||||
:return: A :class:`~pina.graph.Graph` instance containing the input
|
||||
information and the computed edge_index.
|
||||
information and the computed ``edge_index``.
|
||||
:rtype: Graph
|
||||
"""
|
||||
|
||||
edge_index = cls.compute_radius_graph(pos, radius)
|
||||
return super().__new__(cls, pos=pos, edge_index=edge_index, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def compute_radius_graph(points, radius):
|
||||
"""
|
||||
Computes `edge_index` for a given set of points base on the radius.
|
||||
Computes ``edge_index`` for a given set of points base on the radius.
|
||||
Each point is connected to all the points within the radius.
|
||||
|
||||
:param points: A tensor of shape `(N, D)` representing the positions of
|
||||
N points in D-dimensional space.
|
||||
:param points: A tensor of shape ``(N, D)`` representing the positions
|
||||
of ``N`` points in ``D``-dimensional space.
|
||||
:type points: torch.Tensor | LabelTensor
|
||||
:param float radius: The number of nearest neighbors to find for each
|
||||
point.
|
||||
:rtype torch.Tensor: A tensor of shape `(2, E)`, where `E` is the number
|
||||
of edges, representing the edge indices of the KNN graph.
|
||||
:rtype torch.Tensor: A tensor of shape ``(2, E)``, where ``E`` is the
|
||||
number of edges, representing the edge indices of the KNN graph.
|
||||
"""
|
||||
|
||||
dist = torch.cdist(points, points, p=2)
|
||||
return (
|
||||
torch.nonzero(dist <= radius, as_tuple=False)
|
||||
@@ -342,8 +344,8 @@ class KNNGraph(GraphBuilder):
|
||||
Extends the :class:`~pina.graph.GraphBuilder` class to compute
|
||||
edge_index based on a K-nearest neighbors algorithm.
|
||||
|
||||
:param pos: A tensor of shape (N, D) representing the positions of N
|
||||
points in D-dimensional space.
|
||||
:param pos: A tensor of shape ``(N, D)`` representing the positions of
|
||||
``N`` points in ``D``-dimensional space.
|
||||
:type pos: torch.Tensor | LabelTensor
|
||||
:param int neighbours: The number of nearest neighbors to consider when
|
||||
building the graph.
|
||||
@@ -352,7 +354,7 @@ class KNNGraph(GraphBuilder):
|
||||
and Graph classes
|
||||
|
||||
:return: A :class:`~pina.graph.Graph` instance containg the
|
||||
information passed in input and the computed edge_index
|
||||
information passed in input and the computed ``edge_index``
|
||||
:rtype: Graph
|
||||
"""
|
||||
|
||||
@@ -364,11 +366,11 @@ class KNNGraph(GraphBuilder):
|
||||
"""
|
||||
Computes the edge_index based k-nearest neighbors graph algorithm
|
||||
|
||||
:param points: A tensor of shape (N, D) representing the positions of
|
||||
N points in D-dimensional space.
|
||||
:param points: A tensor of shape ``(N, D)`` representing the positions
|
||||
of ``N`` points in ``D``-dimensional space.
|
||||
:type points: torch.Tensor | LabelTensor
|
||||
:param int k: The number of nearest neighbors to find for each point.
|
||||
:return: A tensor of shape (2, E), where E is the number of
|
||||
:return: A tensor of shape ``(2, E)``, where ``E`` is the number of
|
||||
edges, representing the edge indices of the KNN graph.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
@@ -391,7 +393,8 @@ class LabelBatch(Batch):
|
||||
Create a Batch object from a list of :class:`~torch_geometric.data.Data`
|
||||
objects.
|
||||
|
||||
:param data_list: List of Data/Graph objects
|
||||
:param data_list: List of :class:`~torch_geometric.data.Data` or
|
||||
:class:`~pina.graph.Graph` objects.
|
||||
:type data_list: list[Data] | list[Graph]
|
||||
:return: A Batch object containing the data in the list
|
||||
:rtype: Batch
|
||||
|
||||
Reference in New Issue
Block a user