Other fixes

This commit is contained in:
FilippoOlivo
2025-03-12 12:29:18 +01:00
committed by Nicola Demo
parent ae796ce34c
commit f587c3bf65
9 changed files with 85 additions and 82 deletions

View File

@@ -18,7 +18,8 @@ class Collector:
and the problem and initializing the data collections (dictionary where
data will be stored).
:param AbstractProblem problem: The problem to collect data from.
:param pina.problem.abstract_problem.AbstractProblem problem: The
problem to collect data from.
"""
# creating a hook between collector and problem
self.problem = problem
@@ -71,16 +72,17 @@ class Collector:
Problem connected to the collector.
:return: The problem from which the data is collected.
:rtype: AbstractProblem
:rtype: pina.problem.abstract_problem.AbstractProblem
"""
return self._problem
@problem.setter
def problem(self, value):
"""
Ser the problem connected to the collector.
Set the problem connected to the collector.
:param AbstractProblem value: The problem to connect to the collector.
:param pina.problem.abstract_problem.AbstractProblem value: The problem
to connect to the collector.
"""
self._problem = value

View File

@@ -89,15 +89,15 @@ class Condition:
Create a new condition object based on the keyword arguments passed.
- `input` and `target`:
:class:`pina.condition.input_target_condition.InputTargetCondition`
:class:`~pina.condition.input_target_condition.InputTargetCondition`
- `domain` and `equation`:
:class:`pina.condition.domain_equation_condition.
:class:`~pina.condition.domain_equation_condition.
DomainEquationCondition`
- `input` and `equation`: :class:`pina.condition.
- `input` and `equation`: :class:`~pina.condition.
input_equation_condition.InputEquationCondition`
- `input`: :class:`pina.condition.data_condition.DataCondition`
- `input`: :class:`~pina.condition.data_condition.DataCondition`
- `input` and `conditional_variables`:
:class:`pina.condition.data_condition.DataCondition`
:class:`~pina.condition.data_condition.DataCondition`
:raises ValueError: No valid condition has been found.
:return: A new condition instance belonging to the proper class.

View File

