Fix rendering and codacy

This commit is contained in:
FilippoOlivo
2025-03-14 15:05:16 +01:00
committed by Nicola Demo
parent 05105dd517
commit 001d1fc9cf
8 changed files with 98 additions and 96 deletions

View File

@@ -118,7 +118,7 @@ class Collector:
""" """
Store inside data collections the sampled data of the problem. These Store inside data collections the sampled data of the problem. These
comes from the conditions that require sampling (e.g. comes from the conditions that require sampling (e.g.
:class:`~pina.condition.domain_equation_condition. :class:`~pina.condition.domain_equation_condition.\
DomainEquationCondition`). DomainEquationCondition`).
""" """

View File

@@ -244,7 +244,7 @@ class PinaSampler:
class PinaDataModule(LightningDataModule): class PinaDataModule(LightningDataModule):
""" """
This class extends :class:`lightning.pytorch.LightningDataModule`, This class extends :class:`~lightning.pytorch.core.LightningDataModule`,
allowing proper creation and management of different types of datasets allowing proper creation and management of different types of datasets
defined in PINA. defined in PINA.
""" """
@@ -274,18 +274,18 @@ class PinaDataModule(LightningDataModule):
:param float val_size: Fraction of elements in the validation split. It :param float val_size: Fraction of elements in the validation split. It
must be in the range [0, 1]. must be in the range [0, 1].
:param batch_size: The batch size used for training. If ``None``, the :param batch_size: The batch size used for training. If ``None``, the
entire dataset is returned in a single batch. entire dataset is returned in a single batch. Default is ``None``.
:type batch_size: int | None :type batch_size: int
:param bool shuffle: Whether to shuffle the dataset before splitting. :param bool shuffle: Whether to shuffle the dataset before splitting.
Default True. Default ``Tru``e.
:param bool repeat: Whether to repeat the dataset indefinitely. :param bool repeat: Whether to repeat the dataset indefinitely.
Default False. Default ``False``.
:param automatic_batching: Whether to enable automatic batching. :param automatic_batching: Whether to enable automatic batching.
Default False. Default ``False``.
:param int num_workers: Number of worker threads for data loading. :param int num_workers: Number of worker threads for data loading.
Default 0 (serial loading). Default ``0`` (serial loading).
:param bool pin_memory: Whether to use pinned memory for faster data :param bool pin_memory: Whether to use pinned memory for faster data
transfer to GPU. Default False. transfer to GPU. Default ``False``.
:raises ValueError: If at least one of the splits is negative. :raises ValueError: If at least one of the splits is negative.
:raises ValueError: If the sum of the splits is different from 1. :raises ValueError: If the sum of the splits is different from 1.
@@ -643,7 +643,7 @@ class PinaDataModule(LightningDataModule):
Return all the input points coming from all the datasets. Return all the input points coming from all the datasets.
:return: The input points for training. :return: The input points for training.
:rtype dict :rtype: dict
""" """
to_return = {} to_return = {}

View File

@@ -12,12 +12,13 @@ class PinaDatasetFactory:
""" """
Factory class for the PINA dataset. Factory class for the PINA dataset.
Depending on the type inside the conditions, it creates a different dataset Depending on the data type inside the conditions, it instanciate an object
object: belonging to the appropriate subclass of
:class:`~pina.data.dataset.PinaDataset`. The possible subclasses are:
- :class:`~pina.data.dataset.PinaTensorDataset` for handling - :class:`~pina.data.dataset.PinaTensorDataset`, for handling \
:class:`torch.Tensor` and :class:`~pina.label_tensor.LabelTensor` data. :class:`torch.Tensor` and :class:`~pina.label_tensor.LabelTensor` data.
- :class:`~pina.data.dataset.PinaGraphDataset` for handling - :class:`~pina.data.dataset.PinaGraphDataset`, for handling \
:class:`~pina.graph.Graph` and :class:`~torch_geometric.data.Data` data. :class:`~pina.graph.Graph` and :class:`~torch_geometric.data.Data` data.
""" """
@@ -33,8 +34,7 @@ class PinaDatasetFactory:
:param dict conditions_dict: Dictionary containing all the conditions :param dict conditions_dict: Dictionary containing all the conditions
to be included in the dataset instance. to be included in the dataset instance.
:return: A subclass of :class:`~pina.data.dataset.PinaDataset`. :return: A subclass of :class:`~pina.data.dataset.PinaDataset`.
:rtype: pina.data.dataset.PinaTensorDataset | :rtype: PinaTensorDataset | PinaGraphDataset
pina.data.dataset.PinaGraphDataset
:raises ValueError: If an empty dictionary is provided. :raises ValueError: If an empty dictionary is provided.
""" """
@@ -74,8 +74,9 @@ class PinaDatasetFactory:
class PinaDataset(Dataset, ABC): class PinaDataset(Dataset, ABC):
""" """
Abstract class for the PINA dataset. It defines the common interface for Abstract class for the PINA dataset which extends the PyTorch
:class:`~pina.data.dataset.PinaTensorDataset` and :class:`~torch.utils.data.Dataset` class. It defines the common interface
for :class:`~pina.data.dataset.PinaTensorDataset` and
:class:`~pina.data.dataset.PinaGraphDataset` classes. :class:`~pina.data.dataset.PinaGraphDataset` classes.
""" """
@@ -83,13 +84,15 @@ class PinaDataset(Dataset, ABC):
self, conditions_dict, max_conditions_lengths, automatic_batching self, conditions_dict, max_conditions_lengths, automatic_batching
): ):
""" """
Initialize :class:`~pina.data.dataset.PinaDataset` instance by storing Initialize the instance by storing the conditions dictionary, the
the provided conditions dictionary, and the automatic batching flag. maximum number of items per conditions to consider, and the automatic
batching flag.
:param dict conditions_dict: Dictionary containing the conditions with :param dict conditions_dict: A dictionary mapping condition names to
data. their respective data. Each key represents a condition name, and the
:param dict max_conditions_lengths: Specifies the maximum number of data corresponding value is a dictionary containing the associated data.
points to include in a single batch for each condition. :param dict max_conditions_lengths: Maximum number of data points that
can be included in a single batch per condition.
:param bool automatic_batching: Indicates whether PyTorch automatic :param bool automatic_batching: Indicates whether PyTorch automatic
batching is enabled in batching is enabled in
:class:`~pina.data.data_module.PinaDataModule`. :class:`~pina.data.data_module.PinaDataModule`.
@@ -258,8 +261,8 @@ class PinaGraphDataset(PinaDataset):
Reshape properly ``data`` tensor to be processed handle by the graph Reshape properly ``data`` tensor to be processed handle by the graph
based models. based models.
:param data: torch.Tensor object of shape (N, ...) where N is the :param data: torch.Tensor object of shape ``(N, ...)`` where ``N`` is
number of data points. the number of data objects.
:type data: torch.Tensor | LabelTensor :type data: torch.Tensor | LabelTensor
:return: Reshaped tensor object. :return: Reshaped tensor object.
:rtype: torch.Tensor | LabelTensor :rtype: torch.Tensor | LabelTensor
@@ -275,7 +278,7 @@ class PinaGraphDataset(PinaDataset):
:param data: List of items to collate in a single batch. :param data: List of items to collate in a single batch.
:type data: list[Data] | list[Graph] :type data: list[Data] | list[Graph]
:return: Batch object. :return: Batch object.
:rtype: Batch | PinaBatch :rtype: Batch | LabelBatch
""" """
if isinstance(data[0], Data): if isinstance(data[0], Data):

