From 7b00b80ecb35aba2c45ab322c08236c9b09437e1 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Mon, 10 Mar 2025 15:57:15 +0100 Subject: [PATCH] Documentation and docstring graph and data --- pina/data/data_module.py | 189 ++++++++++++++++++++++++++++++--------- pina/data/dataset.py | 146 ++++++++++++++++++++++++++---- pina/graph.py | 90 +++++++++++++------ 3 files changed, 342 insertions(+), 83 deletions(-) diff --git a/pina/data/data_module.py b/pina/data/data_module.py index f68bbc7..2179984 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -23,16 +23,24 @@ class DummyDataloader: def __init__(self, dataset): """ - param dataset: The dataset object to be processed. - :notes: - - **Distributed Environment**: - - Divides the dataset across processes using the - rank and world size. - - Fetches only the portion of data corresponding to - the current process. - - **Non-Distributed Environment**: - - Fetches the entire dataset. + Preprare a dataloader object which will return the entire dataset + in a single batch. Depending on the number of GPUs, the dataset we + have the following cases: + + - **Distributed Environment** (multiple GPUs): + - Divides the dataset across processes using the rank and world + size. + - Fetches only the portion of data corresponding to the current + process. + - **Non-Distributed Environment** (single GPU): + - Fetches the entire dataset. + + :param dataset: The dataset object to be processed. + :type dataset: PinaDataset + + .. note:: This data loader is used when the batch size is None. """ + if ( torch.distributed.is_available() and torch.distributed.is_initialized() @@ -67,23 +75,50 @@ class Collator: Class used to collate the batch """ - def __init__(self, max_conditions_lengths, dataset=None): + def __init__( + self, max_conditions_lengths, automatic_batching, dataset=None + ): + """ + Initialize the object, setting the collate function based on whether + automatic batching is enabled or not. + + :param dict max_conditions_lengths: dict containing the maximum number + of data points to consider in a single batch for each condition. + :param PinaDataset dataset: The dataset where the data is stored. + """ + self.max_conditions_lengths = max_conditions_lengths + # Set the collate function based on the batching strategy + # collate_pina_dataloader is used when automatic batching is disabled + # collate_torch_dataloader is used when automatic batching is enabled self.callable_function = ( - self._collate_custom_dataloader - if max_conditions_lengths is None - else (self._collate_standard_dataloader) + self._collate_torch_dataloader + if automatic_batching + else (self._collate_pina_dataloader) ) self.dataset = dataset + + # Set the function which performs the actual collation if isinstance(self.dataset, PinaTensorDataset): + # If the dataset is a PinaTensorDataset, use this collate function self._collate = self._collate_tensor_dataset else: + # If the dataset is a PinaDataset, use this collate function self._collate = self._collate_graph_dataset - def _collate_custom_dataloader(self, batch): + def _collate_pina_dataloader(self, batch): + """ + Function used to create a batch when automatic batching is disabled. + + :param list(int) batch: List of integers representing the indices of + the data points to be fetched. + :return: Dictionary containing the data points fetched from the dataset. + :rtype: dict + """ + # Call the fetch_from_idx_list method of the dataset return self.dataset.fetch_from_idx_list(batch) - def _collate_standard_dataloader(self, batch): + def _collate_torch_dataloader(self, batch): """ Function used to collate the batch """ @@ -112,6 +147,19 @@ class Collator: @staticmethod def _collate_tensor_dataset(data_list): + """ + Function used to collate the data when the dataset is a + `PinaTensorDataset`. + + :param data_list: List of `torch.Tensor` or `LabelTensor` to be + collated. + :type data_list: list(torch.Tensor) | list(LabelTensor) + :raises RuntimeError: If the data is not a `torch.Tensor` or a + `LabelTensor`. + :return: Batch of data + :rtype: dict + """ + if isinstance(data_list[0], LabelTensor): return LabelTensor.stack(data_list) if isinstance(data_list[0], torch.Tensor): @@ -119,15 +167,36 @@ class Collator: raise RuntimeError("Data must be Tensors or LabelTensor ") def _collate_graph_dataset(self, data_list): + """ + Function used to collate the data when the dataset is a + `PinaGraphDataset`. + + :param data_list: List of `Data` or `Graph` to be collated. + :type data_list: list(Data) | list(Graph) + :raises RuntimeError: If the data is not a `Data` or a `Graph`. + :return: Batch of data + :rtype: dict + """ + if isinstance(data_list[0], LabelTensor): return LabelTensor.cat(data_list) if isinstance(data_list[0], torch.Tensor): return torch.cat(data_list) if isinstance(data_list[0], Data): - return self.dataset.create_graph_batch(data_list) + return self.dataset.create_batch(data_list) raise RuntimeError("Data must be Tensors or LabelTensor or pyG Data") def __call__(self, batch): + """ + Call the function to collate the batch, defined in __init__. + + :param batch: list of indices or list of retrieved data + :type batch: list(int) | list(dict) + :return: Dictionary containing the data points fetched from the dataset, + collated. + :rtype: dict + """ + return self.callable_function(batch) @@ -137,6 +206,16 @@ class PinaSampler: """ def __new__(cls, dataset, shuffle): + """ + Create the sampler instance, according to shuffle and whether the + environment is distributed or not. + + :param PinaDataset dataset: The dataset to be sampled. + :param bool shuffle: whether to shuffle the dataset or not before + sampling. + :return: The sampler instance. + :rtype: torch.utils.data.Sampler + """ if ( torch.distributed.is_available() @@ -173,29 +252,24 @@ class PinaDataModule(LightningDataModule): """ Initialize the object, creating datasets based on the input problem. - :param problem: The problem defining the dataset. - :type problem: AbstractProblem - :param train_size: Fraction or number of elements in the training split. - :type train_size: float - :param test_size: Fraction or number of elements in the test split. - :type test_size: float - :param val_size: Fraction or number of elements in the validation split. - :type val_size: float - :param batch_size: Batch size used for training. If None, the entire - dataset is used per batch. - :type batch_size: int or None - :param shuffle: Whether to shuffle the dataset before splitting. - :type shuffle: bool - :param repeat: Whether to repeat the dataset indefinitely. - :type repeat: bool + :param AbstractProblem problem: The problem containing the data on which + to train/test the model. + :param float train_size: Fraction or number of elements in the training + split. + :param float test_size: Fraction or number of elements in the test + split. + :param float val_size: Fraction or number of elements in the validation + split. + :param batch_size: The batch size used for training. If `None`, the + entire dataset is used per batch. + :type batch_size: int | None + :param bool shuffle: Whether to shuffle the dataset before splitting. + :param bool repeat: Whether to repeat the dataset indefinitely. :param automatic_batching: Whether to enable automatic batching. - :type automatic_batching: bool - :param num_workers: Number of worker threads for data loading. + :param int num_workers: Number of worker threads for data loading. Default 0 (serial loading) - :type num_workers: int - :param pin_memory: Whether to use pinned memory for faster data + :param bool pin_memory: Whether to use pinned memory for faster data transfer to GPU. (Default False) - :type pin_memory: bool """ super().__init__() @@ -365,10 +439,14 @@ class PinaDataModule(LightningDataModule): sampler = PinaSampler(dataset, shuffle) if self.automatic_batching: collate = Collator( - self.find_max_conditions_lengths(split), dataset=dataset + self.find_max_conditions_lengths(split), + self.automatic_batching, + dataset=dataset, ) else: - collate = Collator(None, dataset=dataset) + collate = Collator( + None, self.automatic_batching, dataset=dataset + ) return DataLoader( dataset, self.batch_size, @@ -413,23 +491,51 @@ class PinaDataModule(LightningDataModule): def train_dataloader(self): """ Create the training dataloader + + :return: The training dataloader + :rtype: DataLoader """ return self._create_dataloader("train", self.train_dataset) def test_dataloader(self): """ Create the testing dataloader + + :return: The testing dataloader + :rtype: DataLoader """ return self._create_dataloader("test", self.test_dataset) @staticmethod def _transfer_batch_to_device_dummy(batch, device, dataloader_idx): + """ + Transfer the batch to the device. This method is called in the + training loop and is used to transfer the batch to the device. + This method is used when the batch size is None: batch has already + been transferred to the device. + + :param list(tuple) batch: list of tuple where the first element of the + tuple is the condition name and the second element is the data. + :param device: device to which the batch is transferred + :type device: torch.device + :param dataloader_idx: index of the dataloader + :type dataloader_idx: int + :return: The batch transferred to the device. + :rtype: list(tuple) + """ return batch def _transfer_batch_to_device(self, batch, device, dataloader_idx): """ Transfer the batch to the device. This method is called in the training loop and is used to transfer the batch to the device. + + :param dict batch: The batch to be transferred to the device. + :param device: The device to which the batch is transferred. + :type device: torch.device + :param int dataloader_idx: The index of the dataloader. + :return: The batch transferred to the device. + :rtype: list(tuple) """ batch = [ ( @@ -456,7 +562,10 @@ class PinaDataModule(LightningDataModule): @property def input(self): """ - # TODO + Return all the input points coming from all the datasets. + + :return: The input points for training. + :rtype dict """ to_return = {} if hasattr(self, "train_dataset") and self.train_dataset is not None: @@ -464,5 +573,5 @@ class PinaDataModule(LightningDataModule): if hasattr(self, "val_dataset") and self.val_dataset is not None: to_return["val"] = self.val_dataset.input if hasattr(self, "test_dataset") and self.test_dataset is not None: - to_return = self.test_dataset.input + to_return["test"] = self.test_dataset.input return to_return diff --git a/pina/data/dataset.py b/pina/data/dataset.py index 3174b4b..e716621 100644 --- a/pina/data/dataset.py +++ b/pina/data/dataset.py @@ -10,13 +10,31 @@ from ..graph import Graph, LabelBatch class PinaDatasetFactory: """ - Factory class for the PINA dataset. Depending on the type inside the - conditions it creates a different dataset object: - - PinaTensorDataset for torch.Tensor - - PinaGraphDataset for list of torch_geometric.data.Data objects + Factory class for the PINA dataset. + + Depending on the type inside the conditions, it creates a different dataset + object: + + - :class:`PinaTensorDataset` for `torch.Tensor` + - :class:`PinaGraphDataset` for `list` of `torch_geometric.data.Data` + objects """ def __new__(cls, conditions_dict, **kwargs): + """ + Instantiate the appropriate subclass of :class:`PinaDataset`. + + If a graph is present in the conditions, returns a + :class:`PinaGraphDataset`, otherwise returns a + :class:`PinaTensorDataset`. + + :param dict conditions_dict: Dictionary containing the conditions. + :return: A subclass of :class:`PinaDataset`. + :rtype: :class:`PinaTensorDataset` | :class:`PinaGraphDataset` + + :raises ValueError: If an empty dictionary is provided. + """ + # Check if conditions_dict is empty if len(conditions_dict) == 0: raise ValueError("No conditions provided") @@ -31,9 +49,21 @@ class PinaDatasetFactory: @staticmethod def _is_graph_dataset(conditions_dict): + """ + Check if a graph is present in the conditions. + + :param conditions_dict: Dictionary containing the conditions. + :type conditions_dict: dict + :return: True if a graph is present in the conditions, False otherwise + :rtype: bool + """ + + # Iterate over the conditions dictionary for v in conditions_dict.values(): + # Iterate over the values of the current condition for cond in v.values(): - if isinstance(cond, (Data, Graph, list)): + # Check if the current value is a list of Data objects + if isinstance(cond, (Data, Graph, list, tuple)): return True return False @@ -46,6 +76,19 @@ class PinaDataset(Dataset): def __init__( self, conditions_dict, max_conditions_lengths, automatic_batching ): + """ + Initialize the :class:`PinaDataset`. + + Stores the conditions dictionary, the maximum number of conditions to + consider, and the automatic batching flag. + + :param dict conditions_dict: Dictionary containing the conditions. + :param dict max_conditions_lengths: Maximum number of data points to + consider in a single batch for each condition. + :param bool automatic_batching: Whether PyTorch automatic batching is + enabled in :class:`PinaDataModule`. + """ + # Store the conditions dictionary self.conditions_dict = conditions_dict # Store the maximum number of conditions to consider @@ -63,7 +106,13 @@ class PinaDataset(Dataset): self._getitem_func = self._getitem_dummy def _get_max_len(self): - """""" + """ + Returns the length of the longest condition in the dataset + + :return: Length of the longest condition in the dataset + :rtype: int + """ + max_len = 0 for condition in self.conditions_dict.values(): max_len = max(max_len, len(condition["input"])) @@ -76,10 +125,29 @@ class PinaDataset(Dataset): return self._getitem_func(idx) def _getitem_dummy(self, idx): + """ + Return the index itself. This is used when automatic batching is + disabled to postpone the data retrieval to the dataloader. + + :param idx: Index + :type idx: int + :return: Index + :rtype: int + """ + # If automatic batching is disabled, return the data at the given index return idx def _getitem_int(self, idx): + """ + Return the data at the given index in the dataset. This is used when + automatic batching is enabled. + + :param int idx: Index + :return: A dictionary containing the data at the given index + :rtype: dict + """ + # If automatic batching is enabled, return the data at the given index return { k: {k_data: v[k_data][idx % len(v["input"])] for k_data in v.keys()} @@ -121,7 +189,14 @@ class PinaDataset(Dataset): @abstractmethod def _retrive_data(self, data, idx_list): - pass + """ + Retrieve data from the dataset given a list of indices + + :param dict data: Dictionary containing the data + :param list idx_list: List of indices to retrieve + :return: Dictionary containing the data at the given indices + :rtype: dict + """ class PinaTensorDataset(PinaDataset): @@ -131,12 +206,26 @@ class PinaTensorDataset(PinaDataset): # Override _retrive_data method for torch.Tensor data def _retrive_data(self, data, idx_list): + """ + Retrieve data from the dataset given a list of indices + + :param data: Dictionary containing the data + (only torch.Tensor/LableTensor) + :type data: dict + :param list(int) idx_list: indices to retrieve + :return: Dictionary containing the data at the given indices + :rtype: dict + """ + return {k: v[idx_list] for k, v in data.items()} @property def input(self): """ - Method to return input points for training. + Method to return all input points from the dataset. + + :return: Dictionary containing the input points + :rtype: dict """ return {k: v["input"] for k, v in self.conditions_dict.items()} @@ -146,15 +235,33 @@ class PinaGraphDataset(PinaDataset): Class for the PINA dataset with torch_geometric.data.Data data """ - def _create_graph_batch_from_list(self, data): + def _create_graph_batch(self, data): + """ + Create a LabelBatch object from a list of Data objects. + + :param data: List of Data or Graph objects + :type data: list(Data) | list(Graph) + :return: LabelBatch object all the graph collated in a single batch + disconnected graphs. + :rtype: LabelBatch + """ batch = LabelBatch.from_data_list(data) return batch - def _create_output_batch(self, data): + def _create_tensor_batch(self, data): + """ + Create a torch.Tensor object from a list of torch.Tensor objects. + + :param data: torch.Tensor object of shape (N, ...) where N is the + number of data points. + :type data: torch.Tensor | LabelTensor + :return: reshaped torch.Tensor or LabelTensor object + :rtype: torch.Tensor | LabelTensor + """ out = data.reshape(-1, *data.shape[2:]) return out - def create_graph_batch(self, data): + def create_batch(self, data): """ Create a Batch object from a list of Data objects. @@ -163,20 +270,29 @@ class PinaGraphDataset(PinaDataset): :return: Batch object :rtype: Batch or PinaBatch """ + if isinstance(data[0], Data): - return self._create_graph_batch_from_list(data) - return self._create_output_batch(data) + return self._create_graph_batch(data) + return self._create_tensor_batch(data) # Override _retrive_data method for graph handling def _retrive_data(self, data, idx_list): + """ + Retrieve data from the dataset given a list of indices + + :param dict data: dictionary containing the data + :param list idx_list: list of indices to retrieve + :return: dictionary containing the data at the given indices + :rtype: dict + """ # Return the data from the current condition # If the data is a list of Data objects, create a Batch object # If the data is a list of torch.Tensor objects, create a torch.Tensor return { k: ( - self._create_graph_batch_from_list([v[i] for i in idx_list]) + self._create_graph_batch([v[i] for i in idx_list]) if isinstance(v, list) - else self._create_output_batch(v[idx_list]) + else self._create_tensor_batch(v[idx_list]) ) for k, v in data.items() } diff --git a/pina/graph.py b/pina/graph.py index 39d4bcf..d4a5a19 100644 --- a/pina/graph.py +++ b/pina/graph.py @@ -19,6 +19,9 @@ class Graph(Data): **kwargs, ): """ + Instantiates a new instance of the Graph class, performing type + consistency checks. + :param kwargs: Parameters to construct the Graph object. :return: A new instance of the Graph class. :rtype: Graph @@ -42,7 +45,10 @@ class Graph(Data): **kwargs, ): """ - Initialize the Graph object. + Initialize the Graph object by setting the node features, edge index, + edge attributes, and positions. The edge index is preprocessed to make + the graph undirected if required. For more details, see the + :meth: `torch_geometric.data.Data` :param x: Optional tensor of node features (N, F) where F is the number of features per node. @@ -69,6 +75,13 @@ class Graph(Data): ) def _check_type_consistency(self, **kwargs): + """ + Check the consistency of the types of the input data. + + :param kwargs: Attributes to be checked for consistency. + :type kwargs: dict + """ + # default types, specified in cls.__new__, by default they are Nont # if specified in **kwargs they get override x, pos, edge_index, edge_attr = None, None, None, None @@ -92,8 +105,10 @@ class Graph(Data): def _check_pos_consistency(pos): """ Check if the position tensor is consistent. + :param torch.Tensor pos: The position tensor. """ + if pos is not None: check_consistency(pos, (torch.Tensor, LabelTensor)) if pos.ndim != 2: @@ -103,8 +118,10 @@ class Graph(Data): def _check_edge_index_consistency(edge_index): """ Check if the edge index is consistent. + :param torch.Tensor edge_index: The edge index tensor. """ + check_consistency(edge_index, (torch.Tensor, LabelTensor)) if edge_index.ndim != 2: raise ValueError("edge_index must be a 2D tensor.") @@ -114,11 +131,13 @@ class Graph(Data): @staticmethod def _check_edge_attr_consistency(edge_attr, edge_index): """ - Check if the edge attr is consistent. - :param torch.Tensor edge_attr: The edge attribute tensor. + Check if the edge attribute tensor is consistent in type and shape + with the edge index. + :param torch.Tensor edge_attr: The edge attribute tensor. :param torch.Tensor edge_index: The edge index tensor. """ + if edge_attr is not None: check_consistency(edge_attr, (torch.Tensor, LabelTensor)) if edge_attr.ndim != 2: @@ -134,10 +153,13 @@ class Graph(Data): @staticmethod def _check_x_consistency(x, pos=None): """ - Check if the input tensor x is consistent with the position tensor pos. + Check if the input tensor x is consistent with the position tensor + `pos`. + :param torch.Tensor x: The input tensor. :param torch.Tensor pos: The position tensor. """ + if x is not None: check_consistency(x, (torch.Tensor, LabelTensor)) if x.ndim != 2: @@ -152,22 +174,24 @@ class Graph(Data): @staticmethod def _preprocess_edge_index(edge_index, undirected): """ - Preprocess the edge index. + Preprocess the edge index to make the graph undirected (if required). + :param torch.Tensor edge_index: The edge index. :param bool undirected: Whether the graph is undirected. :return: The preprocessed edge index. :rtype: torch.Tensor """ + if undirected: edge_index = to_undirected(edge_index) return edge_index def extract(self, labels, attr="x"): """ - Perform extraction of labels on node features (x) + Perform extraction of labels from the attribute specified by `attr`. :param labels: Labels to extract - :type labels: list[str] | tuple[str] | str + :type labels: list[str] | tuple[str] | str | dict :return: Batch object with extraction performed on x :rtype: PinaBatch """ @@ -193,21 +217,24 @@ class GraphBuilder: **kwargs, ): """ - Creates a new instance of the Graph class. + Compute the edge attributes and create a new instance of the Graph + class. :param pos: A tensor of shape (N, D) representing the positions of N points in D-dimensional space. - :type pos: torch.Tensor | LabelTensor + :type pos: torch.Tensor or LabelTensor :param edge_index: A tensor of shape (2, E) representing the indices of the graph's edges. :type edge_index: torch.Tensor - :param x: Optional tensor of node features (N, F) where F is the number - of features per node. - :type x: torch.Tensor, LabelTensor - :param bool edge_attr: Optional edge attributes (E, F) where F is the - number of features per edge. - :param callable custom_edge_func: A custom function to compute edge - attributes. + :param x: Optional tensor of node features of shape (N, F), where F is + the number of features per node. + :type x: torch.Tensor | LabelTensor, optional + :param edge_attr: Optional tensor of edge attributes of shape (E, F), + where F is the number of features per edge. + :type edge_attr: torch.Tensor, optional + :param custom_edge_func: A custom function to compute edge attributes. + If provided, overrides `edge_attr`. + :type custom_edge_func: callable, optional :param kwargs: Additional keyword arguments passed to the Graph class constructor. :return: A Graph instance constructed using the provided information. @@ -249,18 +276,18 @@ class RadiusGraph(GraphBuilder): def __new__(cls, pos, radius, **kwargs): """ - Creates a new instance of the Graph class using a radius-based graph - construction. + Extends the `GraphBuilder` class to compute edge_index based on a + radius. Each point is connected to all the points within the radius. :param pos: A tensor of shape (N, D) representing the positions of N points in D-dimensional space. - :type pos: torch.Tensor | LabelTensor - :param float radius: The radius within which points are connected. - :Keyword Arguments: - The additional keyword arguments to be passed to GraphBuilder - and Graph classes - :return: Graph instance containg the information passed in input and - the computed edge_index + :type pos: torch.Tensor or LabelTensor + :param radius: The radius within which points are connected. + :type radius: float + :param kwargs: Additional keyword arguments to be passed to the + `GraphBuilder` and `Graph` constructors. + :return: A `Graph` instance containing the input information and the + computed edge_index. :rtype: Graph """ edge_index = cls.compute_radius_graph(pos, radius) @@ -269,7 +296,8 @@ class RadiusGraph(GraphBuilder): @staticmethod def compute_radius_graph(points, radius): """ - Computes a radius-based graph for a given set of points. + Computes edge_index for a given set of points base on the radius. + Each point is connected to all the points within the radius. :param points: A tensor of shape (N, D) representing the positions of N points in D-dimensional space. @@ -295,7 +323,7 @@ class KNNGraph(GraphBuilder): def __new__(cls, pos, neighbours, **kwargs): """ Creates a new instance of the Graph class using k-nearest neighbors - to compute edge_index. + algorithm to define the edges. :param pos: A tensor of shape (N, D) representing the positions of N points in D-dimensional space. @@ -323,8 +351,9 @@ class KNNGraph(GraphBuilder): N points in D-dimensional space. :type points: torch.Tensor | LabelTensor :param int k: The number of nearest neighbors to find for each point. - :rtype torch.Tensor: A tensor of shape (2, E), where E is the number of + :return: A tensor of shape (2, E), where E is the number of edges, representing the edge indices of the KNN graph. + :rtype: torch.Tensor """ dist = torch.cdist(points, points, p=2) @@ -343,6 +372,11 @@ class LabelBatch(Batch): def from_data_list(cls, data_list): """ Create a Batch object from a list of Data objects. + + :param data_list: List of Data/Graph objects + :type data_list: list[Data] | list[Graph] + :return: A Batch object containing the data in the list + :rtype: Batch """ # Store the labels of Data/Graph objects (all data have the same labels) # If the data do not contain labels, labels is an empty dictionary,