@@ -35,7 +35,7 @@ class DataCondition(ConditionInterface):
pina.condition.data_condition.GraphDataCondition
:raises ValueError: If input is not of type :class:`torch.Tensor`,
:class:`pina.label_tensor.LabelTensor`, :class:`pina.graph.Graph`,
:class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`,
or :class:`~torch_geometric.data.Data`.
@@ -68,7 +68,7 @@ class DataCondition(ConditionInterface):
:type conditional_variables: torch.Tensor or LabelTensor
.. note::
If either `input` is composed by a list of :class:`pina.graph.Graph`
If either `input` is composed by a list of :class:`~pina.graph.Graph`
or :class:`~torch_geometric.data.Data` objects, all elements must
have the same structure (keys and data types)
"""
@@ -80,12 +80,12 @@ class DataCondition(ConditionInterface):
class TensorDataCondition(DataCondition):
"""
DataCondition for :class:`torch.Tensor` or
:class:`pina.label_tensor.LabelTensor` input data
:class:`~pina.label_tensor.LabelTensor` input data
"""
class GraphDataCondition(DataCondition):
"""
DataCondition for :class:`pina.graph.Graph` or
DataCondition for :class:`~pina.graph.Graph` or
:class:`~torch_geometric.data.Data` input data
"""

View File

@@ -35,7 +35,7 @@ class InputEquationCondition(ConditionInterface):
pina.condition.input_equation_condition.InputGraphEquationCondition
:raises ValueError: If input is not of type
:class:`pina.label_tensor.LabelTensor`, :class:`pina.graph.Graph`.
:class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`.
"""
# If the class is already a subclass, return the instance
@@ -62,15 +62,16 @@ class InputEquationCondition(ConditionInterface):
Initialize the InputEquationCondition by storing the input and equation.
:param input: Input data for the condition.
:type input: LabelTensor | Graph | list[Graph] | tuple[Graph]
:type input: pina.label_tensor.LabelTensor | pina.graph.Graph |
list[pina.graph.Graph] | tuple[pina.graph.Graph]
:param EquationInterface equation: Equation object containing the
equation function.
.. note::
If `input` is composed by a list of :class:`pina.graph.Graph`
If `input` is composed by a list of :class:`~pina.graph.Graph`
objects, all elements must have the same structure (keys and data
types). Moreover, at least one attribute must be a
:class:`pina.label_tensor.LabelTensor`.
:class:`~pina.label_tensor.LabelTensor`.
"""
super().__init__()
@@ -90,21 +91,21 @@ class InputEquationCondition(ConditionInterface):
class InputTensorEquationCondition(InputEquationCondition):
"""
InputEquationCondition subclass for :class:`pina.label_tensor.LabelTensor`
InputEquationCondition subclass for :class:`~pina.label_tensor.LabelTensor`
input data.
"""
class InputGraphEquationCondition(InputEquationCondition):
"""
InputEquationCondition subclass for :class:`pina.graph.Graph` input data.
InputEquationCondition subclass for :class:`~pina.graph.Graph` input data.
"""
@staticmethod
def _check_label_tensor(input):
"""
Check if at least one :class:`pina.label_tensor.LabelTensor` is present
in the :class:`pina.graph.Graph` object.
Check if at least one :class:`~pina.label_tensor.LabelTensor` is present
in the :class:`~pina.graph.Graph` object.
:param input: Input data.
:type input: torch.Tensor | Graph | Data

View File

@@ -44,8 +44,8 @@ class InputTargetCondition(ConditionInterface):
pina.condition.input_target_condition.GraphInputGraphTargetCondition
:raises ValueError: If input and or target are not of type
:class:`torch.Tensor`, :class:`pina.label_tensor.LabelTensor`,
:class:`pina.graph.Graph`, or :class:`~torch_geometric.data.Data`.
:class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`,
:class:`~pina.graph.Graph`, or :class:`~torch_geometric.data.Data`.
"""
if cls != InputTargetCondition:
return super().__new__(cls)
@@ -97,7 +97,7 @@ class InputTargetCondition(ConditionInterface):
.. note::
If either `input` or `target` are composed by a list of
:class:`pina.graph.Graph` or :class:`~torch_geometric.data.Data`
:class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data`
objects, all elements must have the same structure (keys and data
types)
"""
@@ -122,29 +122,29 @@ class InputTargetCondition(ConditionInterface):
class TensorInputTensorTargetCondition(InputTargetCondition):
"""
InputTargetCondition subclass for :class:`torch.Tensor` or
:class:`pina.label_tensor.LabelTensor` input and target data.
:class:`~pina.label_tensor.LabelTensor` input and target data.
"""
class TensorInputGraphTargetCondition(InputTargetCondition):
"""
InputTargetCondition subclass for :class:`torch.Tensor` or
:class:`pina.label_tensor.LabelTensor` input and
:class:`pina.graph.Graph` or :class:`~torch_geometric.data.Data` target
:class:`~pina.label_tensor.LabelTensor` input and
:class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` target
data.
"""
class GraphInputTensorTargetCondition(InputTargetCondition):
"""
InputTargetCondition subclass for :class:`pina.graph.Graph` o
InputTargetCondition subclass for :class:`~pina.graph.Graph` o
:class:`~torch_geometric.data.Data` input and :class:`torch.Tensor` or
:class:`pina.label_tensor.LabelTensor` target data.
:class:`~pina.label_tensor.LabelTensor` target data.
"""
class GraphInputGraphTargetCondition(InputTargetCondition):
"""
InputTargetCondition subclass for :class:`pina.graph.Graph`/
InputTargetCondition subclass for :class:`~pina.graph.Graph`/
:class:`~torch_geometric.data.Data` input and target data.
"""

View File

@@ -157,7 +157,7 @@ class Collator:
def _collate_tensor_dataset(data_list):
"""
Function used to collate the data when the dataset is a
:class:`pina.data.dataset.PinaTensorDataset`.
:class:`~pina.data.dataset.PinaTensorDataset`.
:param data_list: Elements to be collated.
:type data_list: list[torch.Tensor] | list[LabelTensor]
@@ -165,7 +165,7 @@ class Collator:
:rtype: dict
:raises RuntimeError: If the data is not a :class:`torch.Tensor` or a
:class:`pina.label_tensor.LabelTensor`.
:class:`~pina.label_tensor.LabelTensor`.
"""
if isinstance(data_list[0], LabelTensor):
@@ -177,7 +177,7 @@ class Collator:
def _collate_graph_dataset(self, data_list):
"""
Function used to collate the data when the dataset is a
:class:`pina.data.dataset.PinaGraphDataset`.
:class:`~pina.data.dataset.PinaGraphDataset`.
:param data_list: Elememts to be collated.
:type data_list: list[Data] | list[Graph]
@@ -185,7 +185,7 @@ class Collator:
:rtype: dict
:raises RuntimeError: If the data is not a
:class:`~torch_geometric.data.Data` or a :class:`pina.graph.Graph`.
:class:`~torch_geometric.data.Data` or a :class:`~pina.graph.Graph`.
"""
if isinstance(data_list[0], LabelTensor):

View File

@@ -15,26 +15,26 @@ class PinaDatasetFactory:
Depending on the type inside the conditions, it creates a different dataset
object:
- :class:`pina.data.dataset.PinaTensorDataset` for handling
:class:`torch.Tensor` and :class:`pina.label_tensor.LabelTensor` data.
- :class:`pina.data.dataset.PinaGraphDataset` for handling
:class:`pina.graph.Graph` and :class:`~torch_geometric.data.Data` data.
- :class:`~pina.data.dataset.PinaTensorDataset` for handling
:class:`torch.Tensor` and :class:`~pina.label_tensor.LabelTensor` data.
- :class:`~pina.data.dataset.PinaGraphDataset` for handling
:class:`~pina.graph.Graph` and :class:`~torch_geometric.data.Data` data.
"""
def __new__(cls, conditions_dict, **kwargs):
"""
Instantiate the appropriate subclass of
:class:`pina.data.dataset.PinaDataset`.
:class:`~pina.data.dataset.PinaDataset`.
If a graph is present in the conditions, returns a
:class:`pina.data.dataset.PinaGraphDataset`, otherwise returns a
:class:`pina.data.dataset.PinaTensorDataset`.
:class:`~pina.data.dataset.PinaGraphDataset`, otherwise returns a
:class:`~pina.data.dataset.PinaTensorDataset`.
:param dict conditions_dict: Dictionary containing all the conditions
to be included in the dataset instance.
:return: A subclass of :class:`pina.data.dataset.PinaDataset`.
:rtype: :class:`pina.data.dataset.PinaTensorDataset` |
:class:`pina.data.dataset.PinaGraphDataset`
:return: A subclass of :class:`~pina.data.dataset.PinaDataset`.
:rtype: :class:`~pina.data.dataset.PinaTensorDataset` |
:class:`~pina.data.dataset.PinaGraphDataset`
:raises ValueError: If an empty dictionary is provided.
"""
@@ -75,15 +75,15 @@ class PinaDatasetFactory:
class PinaDataset(Dataset, ABC):
"""
Abstract class for the PINA dataset. It defines the common interface for
the :class:`pina.data.dataset.PinaTensorDataset` and
:class:`pina.data.dataset.PinaGraphDataset` classes.
the :class:`~pina.data.dataset.PinaTensorDataset` and
:class:`~pina.data.dataset.PinaGraphDataset` classes.
"""
def __init__(
self, conditions_dict, max_conditions_lengths, automatic_batching
):
"""
Initialize a :class:`pina.data.dataset.PinaDataset` instance by storing
Initialize a :class:`~pina.data.dataset.PinaDataset` instance by storing
the providedconditions dictionary, the maximum number of conditions to
consider, and the automatic batching flag.
@@ -93,7 +93,7 @@ class PinaDataset(Dataset, ABC):
points to include in a single batch for each condition.
:param bool automatic_batching: Indicates whether PyTorch automatic
batching is enabled in
:class:`pina.data.data_module.PinaDataModule`.
:class:`~pina.data.data_module.PinaDataModule`.
"""
# Store the conditions dictionary
@@ -205,7 +205,7 @@ class PinaDataset(Dataset, ABC):
class PinaTensorDataset(PinaDataset):
"""
Dataset class for the PINA dataset with :class:`torch.Tensor` and
:class:`pina.label_tensor.LabelTensor` data.
:class:`~pina.label_tensor.LabelTensor` data.
"""
# Override _retrive_data method for torch.Tensor data
@@ -215,7 +215,7 @@ class PinaTensorDataset(PinaDataset):
:param dict data: Dictionary containing the data
(only :class:`torch.Tensor` or
:class:`pina.label_tensor.LabelTensor`).
:class:`~pina.label_tensor.LabelTensor`).
:param list[int] idx_list: indices to retrieve.
:return: Dictionary containing the data at the given indices.
:rtype: dict
@@ -237,7 +237,7 @@ class PinaTensorDataset(PinaDataset):
class PinaGraphDataset(PinaDataset):
"""
Dataset class for the PINA dataset with :class:`~torch_geometric.data.Data`
and :class:`pina.graph.Graph` data.
and :class:`~pina.graph.Graph` data.
"""
def _create_graph_batch(self, data):

