Documentation and docstring graph and data

This commit is contained in:
FilippoOlivo
2025-03-10 15:57:15 +01:00
committed by Nicola Demo
parent 6ce0bafc2b
commit 635e3b3a75
3 changed files with 342 additions and 83 deletions

View File

@@ -23,16 +23,24 @@ class DummyDataloader:
def __init__(self, dataset):
"""
param dataset: The dataset object to be processed.
:notes:
- **Distributed Environment**:
- Divides the dataset across processes using the
rank and world size.
- Fetches only the portion of data corresponding to
the current process.
- **Non-Distributed Environment**:
Preprare a dataloader object which will return the entire dataset
in a single batch. Depending on the number of GPUs, the dataset we
have the following cases:
- **Distributed Environment** (multiple GPUs):
- Divides the dataset across processes using the rank and world
size.
- Fetches only the portion of data corresponding to the current
process.
- **Non-Distributed Environment** (single GPU):
- Fetches the entire dataset.
:param dataset: The dataset object to be processed.
:type dataset: PinaDataset
.. note:: This data loader is used when the batch size is None.
"""
if (
torch.distributed.is_available()
and torch.distributed.is_initialized()
@@ -67,23 +75,50 @@ class Collator:
Class used to collate the batch
"""
def __init__(self, max_conditions_lengths, dataset=None):
def __init__(
self, max_conditions_lengths, automatic_batching, dataset=None
):
"""
Initialize the object, setting the collate function based on whether
automatic batching is enabled or not.
:param dict max_conditions_lengths: dict containing the maximum number
of data points to consider in a single batch for each condition.
:param PinaDataset dataset: The dataset where the data is stored.
"""
self.max_conditions_lengths = max_conditions_lengths
# Set the collate function based on the batching strategy
# collate_pina_dataloader is used when automatic batching is disabled
# collate_torch_dataloader is used when automatic batching is enabled
self.callable_function = (
self._collate_custom_dataloader
if max_conditions_lengths is None
else (self._collate_standard_dataloader)
self._collate_torch_dataloader
if automatic_batching
else (self._collate_pina_dataloader)
)
self.dataset = dataset
# Set the function which performs the actual collation
if isinstance(self.dataset, PinaTensorDataset):
# If the dataset is a PinaTensorDataset, use this collate function
self._collate = self._collate_tensor_dataset
else:
# If the dataset is a PinaDataset, use this collate function
self._collate = self._collate_graph_dataset
def _collate_custom_dataloader(self, batch):
def _collate_pina_dataloader(self, batch):
"""
Function used to create a batch when automatic batching is disabled.
:param list(int) batch: List of integers representing the indices of
the data points to be fetched.
:return: Dictionary containing the data points fetched from the dataset.
:rtype: dict
"""
# Call the fetch_from_idx_list method of the dataset
return self.dataset.fetch_from_idx_list(batch)
def _collate_standard_dataloader(self, batch):
def _collate_torch_dataloader(self, batch):
"""
Function used to collate the batch
"""
@@ -112,6 +147,19 @@ class Collator:
@staticmethod
def _collate_tensor_dataset(data_list):
"""
Function used to collate the data when the dataset is a
`PinaTensorDataset`.
:param data_list: List of `torch.Tensor` or `LabelTensor` to be
collated.
:type data_list: list(torch.Tensor) | list(LabelTensor)
:raises RuntimeError: If the data is not a `torch.Tensor` or a
`LabelTensor`.
:return: Batch of data
:rtype: dict
"""
if isinstance(data_list[0], LabelTensor):
return LabelTensor.stack(data_list)
if isinstance(data_list[0], torch.Tensor):
@@ -119,15 +167,36 @@ class Collator:
raise RuntimeError("Data must be Tensors or LabelTensor ")
def _collate_graph_dataset(self, data_list):
"""
Function used to collate the data when the dataset is a
`PinaGraphDataset`.
:param data_list: List of `Data` or `Graph` to be collated.
:type data_list: list(Data) | list(Graph)
:raises RuntimeError: If the data is not a `Data` or a `Graph`.
:return: Batch of data
:rtype: dict
"""
if isinstance(data_list[0], LabelTensor):
return LabelTensor.cat(data_list)
if isinstance(data_list[0], torch.Tensor):
return torch.cat(data_list)
if isinstance(data_list[0], Data):
return self.dataset.create_graph_batch(data_list)
return self.dataset.create_batch(data_list)
raise RuntimeError("Data must be Tensors or LabelTensor or pyG Data")
def __call__(self, batch):
"""
Call the function to collate the batch, defined in __init__.
:param batch: list of indices or list of retrieved data
:type batch: list(int) | list(dict)
:return: Dictionary containing the data points fetched from the dataset,
collated.
:rtype: dict
"""
return self.callable_function(batch)
@@ -137,6 +206,16 @@ class PinaSampler:
"""
def __new__(cls, dataset, shuffle):
"""
Create the sampler instance, according to shuffle and whether the
environment is distributed or not.
:param PinaDataset dataset: The dataset to be sampled.
:param bool shuffle: whether to shuffle the dataset or not before
sampling.
:return: The sampler instance.
:rtype: torch.utils.data.Sampler
"""
if (
torch.distributed.is_available()
@@ -173,29 +252,24 @@ class PinaDataModule(LightningDataModule):
"""
Initialize the object, creating datasets based on the input problem.
:param problem: The problem defining the dataset.
:type problem: AbstractProblem
:param train_size: Fraction or number of elements in the training split.
:type train_size: float
:param test_size: Fraction or number of elements in the test split.
:type test_size: float
:param val_size: Fraction or number of elements in the validation split.
:type val_size: float
:param batch_size: Batch size used for training. If None, the entire
dataset is used per batch.
:type batch_size: int or None
:param shuffle: Whether to shuffle the dataset before splitting.
:type shuffle: bool
:param repeat: Whether to repeat the dataset indefinitely.
:type repeat: bool
:param AbstractProblem problem: The problem containing the data on which
to train/test the model.
:param float train_size: Fraction or number of elements in the training
split.
:param float test_size: Fraction or number of elements in the test
split.
:param float val_size: Fraction or number of elements in the validation
split.
:param batch_size: The batch size used for training. If `None`, the
entire dataset is used per batch.
:type batch_size: int | None
:param bool shuffle: Whether to shuffle the dataset before splitting.
:param bool repeat: Whether to repeat the dataset indefinitely.
:param automatic_batching: Whether to enable automatic batching.
:type automatic_batching: bool
:param num_workers: Number of worker threads for data loading.
:param int num_workers: Number of worker threads for data loading.
Default 0 (serial loading)
:type num_workers: int
:param pin_memory: Whether to use pinned memory for faster data
:param bool pin_memory: Whether to use pinned memory for faster data
transfer to GPU. (Default False)
:type pin_memory: bool
"""
super().__init__()
@@ -365,10 +439,14 @@ class PinaDataModule(LightningDataModule):
sampler = PinaSampler(dataset, shuffle)
if self.automatic_batching:
collate = Collator(
self.find_max_conditions_lengths(split), dataset=dataset
self.find_max_conditions_lengths(split),
self.automatic_batching,
dataset=dataset,
)
else:
collate = Collator(None, dataset=dataset)
collate = Collator(
None, self.automatic_batching, dataset=dataset
)
return DataLoader(
dataset,
self.batch_size,
@@ -413,23 +491,51 @@ class PinaDataModule(LightningDataModule):
def train_dataloader(self):
"""
Create the training dataloader
:return: The training dataloader
:rtype: DataLoader
"""
return self._create_dataloader("train", self.train_dataset)
def test_dataloader(self):
"""
Create the testing dataloader
:return: The testing dataloader
:rtype: DataLoader
"""
return self._create_dataloader("test", self.test_dataset)
@staticmethod
def _transfer_batch_to_device_dummy(batch, device, dataloader_idx):
"""
Transfer the batch to the device. This method is called in the
training loop and is used to transfer the batch to the device.
This method is used when the batch size is None: batch has already
been transferred to the device.
:param list(tuple) batch: list of tuple where the first element of the
tuple is the condition name and the second element is the data.
:param device: device to which the batch is transferred
:type device: torch.device
:param dataloader_idx: index of the dataloader
:type dataloader_idx: int
:return: The batch transferred to the device.
:rtype: list(tuple)
"""
return batch
def _transfer_batch_to_device(self, batch, device, dataloader_idx):
"""
Transfer the batch to the device. This method is called in the
training loop and is used to transfer the batch to the device.
:param dict batch: The batch to be transferred to the device.
:param device: The device to which the batch is transferred.
:type device: torch.device
:param int dataloader_idx: The index of the dataloader.
:return: The batch transferred to the device.
:rtype: list(tuple)
"""
batch = [
(
@@ -456,7 +562,10 @@ class PinaDataModule(LightningDataModule):
@property
def input(self):
"""
# TODO
Return all the input points coming from all the datasets.
:return: The input points for training.
:rtype dict
"""
to_return = {}
if hasattr(self, "train_dataset") and self.train_dataset is not None:
@@ -464,5 +573,5 @@ class PinaDataModule(LightningDataModule):
if hasattr(self, "val_dataset") and self.val_dataset is not None:
to_return["val"] = self.val_dataset.input
if hasattr(self, "test_dataset") and self.test_dataset is not None:
to_return = self.test_dataset.input
to_return["test"] = self.test_dataset.input
return to_return

View File

@@ -10,13 +10,31 @@ from ..graph import Graph, LabelBatch
class PinaDatasetFactory:
"""
Factory class for the PINA dataset. Depending on the type inside the
conditions it creates a different dataset object:
- PinaTensorDataset for torch.Tensor
- PinaGraphDataset for list of torch_geometric.data.Data objects
Factory class for the PINA dataset.
Depending on the type inside the conditions, it creates a different dataset
object:
- :class:`PinaTensorDataset` for `torch.Tensor`
- :class:`PinaGraphDataset` for `list` of `torch_geometric.data.Data`
objects
"""
def __new__(cls, conditions_dict, **kwargs):
"""
Instantiate the appropriate subclass of :class:`PinaDataset`.
If a graph is present in the conditions, returns a
:class:`PinaGraphDataset`, otherwise returns a
:class:`PinaTensorDataset`.
:param dict conditions_dict: Dictionary containing the conditions.
:return: A subclass of :class:`PinaDataset`.
:rtype: :class:`PinaTensorDataset` | :class:`PinaGraphDataset`
:raises ValueError: If an empty dictionary is provided.
"""
# Check if conditions_dict is empty
if len(conditions_dict) == 0:
raise ValueError("No conditions provided")
@@ -31,9 +49,21 @@ class PinaDatasetFactory:
@staticmethod
def _is_graph_dataset(conditions_dict):
"""
Check if a graph is present in the conditions.
:param conditions_dict: Dictionary containing the conditions.
:type conditions_dict: dict
:return: True if a graph is present in the conditions, False otherwise
:rtype: bool
"""
# Iterate over the conditions dictionary
for v in conditions_dict.values():
# Iterate over the values of the current condition
for cond in v.values():
if isinstance(cond, (Data, Graph, list)):
# Check if the current value is a list of Data objects
if isinstance(cond, (Data, Graph, list, tuple)):
return True
return False
@@ -46,6 +76,19 @@ class PinaDataset(Dataset):
def __init__(
self, conditions_dict, max_conditions_lengths, automatic_batching
):
"""
Initialize the :class:`PinaDataset`.
Stores the conditions dictionary, the maximum number of conditions to
consider, and the automatic batching flag.
:param dict conditions_dict: Dictionary containing the conditions.
:param dict max_conditions_lengths: Maximum number of data points to
consider in a single batch for each condition.
:param bool automatic_batching: Whether PyTorch automatic batching is
enabled in :class:`PinaDataModule`.
"""
# Store the conditions dictionary
self.conditions_dict = conditions_dict
# Store the maximum number of conditions to consider
@@ -63,7 +106,13 @@ class PinaDataset(Dataset):
self._getitem_func = self._getitem_dummy
def _get_max_len(self):
""""""
"""
Returns the length of the longest condition in the dataset
:return: Length of the longest condition in the dataset
:rtype: int
"""
max_len = 0
for condition in self.conditions_dict.values():
max_len = max(max_len, len(condition["input"]))
@@ -76,10 +125,29 @@ class PinaDataset(Dataset):
return self._getitem_func(idx)
def _getitem_dummy(self, idx):
"""
Return the index itself. This is used when automatic batching is
disabled to postpone the data retrieval to the dataloader.
:param idx: Index
:type idx: int
:return: Index
:rtype: int
"""
# If automatic batching is disabled, return the data at the given index
return idx
def _getitem_int(self, idx):
"""
Return the data at the given index in the dataset. This is used when
automatic batching is enabled.
:param int idx: Index
:return: A dictionary containing the data at the given index
:rtype: dict
"""
# If automatic batching is enabled, return the data at the given index
return {
k: {k_data: v[k_data][idx % len(v["input"])] for k_data in v.keys()}
@@ -121,7 +189,14 @@ class PinaDataset(Dataset):
@abstractmethod
def _retrive_data(self, data, idx_list):
pass
"""
Retrieve data from the dataset given a list of indices
:param dict data: Dictionary containing the data
:param list idx_list: List of indices to retrieve
:return: Dictionary containing the data at the given indices
:rtype: dict
"""
class PinaTensorDataset(PinaDataset):
@@ -131,12 +206,26 @@ class PinaTensorDataset(PinaDataset):
# Override _retrive_data method for torch.Tensor data
def _retrive_data(self, data, idx_list):
"""
Retrieve data from the dataset given a list of indices
:param data: Dictionary containing the data
(only torch.Tensor/LableTensor)
:type data: dict
:param list(int) idx_list: indices to retrieve
:return: Dictionary containing the data at the given indices
:rtype: dict
"""
return {k: v[idx_list] for k, v in data.items()}
@property
def input(self):
"""
Method to return input points for training.
Method to return all input points from the dataset.
:return: Dictionary containing the input points
:rtype: dict
"""
return {k: v["input"] for k, v in self.conditions_dict.items()}
@@ -146,15 +235,33 @@ class PinaGraphDataset(PinaDataset):
Class for the PINA dataset with torch_geometric.data.Data data
"""
def _create_graph_batch_from_list(self, data):
def _create_graph_batch(self, data):
"""
Create a LabelBatch object from a list of Data objects.
:param data: List of Data or Graph objects
:type data: list(Data) | list(Graph)
:return: LabelBatch object all the graph collated in a single batch
disconnected graphs.
:rtype: LabelBatch
"""
batch = LabelBatch.from_data_list(data)
return batch
def _create_output_batch(self, data):
def _create_tensor_batch(self, data):
"""
Create a torch.Tensor object from a list of torch.Tensor objects.
:param data: torch.Tensor object of shape (N, ...) where N is the
number of data points.
:type data: torch.Tensor | LabelTensor
:return: reshaped torch.Tensor or LabelTensor object
:rtype: torch.Tensor | LabelTensor
"""
out = data.reshape(-1, *data.shape[2:])
return out
def create_graph_batch(self, data):
def create_batch(self, data):
"""
Create a Batch object from a list of Data objects.
@@ -163,20 +270,29 @@ class PinaGraphDataset(PinaDataset):
:return: Batch object
:rtype: Batch or PinaBatch
"""
if isinstance(data[0], Data):
return self._create_graph_batch_from_list(data)
return self._create_output_batch(data)
return self._create_graph_batch(data)
return self._create_tensor_batch(data)
# Override _retrive_data method for graph handling
def _retrive_data(self, data, idx_list):
"""
Retrieve data from the dataset given a list of indices
:param dict data: dictionary containing the data
:param list idx_list: list of indices to retrieve
:return: dictionary containing the data at the given indices
:rtype: dict
"""
# Return the data from the current condition
# If the data is a list of Data objects, create a Batch object
# If the data is a list of torch.Tensor objects, create a torch.Tensor
return {
k: (
self._create_graph_batch_from_list([v[i] for i in idx_list])
self._create_graph_batch([v[i] for i in idx_list])
if isinstance(v, list)
else self._create_output_batch(v[idx_list])
else self._create_tensor_batch(v[idx_list])
)
for k, v in data.items()
}

View File

@@ -19,6 +19,9 @@ class Graph(Data):
**kwargs,
):
"""
Instantiates a new instance of the Graph class, performing type
consistency checks.
:param kwargs: Parameters to construct the Graph object.
:return: A new instance of the Graph class.
:rtype: Graph
@@ -42,7 +45,10 @@ class Graph(Data):
**kwargs,
):
"""
Initialize the Graph object.
Initialize the Graph object by setting the node features, edge index,
edge attributes, and positions. The edge index is preprocessed to make
the graph undirected if required. For more details, see the
:meth: `torch_geometric.data.Data`
:param x: Optional tensor of node features (N, F) where F is the number
of features per node.
@@ -69,6 +75,13 @@ class Graph(Data):
)
def _check_type_consistency(self, **kwargs):
"""
Check the consistency of the types of the input data.
:param kwargs: Attributes to be checked for consistency.
:type kwargs: dict
"""
# default types, specified in cls.__new__, by default they are Nont
# if specified in **kwargs they get override
x, pos, edge_index, edge_attr = None, None, None, None
@@ -92,8 +105,10 @@ class Graph(Data):
def _check_pos_consistency(pos):
"""
Check if the position tensor is consistent.
:param torch.Tensor pos: The position tensor.
"""
if pos is not None:
check_consistency(pos, (torch.Tensor, LabelTensor))
if pos.ndim != 2:
@@ -103,8 +118,10 @@ class Graph(Data):
def _check_edge_index_consistency(edge_index):
"""
Check if the edge index is consistent.
:param torch.Tensor edge_index: The edge index tensor.
"""
check_consistency(edge_index, (torch.Tensor, LabelTensor))
if edge_index.ndim != 2:
raise ValueError("edge_index must be a 2D tensor.")
@@ -114,11 +131,13 @@ class Graph(Data):
@staticmethod
def _check_edge_attr_consistency(edge_attr, edge_index):
"""
Check if the edge attr is consistent.
:param torch.Tensor edge_attr: The edge attribute tensor.
Check if the edge attribute tensor is consistent in type and shape
with the edge index.
:param torch.Tensor edge_attr: The edge attribute tensor.
:param torch.Tensor edge_index: The edge index tensor.
"""
if edge_attr is not None:
check_consistency(edge_attr, (torch.Tensor, LabelTensor))
if edge_attr.ndim != 2:
@@ -134,10 +153,13 @@ class Graph(Data):
@staticmethod
def _check_x_consistency(x, pos=None):
"""
Check if the input tensor x is consistent with the position tensor pos.
Check if the input tensor x is consistent with the position tensor
`pos`.
:param torch.Tensor x: The input tensor.
:param torch.Tensor pos: The position tensor.
"""
if x is not None:
check_consistency(x, (torch.Tensor, LabelTensor))
if x.ndim != 2:
@@ -152,22 +174,24 @@ class Graph(Data):
@staticmethod
def _preprocess_edge_index(edge_index, undirected):
"""
Preprocess the edge index.
Preprocess the edge index to make the graph undirected (if required).
:param torch.Tensor edge_index: The edge index.
:param bool undirected: Whether the graph is undirected.
:return: The preprocessed edge index.
:rtype: torch.Tensor
"""
if undirected:
edge_index = to_undirected(edge_index)
return edge_index
def extract(self, labels, attr="x"):
"""
Perform extraction of labels on node features (x)
Perform extraction of labels from the attribute specified by `attr`.
:param labels: Labels to extract
:type labels: list[str] | tuple[str] | str
:type labels: list[str] | tuple[str] | str | dict
:return: Batch object with extraction performed on x
:rtype: PinaBatch
"""
@@ -193,21 +217,24 @@ class GraphBuilder:
**kwargs,
):
"""
Creates a new instance of the Graph class.
Compute the edge attributes and create a new instance of the Graph
class.
:param pos: A tensor of shape (N, D) representing the positions of N
points in D-dimensional space.
:type pos: torch.Tensor | LabelTensor
:type pos: torch.Tensor or LabelTensor
:param edge_index: A tensor of shape (2, E) representing the indices of
the graph's edges.
:type edge_index: torch.Tensor
:param x: Optional tensor of node features (N, F) where F is the number
of features per node.
:type x: torch.Tensor, LabelTensor
:param bool edge_attr: Optional edge attributes (E, F) where F is the
number of features per edge.
:param callable custom_edge_func: A custom function to compute edge
attributes.
:param x: Optional tensor of node features of shape (N, F), where F is
the number of features per node.
:type x: torch.Tensor | LabelTensor, optional
:param edge_attr: Optional tensor of edge attributes of shape (E, F),
where F is the number of features per edge.
:type edge_attr: torch.Tensor, optional
:param custom_edge_func: A custom function to compute edge attributes.
If provided, overrides `edge_attr`.
:type custom_edge_func: callable, optional
:param kwargs: Additional keyword arguments passed to the Graph class
constructor.
:return: A Graph instance constructed using the provided information.
@@ -249,18 +276,18 @@ class RadiusGraph(GraphBuilder):
def __new__(cls, pos, radius, **kwargs):
"""
Creates a new instance of the Graph class using a radius-based graph
construction.
Extends the `GraphBuilder` class to compute edge_index based on a
radius. Each point is connected to all the points within the radius.
:param pos: A tensor of shape (N, D) representing the positions of N
points in D-dimensional space.
:type pos: torch.Tensor | LabelTensor
:param float radius: The radius within which points are connected.
:Keyword Arguments:
The additional keyword arguments to be passed to GraphBuilder
and Graph classes
:return: Graph instance containg the information passed in input and
the computed edge_index
:type pos: torch.Tensor or LabelTensor
:param radius: The radius within which points are connected.
:type radius: float
:param kwargs: Additional keyword arguments to be passed to the
`GraphBuilder` and `Graph` constructors.
:return: A `Graph` instance containing the input information and the
computed edge_index.
:rtype: Graph
"""
edge_index = cls.compute_radius_graph(pos, radius)
@@ -269,7 +296,8 @@ class RadiusGraph(GraphBuilder):
@staticmethod
def compute_radius_graph(points, radius):
"""
Computes a radius-based graph for a given set of points.
Computes edge_index for a given set of points base on the radius.
Each point is connected to all the points within the radius.
:param points: A tensor of shape (N, D) representing the positions of
N points in D-dimensional space.
@@ -295,7 +323,7 @@ class KNNGraph(GraphBuilder):
def __new__(cls, pos, neighbours, **kwargs):
"""
Creates a new instance of the Graph class using k-nearest neighbors
to compute edge_index.
algorithm to define the edges.
:param pos: A tensor of shape (N, D) representing the positions of N
points in D-dimensional space.
@@ -323,8 +351,9 @@ class KNNGraph(GraphBuilder):
N points in D-dimensional space.
:type points: torch.Tensor | LabelTensor
:param int k: The number of nearest neighbors to find for each point.
:rtype torch.Tensor: A tensor of shape (2, E), where E is the number of
:return: A tensor of shape (2, E), where E is the number of
edges, representing the edge indices of the KNN graph.
:rtype: torch.Tensor
"""
dist = torch.cdist(points, points, p=2)
@@ -343,6 +372,11 @@ class LabelBatch(Batch):
def from_data_list(cls, data_list):
"""
Create a Batch object from a list of Data objects.
:param data_list: List of Data/Graph objects
:type data_list: list[Data] | list[Graph]
:return: A Batch object containing the data in the list
:rtype: Batch
"""
# Store the labels of Data/Graph objects (all data have the same labels)
# If the data do not contain labels, labels is an empty dictionary,