Fix doc data

This commit is contained in:
FilippoOlivo
2025-03-12 11:43:21 +01:00
committed by Nicola Demo
parent 6f97799284
commit 59e6ee595c
2 changed files with 48 additions and 49 deletions

View File

@@ -17,7 +17,7 @@ from ..collector import Collector
class DummyDataloader: class DummyDataloader:
""" """
Dataloader used when batch size is ``None``. It returns the entire dataset Dataloader used when batch size is `None`. It returns the entire dataset
in a single batch. in a single batch.
""" """
@@ -38,7 +38,7 @@ class DummyDataloader:
:param dataset: The dataset object to be processed. :param dataset: The dataset object to be processed.
:type dataset: PinaDataset :type dataset: PinaDataset
.. note:: This data loader is used when the batch size is ``None``. .. note:: This data loader is used when the batch size is `None`.
""" """
if ( if (
@@ -157,7 +157,7 @@ class Collator:
def _collate_tensor_dataset(data_list): def _collate_tensor_dataset(data_list):
""" """
Function used to collate the data when the dataset is a Function used to collate the data when the dataset is a
:class:`PinaTensorDataset`. :class:`pina.data.dataset.PinaTensorDataset`.
:param data_list: Elements to be collated. :param data_list: Elements to be collated.
:type data_list: list[torch.Tensor] | list[LabelTensor] :type data_list: list[torch.Tensor] | list[LabelTensor]
@@ -165,7 +165,7 @@ class Collator:
:rtype: dict :rtype: dict
:raises RuntimeError: If the data is not a :class:`torch.Tensor` or a :raises RuntimeError: If the data is not a :class:`torch.Tensor` or a
:class:`LabelTensor`. :class:`pina.label_tensor.LabelTensor`.
""" """
if isinstance(data_list[0], LabelTensor): if isinstance(data_list[0], LabelTensor):
@@ -177,15 +177,15 @@ class Collator:
def _collate_graph_dataset(self, data_list): def _collate_graph_dataset(self, data_list):
""" """
Function used to collate the data when the dataset is a Function used to collate the data when the dataset is a
:class:`PinaGraphDataset`. :class:`pina.data.dataset.PinaGraphDataset`.
:param data_list: Elememts to be collated. :param data_list: Elememts to be collated.
:type data_list: list[torch_geometric.data.Data] | list[Graph] :type data_list: list[Data] | list[Graph]
:return: Batch of data. :return: Batch of data.
:rtype: dict :rtype: dict
:raises RuntimeError: If the data is not a :raises RuntimeError: If the data is not a
:class:`torch_geometric.data.Data` or a :class:`Graph`. :class:`~torch_geometric.data.Data` or a :class:`pina.graph.Graph`.
""" """
if isinstance(data_list[0], LabelTensor): if isinstance(data_list[0], LabelTensor):
@@ -248,7 +248,7 @@ class PinaSampler:
class PinaDataModule(LightningDataModule): class PinaDataModule(LightningDataModule):
""" """
This class extends :class:`pytorch_lightning.LightningDataModule`, This class extends :class:`lightning.pytorch.LightningDataModule`,
allowing proper creation and management of different types of datasets allowing proper creation and management of different types of datasets
defined in PINA. defined in PINA.
""" """
@@ -536,8 +536,7 @@ class PinaDataModule(LightningDataModule):
""" """
Define the maximum length of the conditions. Define the maximum length of the conditions.
:param split: The splits of the dataset. :param dict split: The splits of the dataset.
:type split: dict
:return: The maximum length of the conditions. :return: The maximum length of the conditions.
:rtype: dict :rtype: dict
""" """
@@ -559,7 +558,7 @@ class PinaDataModule(LightningDataModule):
Create the validation dataloader. Create the validation dataloader.
:return: The validation dataloader :return: The validation dataloader
:rtype: DataLoader :rtype: torch.utils.data.DataLoader
""" """
return self._create_dataloader("val", self.val_dataset) return self._create_dataloader("val", self.val_dataset)
@@ -568,7 +567,7 @@ class PinaDataModule(LightningDataModule):
Create the training dataloader Create the training dataloader
:return: The training dataloader :return: The training dataloader
:rtype: DataLoader :rtype: torch.utils.data.DataLoader
""" """
return self._create_dataloader("train", self.train_dataset) return self._create_dataloader("train", self.train_dataset)
@@ -577,7 +576,7 @@ class PinaDataModule(LightningDataModule):
Create the testing dataloader Create the testing dataloader
:return: The testing dataloader :return: The testing dataloader
:rtype: DataLoader :rtype: torch.utils.data.DataLoader
""" """
return self._create_dataloader("test", self.test_dataset) return self._create_dataloader("test", self.test_dataset)

View File

@@ -15,24 +15,26 @@ class PinaDatasetFactory:
Depending on the type inside the conditions, it creates a different dataset Depending on the type inside the conditions, it creates a different dataset
object: object:
- :class:`PinaTensorDataset` for handling :class:`torch.Tensor` and - :class:`pina.data.dataset.PinaTensorDataset` for handling
:class:`LabelTensor` data. :class:`torch.Tensor` and :class:`pina.label_tensor.LabelTensor` data.
- :class:`PinaGraphDataset` for handling :class:`Graph` and :class:`Data` - :class:`pina.data.dataset.PinaGraphDataset` for handling
data. :class:`pina.graph.Graph` and :class:`~torch_geometric.data.Data` data.
""" """
def __new__(cls, conditions_dict, **kwargs): def __new__(cls, conditions_dict, **kwargs):
""" """
Instantiate the appropriate subclass of :class:`PinaDataset`. Instantiate the appropriate subclass of
:class:`pina.data.dataset.PinaDataset`.
If a graph is present in the conditions, returns a If a graph is present in the conditions, returns a
:class:`PinaGraphDataset`, otherwise returns a :class:`pina.data.dataset.PinaGraphDataset`, otherwise returns a
:class:`PinaTensorDataset`. :class:`pina.data.dataset.PinaTensorDataset`.
:param dict conditions_dict: Dictionary containing all the conditions :param dict conditions_dict: Dictionary containing all the conditions
to be included in the dataset instance. to be included in the dataset instance.
:return: A subclass of :class:`PinaDataset`. :return: A subclass of :class:`pina.data.dataset.PinaDataset`.
:rtype: :class:`PinaTensorDataset` | :class:`PinaGraphDataset` :rtype: :class:`pina.data.dataset.PinaTensorDataset` |
:class:`pina.data.dataset.PinaGraphDataset`
:raises ValueError: If an empty dictionary is provided. :raises ValueError: If an empty dictionary is provided.
""" """
@@ -73,25 +75,25 @@ class PinaDatasetFactory:
class PinaDataset(Dataset, ABC): class PinaDataset(Dataset, ABC):
""" """
Abstract class for the PINA dataset. It defines the common interface for Abstract class for the PINA dataset. It defines the common interface for
the :class:`PinaTensorDataset` and :class:`PinaGraphDataset` classes. the :class:`pina.data.dataset.PinaTensorDataset` and
:class:`pina.data.dataset.PinaGraphDataset` classes.
""" """
def __init__( def __init__(
self, conditions_dict, max_conditions_lengths, automatic_batching self, conditions_dict, max_conditions_lengths, automatic_batching
): ):
""" """
Initialize a :class:`PinaDataset` instance by storing the provided Initialize a :class:`pina.data.dataset.PinaDataset` instance by storing
conditions dictionary, the maximum number of conditions to consider, the providedconditions dictionary, the maximum number of conditions to
and the automatic batching flag. consider, and the automatic batching flag.
:param conditions_dict: Dictionary containing the conditions. :param dict conditions_dict: Dictionary containing the conditions with
:type conditions_dict: dict data.
:param max_conditions_lengths: Specifies the maximum number of data :param dict max_conditions_lengths: Specifies the maximum number of data
points to include in a single batch for each condition. points to include in a single batch for each condition.
:type max_conditions_lengths: dict :param bool automatic_batching: Indicates whether PyTorch automatic
:param automatic_batching: Indicates whether PyTorch automatic batching batching is enabled in
is enabled in :class:`PinaDataModule`. :class:`pina.data.data_module.PinaDataModule`.
:type automatic_batching: bool
""" """
# Store the conditions dictionary # Store the conditions dictionary
@@ -134,8 +136,7 @@ class PinaDataset(Dataset, ABC):
Return the index itself. This is used when automatic batching is Return the index itself. This is used when automatic batching is
disabled to postpone the data retrieval to the dataloader. disabled to postpone the data retrieval to the dataloader.
:param idx: Index. :param int idx: Index.
:type idx: int
:return: Index. :return: Index.
:rtype: int :rtype: int
""" """
@@ -174,8 +175,7 @@ class PinaDataset(Dataset, ABC):
""" """
Return data from the dataset given a list of indices. Return data from the dataset given a list of indices.
:param idx: List of indices. :param list[int] idx: List of indices.
:type idx: list[int]
:return: A dictionary containing the data at the given indices. :return: A dictionary containing the data at the given indices.
:rtype: dict :rtype: dict
""" """
@@ -205,7 +205,7 @@ class PinaDataset(Dataset, ABC):
class PinaTensorDataset(PinaDataset): class PinaTensorDataset(PinaDataset):
""" """
Dataset class for the PINA dataset with :class:`torch.Tensor` and Dataset class for the PINA dataset with :class:`torch.Tensor` and
:class:`LabelTensor` data. :class:`pina.label_tensor.LabelTensor` data.
""" """
# Override _retrive_data method for torch.Tensor data # Override _retrive_data method for torch.Tensor data
@@ -213,9 +213,9 @@ class PinaTensorDataset(PinaDataset):
""" """
Retrieve data from the dataset given a list of indices. Retrieve data from the dataset given a list of indices.
:param data: Dictionary containing the data :param dict data: Dictionary containing the data
(only torch.Tensor/LableTensor). (only :class:`torch.Tensor` or
:type data: dict :class:`pina.label_tensor.LabelTensor`).
:param list[int] idx_list: indices to retrieve. :param list[int] idx_list: indices to retrieve.
:return: Dictionary containing the data at the given indices. :return: Dictionary containing the data at the given indices.
:rtype: dict :rtype: dict
@@ -236,17 +236,17 @@ class PinaTensorDataset(PinaDataset):
class PinaGraphDataset(PinaDataset): class PinaGraphDataset(PinaDataset):
""" """
Dataset class for the PINA dataset with :class:`torch_geometric.data.Data` Dataset class for the PINA dataset with :class:`~torch_geometric.data.Data`
and :class:`Graph` data. and :class:`pina.graph.Graph` data.
""" """
def _create_graph_batch(self, data): def _create_graph_batch(self, data):
""" """
Create a LabelBatch object from a list of Create a LabelBatch object from a list of
:class:`torch_geometric.data.Data` objects. :class:`~torch_geometric.data.Data` objects.
:param data: List of items to collate in a single batch. :param data: List of items to collate in a single batch.
:type data: list[torch_geometric.data.Data] | list[Graph] :type data: list[Data] | list[Graph]
:return: LabelBatch object all the graph collated in a single batch :return: LabelBatch object all the graph collated in a single batch
disconnected graphs. disconnected graphs.
:rtype: LabelBatch :rtype: LabelBatch
@@ -256,13 +256,13 @@ class PinaGraphDataset(PinaDataset):
def _create_tensor_batch(self, data): def _create_tensor_batch(self, data):
""" """
Reshape properly ``data`` tensor to be processed handle by the graph Reshape properly `data` tensor to be processed handle by the graph
based models. based models.
:param data: torch.Tensor object of shape (N, ...) where N is the :param data: torch.Tensor object of shape (N, ...) where N is the
number of data points. number of data points.
:type data: torch.Tensor | LabelTensor :type data: torch.Tensor | LabelTensor
:return: reshaped torch.Tensor or LabelTensor object. :return: Reshaped tensor object.
:rtype: torch.Tensor | LabelTensor :rtype: torch.Tensor | LabelTensor
""" """
out = data.reshape(-1, *data.shape[2:]) out = data.reshape(-1, *data.shape[2:])
@@ -270,11 +270,11 @@ class PinaGraphDataset(PinaDataset):
def create_batch(self, data): def create_batch(self, data):
""" """
Create a Batch object from a list of :class:`torch_geometric.data.Data` Create a Batch object from a list of :class:`~torch_geometric.data.Data`
objects. objects.
:param data: List of items to collate in a single batch. :param data: List of items to collate in a single batch.
:type data: list[torch_geometric.data.Data] | list[Graph] :type data: list[Data] | list[Graph]
:return: Batch object. :return: Batch object.
:rtype: Batch | PinaBatch :rtype: Batch | PinaBatch
""" """