View File

@@ -20,13 +20,13 @@ class Graph(Data):
**kwargs,
):
"""
Create a new instance of the :class:`pina.graph.Graph` class by checking
Create a new instance of the :class:`~pina.graph.Graph` class by checking
the consistency of the input data and storing the attributes.
:param kwargs: Parameters used to initialize the
:class:`pina.graph.Graph` object.
:class:`~pina.graph.Graph` object.
:type kwargs: dict
:return: A new instance of the :class:`pina.graph.Graph` class.
:return: A new instance of the :class:`~pina.graph.Graph` class.
:rtype: Graph
"""
@@ -230,7 +230,7 @@ class GraphBuilder:
):
"""
Compute the edge attributes and create a new instance of the
:class:`pina.graph.Graph` class.
:class:`~pina.graph.Graph` class.
:param pos: A tensor of shape `(N, D)` representing the positions of `N`
points in `D`-dimensional space.
@@ -248,8 +248,8 @@ class GraphBuilder:
If provided, overrides `edge_attr`.
:type custom_edge_func: callable, optional
:param kwargs: Additional keyword arguments passed to the
:class:`pina.graph.Graph` class constructor.
:return: A :class:`pina.graph.Graph` instance constructed using the
:class:`~pina.graph.Graph` class constructor.
:return: A :class:`~pina.graph.Graph` instance constructed using the
provided information.
:rtype: Graph
"""
@@ -289,7 +289,7 @@ class RadiusGraph(GraphBuilder):
def __new__(cls, pos, radius, **kwargs):
"""
Extends the :class:`pina.graph.GraphBuilder` class to compute
Extends the :class:`~pina.graph.GraphBuilder` class to compute
edge_index based on a radius. Each point is connected to all the points
within the radius.
@@ -299,9 +299,9 @@ class RadiusGraph(GraphBuilder):
:param radius: The radius within which points are connected.
:type radius: float
:param kwargs: Additional keyword arguments to be passed to the
:class:`pina.graph.GraphBuilder` and :class:`pina.graph.Graph`
:class:`~pina.graph.GraphBuilder` and :class:`~pina.graph.Graph`
constructors.
:return: A :class:`pina.graph.Graph` instance containing the input
:return: A :class:`~pina.graph.Graph` instance containing the input
information and the computed edge_index.
:rtype: Graph
"""
@@ -339,7 +339,7 @@ class KNNGraph(GraphBuilder):
def __new__(cls, pos, neighbours, **kwargs):
"""
Extends the :class:`pina.graph.GraphBuilder` class to compute edge_index
Extends the :class:`~pina.graph.GraphBuilder` class to compute edge_index
based on a K-nearest neighbors algorithm.
:param pos: A tensor of shape (N, D) representing the positions of N
@@ -351,7 +351,7 @@ class KNNGraph(GraphBuilder):
The additional keyword arguments to be passed to GraphBuilder
and Graph classes
:return: A :class:`pina.graph.Graph` instance containg the
:return: A :class:`~pina.graph.Graph` instance containg the
information passed in input and the computed edge_index
:rtype: Graph
"""

View File

@@ -14,14 +14,14 @@ class LabelTensor(torch.Tensor):
@staticmethod
def __new__(cls, x, labels, *args, **kwargs):
"""
Create a new instance of the :class:`pina.label_tensor.LabelTensor`
Create a new instance of the :class:`~pina.label_tensor.LabelTensor`
class.
:param torch.Tensor x: :class:`torch.tensor` instance to be casted as a
:class:`pina.label_tensor.LabelTensor`.
:class:`~pina.label_tensor.LabelTensor`.
:param labels: Labels to assign to the tensor.
:type labels: str | list[str] | dict
:return: The instance of the :class:`pina.label_tensor.LabelTensor`
:return: The instance of the :class:`~pina.label_tensor.LabelTensor`
class.
:rtype: LabelTensor
"""
@@ -33,10 +33,10 @@ class LabelTensor(torch.Tensor):
@property
def tensor(self):
"""
Give the tensor part of the :class:`pina.label_tensor.LabelTensor`
Give the tensor part of the :class:`~pina.label_tensor.LabelTensor`
object.
:return: tensor part of the :class:`pina.label_tensor.LabelTensor`.
:return: tensor part of the :class:`~pina.label_tensor.LabelTensor`.
:rtype: torch.Tensor
"""
@@ -44,7 +44,7 @@ class LabelTensor(torch.Tensor):
def __init__(self, x, labels):
"""
Construct a :class:`pina.label_tensor.LabelTensor` by passing a dict of
Construct a :class:`~pina.label_tensor.LabelTensor` by passing a dict of
the labels and a :class:`torch.Tensor`. Internally, the initialization
method will store check the compatibility of the labels with the tensor
shape.
@@ -275,10 +275,10 @@ class LabelTensor(torch.Tensor):
def __str__(self):
"""
The string representation of the :class:`pina.label_tensor.LabelTensor`.
The string representation of the :class:`~pina.label_tensor.LabelTensor`.
:return: String representation of the
:class:`pina.label_tensor.LabelTensor` instance.
:class:`~pina.label_tensor.LabelTensor` instance.
:rtype: str
"""
@@ -295,7 +295,7 @@ class LabelTensor(torch.Tensor):
Concatenate a list of tensors along a specified dimension. For more
details, see :meth:`torch.cat`.
:param list[LabelTensor] tensors: :class:`pina.label_tensor.LabelTensor`
:param list[LabelTensor] tensors: :class:`~pina.label_tensor.LabelTensor`
instances to concatenate
:param int dim: dimensions on which you want to perform the operation
(default is 0)
@@ -351,7 +351,7 @@ class LabelTensor(torch.Tensor):
:param list[LabelTensor] tensors: A list of tensors to stack.
All tensors must have the same shape.
:return: A new :class:`pina.label_tensor.LabelTensor` instance obtained
:return: A new :class:`~pina.label_tensor.LabelTensor` instance obtained
by stacking the input tensors, with the updated labels.
:rtype: LabelTensor
"""
@@ -373,7 +373,7 @@ class LabelTensor(torch.Tensor):
:param bool mode: A boolean value indicating whether the tensor should
track gradients.If `True`, the tensor will track gradients;
if `False`, it will not.
:return: The :class:`pina.label_tensor.LabelTensor` itself with the
:return: The :class:`~pina.label_tensor.LabelTensor` itself with the
updated `requires_grad` state and retained labels.
:rtype: LabelTensor
"""
@@ -399,7 +399,7 @@ class LabelTensor(torch.Tensor):
Performs Tensor dtype and/or device conversion. For more details, see
:meth:`torch.Tensor.to`.
:return: A new :class:`pina.label_tensor.LabelTensor` instance with the
:return: A new :class:`~pina.label_tensor.LabelTensor` instance with the
updated dtype and/or device and retained labels.
:rtype: LabelTensor
"""
@@ -410,10 +410,10 @@ class LabelTensor(torch.Tensor):
def clone(self, *args, **kwargs):
"""
Clone the :class:`pina.label_tensor.LabelTensor`. For more details, see
Clone the :class:`~pina.label_tensor.LabelTensor`. For more details, see
:meth:`torch.Tensor.clone`.
:return: A new :class:`pina.label_tensor.LabelTensor` instance with the
:return: A new :class:`~pina.label_tensor.LabelTensor` instance with the
same data and labels but allocated in a different memory location.
:rtype: LabelTensor
"""
@@ -472,9 +472,9 @@ class LabelTensor(torch.Tensor):
Stack tensors vertically. For more details, see :meth:`torch.vstack`.
:param list of LabelTensor label_tensors: The
:class:`pina.label_tensor.LabelTensor` instances to stack. They need
:class:`~pina.label_tensor.LabelTensor` instances to stack. They need
to have equal labels.
:return: A new :class:`pina.label_tensor.LabelTensor` instance obtained
:return: A new :class:`~pina.label_tensor.LabelTensor` instance obtained
by stacking the input tensors vertically.
:rtype: LabelTensor
"""
@@ -530,13 +530,13 @@ class LabelTensor(torch.Tensor):
def __getitem__(self, index):
""" "
Override the __getitem__ method to handle the labels of the
:class:`pina.label_tensor.LabelTensor` instance. It first performs
:class:`~pina.label_tensor.LabelTensor` instance. It first performs
__getitem__ operation on the :class:`torch.Tensor` part of the instance,
then updates the labels based on the index.
:param index: The index used to access the item
:type index: int | str | tuple of int | list ot int | torch.Tensor
:return: A new :class:`pina.label_tensor.LabelTensor` instance obtained
:return: A new :class:`~pina.label_tensor.LabelTensor` instance obtained
`__getitem__` operation on :class:`torch.Tensor` part of the
instance, with the updated labels.
:rtype: LabelTensor
@@ -672,7 +672,7 @@ class LabelTensor(torch.Tensor):
def summation(tensors):
"""
Computes the summation of a list of
:class:`pina.label_tensor.LabelTensor` instances.
:class:`~pina.label_tensor.LabelTensor` instances.
:param list[LabelTensor] tensors: A list of tensors to sum. All
@@ -719,7 +719,7 @@ class LabelTensor(torch.Tensor):
For more details, see :meth:`torch.Tensor.reshape`.
:param tuple of int shape: The new shape of the tensor.
:return: A new :class:`pina.label_tensor.LabelTensor` instance with the
:return: A new :class:`~pina.label_tensor.LabelTensor` instance with the
updated shape and labels.
:rtype: LabelTensor
"""