From 3db9b3b9401726e1df78c5c4f0c4ed789a41ddcc Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Wed, 12 Mar 2025 11:43:21 +0100 Subject: [PATCH] Fix doc data --- pina/data/data_module.py | 25 +++++++------- pina/data/dataset.py | 72 ++++++++++++++++++++-------------------- 2 files changed, 48 insertions(+), 49 deletions(-) diff --git a/pina/data/data_module.py b/pina/data/data_module.py index 48e3b2f..56a8f50 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -17,7 +17,7 @@ from ..collector import Collector class DummyDataloader: """ - Dataloader used when batch size is ``None``. It returns the entire dataset + Dataloader used when batch size is `None`. It returns the entire dataset in a single batch. """ @@ -38,7 +38,7 @@ class DummyDataloader: :param dataset: The dataset object to be processed. :type dataset: PinaDataset - .. note:: This data loader is used when the batch size is ``None``. + .. note:: This data loader is used when the batch size is `None`. """ if ( @@ -157,7 +157,7 @@ class Collator: def _collate_tensor_dataset(data_list): """ Function used to collate the data when the dataset is a - :class:`PinaTensorDataset`. + :class:`pina.data.dataset.PinaTensorDataset`. :param data_list: Elements to be collated. :type data_list: list[torch.Tensor] | list[LabelTensor] @@ -165,7 +165,7 @@ class Collator: :rtype: dict :raises RuntimeError: If the data is not a :class:`torch.Tensor` or a - :class:`LabelTensor`. + :class:`pina.label_tensor.LabelTensor`. """ if isinstance(data_list[0], LabelTensor): @@ -177,15 +177,15 @@ class Collator: def _collate_graph_dataset(self, data_list): """ Function used to collate the data when the dataset is a - :class:`PinaGraphDataset`. + :class:`pina.data.dataset.PinaGraphDataset`. :param data_list: Elememts to be collated. - :type data_list: list[torch_geometric.data.Data] | list[Graph] + :type data_list: list[Data] | list[Graph] :return: Batch of data. :rtype: dict :raises RuntimeError: If the data is not a - :class:`torch_geometric.data.Data` or a :class:`Graph`. + :class:`~torch_geometric.data.Data` or a :class:`pina.graph.Graph`. """ if isinstance(data_list[0], LabelTensor): @@ -248,7 +248,7 @@ class PinaSampler: class PinaDataModule(LightningDataModule): """ - This class extends :class:`pytorch_lightning.LightningDataModule`, + This class extends :class:`lightning.pytorch.LightningDataModule`, allowing proper creation and management of different types of datasets defined in PINA. """ @@ -536,8 +536,7 @@ class PinaDataModule(LightningDataModule): """ Define the maximum length of the conditions. - :param split: The splits of the dataset. - :type split: dict + :param dict split: The splits of the dataset. :return: The maximum length of the conditions. :rtype: dict """ @@ -559,7 +558,7 @@ class PinaDataModule(LightningDataModule): Create the validation dataloader. :return: The validation dataloader - :rtype: DataLoader + :rtype: torch.utils.data.DataLoader """ return self._create_dataloader("val", self.val_dataset) @@ -568,7 +567,7 @@ class PinaDataModule(LightningDataModule): Create the training dataloader :return: The training dataloader - :rtype: DataLoader + :rtype: torch.utils.data.DataLoader """ return self._create_dataloader("train", self.train_dataset) @@ -577,7 +576,7 @@ class PinaDataModule(LightningDataModule): Create the testing dataloader :return: The testing dataloader - :rtype: DataLoader + :rtype: torch.utils.data.DataLoader """ return self._create_dataloader("test", self.test_dataset) diff --git a/pina/data/dataset.py b/pina/data/dataset.py index b1d6c71..2effe60 100644 --- a/pina/data/dataset.py +++ b/pina/data/dataset.py @@ -15,24 +15,26 @@ class PinaDatasetFactory: Depending on the type inside the conditions, it creates a different dataset object: - - :class:`PinaTensorDataset` for handling :class:`torch.Tensor` and - :class:`LabelTensor` data. - - :class:`PinaGraphDataset` for handling :class:`Graph` and :class:`Data` - data. + - :class:`pina.data.dataset.PinaTensorDataset` for handling + :class:`torch.Tensor` and :class:`pina.label_tensor.LabelTensor` data. + - :class:`pina.data.dataset.PinaGraphDataset` for handling + :class:`pina.graph.Graph` and :class:`~torch_geometric.data.Data` data. """ def __new__(cls, conditions_dict, **kwargs): """ - Instantiate the appropriate subclass of :class:`PinaDataset`. + Instantiate the appropriate subclass of + :class:`pina.data.dataset.PinaDataset`. If a graph is present in the conditions, returns a - :class:`PinaGraphDataset`, otherwise returns a - :class:`PinaTensorDataset`. + :class:`pina.data.dataset.PinaGraphDataset`, otherwise returns a + :class:`pina.data.dataset.PinaTensorDataset`. :param dict conditions_dict: Dictionary containing all the conditions to be included in the dataset instance. - :return: A subclass of :class:`PinaDataset`. - :rtype: :class:`PinaTensorDataset` | :class:`PinaGraphDataset` + :return: A subclass of :class:`pina.data.dataset.PinaDataset`. + :rtype: :class:`pina.data.dataset.PinaTensorDataset` | + :class:`pina.data.dataset.PinaGraphDataset` :raises ValueError: If an empty dictionary is provided. """ @@ -73,25 +75,25 @@ class PinaDatasetFactory: class PinaDataset(Dataset, ABC): """ Abstract class for the PINA dataset. It defines the common interface for - the :class:`PinaTensorDataset` and :class:`PinaGraphDataset` classes. + the :class:`pina.data.dataset.PinaTensorDataset` and + :class:`pina.data.dataset.PinaGraphDataset` classes. """ def __init__( self, conditions_dict, max_conditions_lengths, automatic_batching ): """ - Initialize a :class:`PinaDataset` instance by storing the provided - conditions dictionary, the maximum number of conditions to consider, - and the automatic batching flag. + Initialize a :class:`pina.data.dataset.PinaDataset` instance by storing + the providedconditions dictionary, the maximum number of conditions to + consider, and the automatic batching flag. - :param conditions_dict: Dictionary containing the conditions. - :type conditions_dict: dict - :param max_conditions_lengths: Specifies the maximum number of data + :param dict conditions_dict: Dictionary containing the conditions with + data. + :param dict max_conditions_lengths: Specifies the maximum number of data points to include in a single batch for each condition. - :type max_conditions_lengths: dict - :param automatic_batching: Indicates whether PyTorch automatic batching - is enabled in :class:`PinaDataModule`. - :type automatic_batching: bool + :param bool automatic_batching: Indicates whether PyTorch automatic + batching is enabled in + :class:`pina.data.data_module.PinaDataModule`. """ # Store the conditions dictionary @@ -134,8 +136,7 @@ class PinaDataset(Dataset, ABC): Return the index itself. This is used when automatic batching is disabled to postpone the data retrieval to the dataloader. - :param idx: Index. - :type idx: int + :param int idx: Index. :return: Index. :rtype: int """ @@ -174,8 +175,7 @@ class PinaDataset(Dataset, ABC): """ Return data from the dataset given a list of indices. - :param idx: List of indices. - :type idx: list[int] + :param list[int] idx: List of indices. :return: A dictionary containing the data at the given indices. :rtype: dict """ @@ -205,7 +205,7 @@ class PinaDataset(Dataset, ABC): class PinaTensorDataset(PinaDataset): """ Dataset class for the PINA dataset with :class:`torch.Tensor` and - :class:`LabelTensor` data. + :class:`pina.label_tensor.LabelTensor` data. """ # Override _retrive_data method for torch.Tensor data @@ -213,9 +213,9 @@ class PinaTensorDataset(PinaDataset): """ Retrieve data from the dataset given a list of indices. - :param data: Dictionary containing the data - (only torch.Tensor/LableTensor). - :type data: dict + :param dict data: Dictionary containing the data + (only :class:`torch.Tensor` or + :class:`pina.label_tensor.LabelTensor`). :param list[int] idx_list: indices to retrieve. :return: Dictionary containing the data at the given indices. :rtype: dict @@ -236,17 +236,17 @@ class PinaTensorDataset(PinaDataset): class PinaGraphDataset(PinaDataset): """ - Dataset class for the PINA dataset with :class:`torch_geometric.data.Data` - and :class:`Graph` data. + Dataset class for the PINA dataset with :class:`~torch_geometric.data.Data` + and :class:`pina.graph.Graph` data. """ def _create_graph_batch(self, data): """ Create a LabelBatch object from a list of - :class:`torch_geometric.data.Data` objects. + :class:`~torch_geometric.data.Data` objects. :param data: List of items to collate in a single batch. - :type data: list[torch_geometric.data.Data] | list[Graph] + :type data: list[Data] | list[Graph] :return: LabelBatch object all the graph collated in a single batch disconnected graphs. :rtype: LabelBatch @@ -256,13 +256,13 @@ class PinaGraphDataset(PinaDataset): def _create_tensor_batch(self, data): """ - Reshape properly ``data`` tensor to be processed handle by the graph + Reshape properly `data` tensor to be processed handle by the graph based models. :param data: torch.Tensor object of shape (N, ...) where N is the number of data points. :type data: torch.Tensor | LabelTensor - :return: reshaped torch.Tensor or LabelTensor object. + :return: Reshaped tensor object. :rtype: torch.Tensor | LabelTensor """ out = data.reshape(-1, *data.shape[2:]) @@ -270,11 +270,11 @@ class PinaGraphDataset(PinaDataset): def create_batch(self, data): """ - Create a Batch object from a list of :class:`torch_geometric.data.Data` + Create a Batch object from a list of :class:`~torch_geometric.data.Data` objects. :param data: List of items to collate in a single batch. - :type data: list[torch_geometric.data.Data] | list[Graph] + :type data: list[Data] | list[Graph] :return: Batch object. :rtype: Batch | PinaBatch """