View File

@@ -23,9 +23,8 @@ class Graph(Data):
Create a new instance of the :class:`~pina.graph.Graph` class by Create a new instance of the :class:`~pina.graph.Graph` class by
checking the consistency of the input data and storing the attributes. checking the consistency of the input data and storing the attributes.
:param kwargs: Parameters used to initialize the :param dict kwargs: Parameters used to initialize the
:class:`~pina.graph.Graph` object. :class:`~pina.graph.Graph` object.
:type kwargs: dict
:return: A new instance of the :class:`~pina.graph.Graph` class. :return: A new instance of the :class:`~pina.graph.Graph` class.
:rtype: Graph :rtype: Graph
""" """
@@ -56,8 +55,8 @@ class Graph(Data):
:param x: Optional tensor of node features ``(N, F)`` where ``F`` is the :param x: Optional tensor of node features ``(N, F)`` where ``F`` is the
number of features per node. number of features per node.
:type x: torch.Tensor, LabelTensor :type x: torch.Tensor, LabelTensor
:param torch.Tensor edge_index: A tensor of shape ``(2, E)`` representing :param torch.Tensor edge_index: A tensor of shape ``(2, E)``
the indices of the graph's edges. representing the indices of the graph's edges.
:param pos: A tensor of shape ``(N, D)`` representing the positions of :param pos: A tensor of shape ``(N, D)`` representing the positions of
``N`` points in ``D``-dimensional space. ``N`` points in ``D``-dimensional space.
:type pos: torch.Tensor | LabelTensor :type pos: torch.Tensor | LabelTensor
@@ -80,8 +79,7 @@ class Graph(Data):
""" """
Check the consistency of the types of the input data. Check the consistency of the types of the input data.
:param kwargs: Attributes to be checked for consistency. :param dict kwargs: Attributes to be checked for consistency.
:type kwargs: dict
""" """
# default types, specified in cls.__new__, by default they are Nont # default types, specified in cls.__new__, by default they are Nont
# if specified in **kwargs they get override # if specified in **kwargs they get override
@@ -134,7 +132,8 @@ class Graph(Data):
Check if the edge attribute tensor is consistent in type and shape Check if the edge attribute tensor is consistent in type and shape
with the edge index. with the edge index.
:param torch.Tensor edge_attr: The edge attribute tensor. :param edge_attr: The edge attribute tensor.
:type edge_attr: torch.Tensor | LabelTensor
:param torch.Tensor edge_index: The edge index tensor. :param torch.Tensor edge_index: The edge index tensor.
:raises ValueError: If the edge attribute tensor is not consistent. :raises ValueError: If the edge attribute tensor is not consistent.
""" """
@@ -156,8 +155,11 @@ class Graph(Data):
Check if the input tensor x is consistent with the position tensor Check if the input tensor x is consistent with the position tensor
`pos`. `pos`.
:param torch.Tensor x: The input tensor. :param x: The input tensor.
:param torch.Tensor pos: The position tensor. :type x: torch.Tensor | LabelTensor
:param pos: The position tensor.
:type pos: torch.Tensor | LabelTensor
:raises ValueError: If the input tensor is not consistent.
""" """
if x is not None: if x is not None:
check_consistency(x, (torch.Tensor, LabelTensor)) check_consistency(x, (torch.Tensor, LabelTensor))
@@ -166,9 +168,6 @@ class Graph(Data):
if pos is not None: if pos is not None:
if x.size(0) != pos.size(0): if x.size(0) != pos.size(0):
raise ValueError("Inconsistent number of nodes.") raise ValueError("Inconsistent number of nodes.")
if pos is not None:
if x.size(0) != pos.size(0):
raise ValueError("Inconsistent number of nodes.")
@staticmethod @staticmethod
def _preprocess_edge_index(edge_index, undirected): def _preprocess_edge_index(edge_index, undirected):
@@ -292,7 +291,7 @@ class GraphBuilder:
class RadiusGraph(GraphBuilder): class RadiusGraph(GraphBuilder):
""" """
Extends the :class:`~pina.graph.GraphBuilder` class to compute Extends the :class:`~pina.graph.GraphBuilder` class to compute
edge_index based on a radius. Each point is connected to all the points ``edge_index`` based on a radius. Each point is connected to all the points
within the radius. within the radius.
""" """
@@ -305,11 +304,10 @@ class RadiusGraph(GraphBuilder):
``N`` points in ``D``-dimensional space. ``N`` points in ``D``-dimensional space.
:type pos: torch.Tensor | LabelTensor :type pos: torch.Tensor | LabelTensor
:param float radius: The radius within which points are connected. :param float radius: The radius within which points are connected.
:param kwargs: Additional keyword arguments to be passed to the :param dict kwargs: The additional keyword arguments to be passed to
:class:`~pina.graph.GraphBuilder` and :class:`~pina.graph.Graph` :class:`GraphBuilder` and :class:`Graph` classes.
constructors. :return: A :class:`~pina.graph.Graph` instance with the computed
:return: A :class:`~pina.graph.Graph` instance containing the input ``edge_index``.
information and the computed ``edge_index``.
:rtype: Graph :rtype: Graph
""" """
edge_index = cls.compute_radius_graph(pos, radius) edge_index = cls.compute_radius_graph(pos, radius)
@@ -318,16 +316,16 @@ class RadiusGraph(GraphBuilder):
@staticmethod @staticmethod
def compute_radius_graph(points, radius): def compute_radius_graph(points, radius):
""" """
Computes ``edge_index`` for a given set of points base on the radius. Computes the ``edge_index`` based on the radius. Each point is connected
Each point is connected to all the points within the radius. to all the points within the radius.
:param points: A tensor of shape ``(N, D)`` representing the positions :param points: A tensor of shape ``(N, D)`` representing the positions
of ``N`` points in ``D``-dimensional space. of ``N`` points in ``D``-dimensional space.
:type points: torch.Tensor | LabelTensor :type points: torch.Tensor | LabelTensor
:param float radius: The number of nearest neighbors to find for each :param float radius: The radius within which points are connected.
point. :return: A tensor of shape ``(2, E)``, with ``E`` number of edges,
:rtype torch.Tensor: A tensor of shape ``(2, E)``, where ``E`` is the representing the edge indices of the graph.
number of edges, representing the edge indices of the KNN graph. :rtype: torch.Tensor
""" """
dist = torch.cdist(points, points, p=2) dist = torch.cdist(points, points, p=2)
return ( return (
@@ -340,7 +338,7 @@ class RadiusGraph(GraphBuilder):
class KNNGraph(GraphBuilder): class KNNGraph(GraphBuilder):
""" """
Extends the :class:`~pina.graph.GraphBuilder` class to compute Extends the :class:`~pina.graph.GraphBuilder` class to compute
edge_index based on a K-nearest neighbors algorithm. ``edge_index`` based on a K-nearest neighbors algorithm.
""" """
def __new__(cls, pos, neighbours, **kwargs): def __new__(cls, pos, neighbours, **kwargs):
@@ -353,12 +351,11 @@ class KNNGraph(GraphBuilder):
:type pos: torch.Tensor | LabelTensor :type pos: torch.Tensor | LabelTensor
:param int neighbours: The number of nearest neighbors to consider when :param int neighbours: The number of nearest neighbors to consider when
building the graph. building the graph.
:Keyword Arguments: :param dict kwargs: The additional keyword arguments to be passed to
The additional keyword arguments to be passed to GraphBuilder :class:`GraphBuilder` and :class:`Graph` classes.
and Graph classes
:return: A :class:`~pina.graph.Graph` instance containg the :return: A :class:`~pina.graph.Graph` instance with the computed
information passed in input and the computed ``edge_index`` ``edge_index``.
:rtype: Graph :rtype: Graph
""" """
@@ -366,41 +363,45 @@ class KNNGraph(GraphBuilder):
return super().__new__(cls, pos=pos, edge_index=edge_index, **kwargs) return super().__new__(cls, pos=pos, edge_index=edge_index, **kwargs)
@staticmethod @staticmethod
def compute_knn_graph(points, k): def compute_knn_graph(points, neighbours):
""" """
Computes the edge_index based k-nearest neighbors graph algorithm Computes the ``edge_index`` based on the K-nearest neighbors algorithm.
:param points: A tensor of shape ``(N, D)`` representing the positions :param points: A tensor of shape ``(N, D)`` representing the positions
of ``N`` points in ``D``-dimensional space. of ``N`` points in ``D``-dimensional space.
:type points: torch.Tensor | LabelTensor :type points: torch.Tensor | LabelTensor
:param int k: The number of nearest neighbors to find for each point. :param int neighbours: The number of nearest neighbors to consider when
:return: A tensor of shape ``(2, E)``, where ``E`` is the number of building the graph.
edges, representing the edge indices of the KNN graph. :return: A tensor of shape ``(2, E)``, with ``E`` number of edges,
representing the edge indices of the graph.
:rtype: torch.Tensor :rtype: torch.Tensor
""" """
dist = torch.cdist(points, points, p=2) dist = torch.cdist(points, points, p=2)
knn_indices = torch.topk(dist, k=k + 1, largest=False).indices[:, 1:] knn_indices = torch.topk(dist, k=neighbours + 1, largest=False).indices[
row = torch.arange(points.size(0)).repeat_interleave(k) :, 1:
]
row = torch.arange(points.size(0)).repeat_interleave(neighbours)
col = knn_indices.flatten() col = knn_indices.flatten()
return torch.stack([row, col], dim=0).as_subclass(torch.Tensor) return torch.stack([row, col], dim=0).as_subclass(torch.Tensor)
class LabelBatch(Batch): class LabelBatch(Batch):
""" """
Add extract function to torch_geometric Batch object Extends the :class:`~torch_geometric.data.Batch` class to include
:class:`~pina.label_tensor.LabelTensor` objects.
""" """
@classmethod @classmethod
def from_data_list(cls, data_list): def from_data_list(cls, data_list):
""" """
Create a Batch object from a list of :class:`~torch_geometric.data.Data` Create a Batch object from a list of :class:`~torch_geometric.data.Data`
objects. or :class:`~pina.graph.Graph` objects.
:param data_list: List of :class:`~torch_geometric.data.Data` or :param data_list: List of :class:`~torch_geometric.data.Data` or
:class:`~pina.graph.Graph` objects. :class:`~pina.graph.Graph` objects.
:type data_list: list[Data] | list[Graph] :type data_list: list[Data] | list[Graph]
:return: A Batch object containing the data in the list :return: A :class:`Batch` object containing the input data.
:rtype: Batch :rtype: Batch
""" """
# Store the labels of Data/Graph objects (all data have the same labels) # Store the labels of Data/Graph objects (all data have the same labels)

View File

@@ -216,10 +216,10 @@ class LabelTensor(torch.Tensor):
def extract(self, labels_to_extract): def extract(self, labels_to_extract):
""" """
Extract the subset of the original tensor by returning all the positions Extract the subset of the original tensor by returning all the positions
corresponding to the passed ``label_to_extract``. If ``label_to_extract`` corresponding to the passed ``label_to_extract``. If
is a dictionary, the keys are the dimension names and the values are the ``label_to_extract`` is a dictionary, the keys are the dimension names
labels to extract. If a single label or a list of labels is passed, the and the values are the labels to extract. If a single label or a list
last dimension is considered. of labels is passed, the last dimension is considered.
:Example: :Example:
>>> from pina import LabelTensor >>> from pina import LabelTensor

View File

@@ -109,9 +109,7 @@ class FourierIntegralKernel(torch.nn.Module):
if all(isinstance(i, list) for i in n_modes) and len(layers) != len( if all(isinstance(i, list) for i in n_modes) and len(layers) != len(
n_modes n_modes
): ):
raise RuntimeError( raise RuntimeError("Inconsistent number of layers and modes.")
"Inconsistent number of layers and modes."
)
if all(isinstance(i, int) for i in n_modes): if all(isinstance(i, int) for i in n_modes):
n_modes = [n_modes] * len(layers) n_modes = [n_modes] * len(layers)
else: else: