Documentation and docstring graph and data

This commit is contained in:
FilippoOlivo
2025-03-10 15:57:15 +01:00
committed by Nicola Demo
parent 6ce0bafc2b
commit 635e3b3a75
3 changed files with 342 additions and 83 deletions

View File

@@ -23,16 +23,24 @@ class DummyDataloader:
def __init__(self, dataset):
"""
param dataset: The dataset object to be processed.
:notes:
- **Distributed Environment**:
- Divides the dataset across processes using the
rank and world size.
- Fetches only the portion of data corresponding to
the current process.
- **Non-Distributed Environment**:
- Fetches the entire dataset.
Preprare a dataloader object which will return the entire dataset
in a single batch. Depending on the number of GPUs, the dataset we
have the following cases:
- **Distributed Environment** (multiple GPUs):
- Divides the dataset across processes using the rank and world
size.
- Fetches only the portion of data corresponding to the current
process.
- **Non-Distributed Environment** (single GPU):
- Fetches the entire dataset.
:param dataset: The dataset object to be processed.
:type dataset: PinaDataset
.. note:: This data loader is used when the batch size is None.
"""
if (
torch.distributed.is_available()
and torch.distributed.is_initialized()
@@ -67,23 +75,50 @@ class Collator:
Class used to collate the batch
"""
def __init__(self, max_conditions_lengths, dataset=None):
def __init__(
self, max_conditions_lengths, automatic_batching, dataset=None
):
"""
Initialize the object, setting the collate function based on whether
automatic batching is enabled or not.
:param dict max_conditions_lengths: dict containing the maximum number
of data points to consider in a single batch for each condition.
:param PinaDataset dataset: The dataset where the data is stored.
"""
self.max_conditions_lengths = max_conditions_lengths
# Set the collate function based on the batching strategy
# collate_pina_dataloader is used when automatic batching is disabled
# collate_torch_dataloader is used when automatic batching is enabled
self.callable_function = (
self._collate_custom_dataloader
if max_conditions_lengths is None
else (self._collate_standard_dataloader)
self._collate_torch_dataloader
if automatic_batching
else (self._collate_pina_dataloader)
)
self.dataset = dataset
# Set the function which performs the actual collation
if isinstance(self.dataset, PinaTensorDataset):
# If the dataset is a PinaTensorDataset, use this collate function
self._collate = self._collate_tensor_dataset
else:
# If the dataset is a PinaDataset, use this collate function
self._collate = self._collate_graph_dataset
def _collate_custom_dataloader(self, batch):
def _collate_pina_dataloader(self, batch):
"""
Function used to create a batch when automatic batching is disabled.
:param list(int) batch: List of integers representing the indices of
the data points to be fetched.
:return: Dictionary containing the data points fetched from the dataset.
:rtype: dict
"""
# Call the fetch_from_idx_list method of the dataset
return self.dataset.fetch_from_idx_list(batch)
def _collate_standard_dataloader(self, batch):
def _collate_torch_dataloader(self, batch):
"""
Function used to collate the batch
"""
@@ -112,6 +147,19 @@ class Collator:
@staticmethod
def _collate_tensor_dataset(data_list):
"""
Function used to collate the data when the dataset is a
`PinaTensorDataset`.
:param data_list: List of `torch.Tensor` or `LabelTensor` to be
collated.
:type data_list: list(torch.Tensor) | list(LabelTensor)
:raises RuntimeError: If the data is not a `torch.Tensor` or a
`LabelTensor`.
:return: Batch of data
:rtype: dict
"""
if isinstance(data_list[0], LabelTensor):
return LabelTensor.stack(data_list)
if isinstance(data_list[0], torch.Tensor):
@@ -119,15 +167,36 @@ class Collator:
raise RuntimeError("Data must be Tensors or LabelTensor ")
def _collate_graph_dataset(self, data_list):
"""
Function used to collate the data when the dataset is a
`PinaGraphDataset`.
:param data_list: List of `Data` or `Graph` to be collated.
:type data_list: list(Data) | list(Graph)
:raises RuntimeError: If the data is not a `Data` or a `Graph`.
:return: Batch of data
:rtype: dict
"""
if isinstance(data_list[0], LabelTensor):
return LabelTensor.cat(data_list)
if isinstance(data_list[0], torch.Tensor):
return torch.cat(data_list)
if isinstance(data_list[0], Data):
return self.dataset.create_graph_batch(data_list)
return self.dataset.create_batch(data_list)
raise RuntimeError("Data must be Tensors or LabelTensor or pyG Data")
def __call__(self, batch):
"""
Call the function to collate the batch, defined in __init__.
:param batch: list of indices or list of retrieved data
:type batch: list(int) | list(dict)
:return: Dictionary containing the data points fetched from the dataset,
collated.
:rtype: dict
"""
return self.callable_function(batch)
@@ -137,6 +206,16 @@ class PinaSampler:
"""
def __new__(cls, dataset, shuffle):
"""
Create the sampler instance, according to shuffle and whether the
environment is distributed or not.
:param PinaDataset dataset: The dataset to be sampled.
:param bool shuffle: whether to shuffle the dataset or not before
sampling.
:return: The sampler instance.
:rtype: torch.utils.data.Sampler
"""
if (
torch.distributed.is_available()
@@ -173,29 +252,24 @@ class PinaDataModule(LightningDataModule):
"""
Initialize the object, creating datasets based on the input problem.
:param problem: The problem defining the dataset.
:type problem: AbstractProblem
:param train_size: Fraction or number of elements in the training split.
:type train_size: float
:param test_size: Fraction or number of elements in the test split.
:type test_size: float
:param val_size: Fraction or number of elements in the validation split.
:type val_size: float
:param batch_size: Batch size used for training. If None, the entire
dataset is used per batch.
:type batch_size: int or None
:param shuffle: Whether to shuffle the dataset before splitting.
:type shuffle: bool
:param repeat: Whether to repeat the dataset indefinitely.
:type repeat: bool
:param AbstractProblem problem: The problem containing the data on which
to train/test the model.
:param float train_size: Fraction or number of elements in the training
split.
:param float test_size: Fraction or number of elements in the test
split.
:param float val_size: Fraction or number of elements in the validation
split.
:param batch_size: The batch size used for training. If `None`, the
entire dataset is used per batch.
:type batch_size: int | None
:param bool shuffle: Whether to shuffle the dataset before splitting.
:param bool repeat: Whether to repeat the dataset indefinitely.
:param automatic_batching: Whether to enable automatic batching.
:type automatic_batching: bool
:param num_workers: Number of worker threads for data loading.
:param int num_workers: Number of worker threads for data loading.
Default 0 (serial loading)
:type num_workers: int
:param pin_memory: Whether to use pinned memory for faster data
:param bool pin_memory: Whether to use pinned memory for faster data
transfer to GPU. (Default False)
:type pin_memory: bool
"""
super().__init__()
@@ -365,10 +439,14 @@ class PinaDataModule(LightningDataModule):
sampler = PinaSampler(dataset, shuffle)
if self.automatic_batching:
collate = Collator(
self.find_max_conditions_lengths(split), dataset=dataset
self.find_max_conditions_lengths(split),
self.automatic_batching,
dataset=dataset,
)
else:
collate = Collator(None, dataset=dataset)
collate = Collator(
None, self.automatic_batching, dataset=dataset
)
return DataLoader(
dataset,
self.batch_size,
@@ -413,23 +491,51 @@ class PinaDataModule(LightningDataModule):
def train_dataloader(self):
"""
Create the training dataloader
:return: The training dataloader
:rtype: DataLoader
"""
return self._create_dataloader("train", self.train_dataset)
def test_dataloader(self):
"""
Create the testing dataloader
:return: The testing dataloader
:rtype: DataLoader
"""
return self._create_dataloader("test", self.test_dataset)
@staticmethod
def _transfer_batch_to_device_dummy(batch, device, dataloader_idx):
"""
Transfer the batch to the device. This method is called in the
training loop and is used to transfer the batch to the device.
This method is used when the batch size is None: batch has already
been transferred to the device.
:param list(tuple) batch: list of tuple where the first element of the
tuple is the condition name and the second element is the data.
:param device: device to which the batch is transferred
:type device: torch.device
:param dataloader_idx: index of the dataloader
:type dataloader_idx: int
:return: The batch transferred to the device.
:rtype: list(tuple)
"""
return batch
def _transfer_batch_to_device(self, batch, device, dataloader_idx):
"""
Transfer the batch to the device. This method is called in the
training loop and is used to transfer the batch to the device.
:param dict batch: The batch to be transferred to the device.
:param device: The device to which the batch is transferred.
:type device: torch.device
:param int dataloader_idx: The index of the dataloader.
:return: The batch transferred to the device.
:rtype: list(tuple)
"""
batch = [
(
@@ -456,7 +562,10 @@ class PinaDataModule(LightningDataModule):
@property
def input(self):
"""
# TODO
Return all the input points coming from all the datasets.
:return: The input points for training.
:rtype dict
"""
to_return = {}
if hasattr(self, "train_dataset") and self.train_dataset is not None:
@@ -464,5 +573,5 @@ class PinaDataModule(LightningDataModule):
if hasattr(self, "val_dataset") and self.val_dataset is not None:
to_return["val"] = self.val_dataset.input
if hasattr(self, "test_dataset") and self.test_dataset is not None:
to_return = self.test_dataset.input
to_return["test"] = self.test_dataset.input
return to_return

View File

@@ -10,13 +10,31 @@ from ..graph import Graph, LabelBatch
class PinaDatasetFactory:
"""
Factory class for the PINA dataset. Depending on the type inside the
conditions it creates a different dataset object:
- PinaTensorDataset for torch.Tensor
- PinaGraphDataset for list of torch_geometric.data.Data objects
Factory class for the PINA dataset.
Depending on the type inside the conditions, it creates a different dataset
object:
- :class:`PinaTensorDataset` for `torch.Tensor`
- :class:`PinaGraphDataset` for `list` of `torch_geometric.data.Data`
objects
"""
def __new__(cls, conditions_dict, **kwargs):
"""
Instantiate the appropriate subclass of :class:`PinaDataset`.
If a graph is present in the conditions, returns a
:class:`PinaGraphDataset`, otherwise returns a
:class:`PinaTensorDataset`.
:param dict conditions_dict: Dictionary containing the conditions.
:return: A subclass of :class:`PinaDataset`.
:rtype: :class:`PinaTensorDataset` | :class:`PinaGraphDataset`
:raises ValueError: If an empty dictionary is provided.
"""
# Check if conditions_dict is empty
if len(conditions_dict) == 0:
raise ValueError("No conditions provided")
@@ -31,9 +49,21 @@ class PinaDatasetFactory:
@staticmethod
def _is_graph_dataset(conditions_dict):
"""
Check if a graph is present in the conditions.
:param conditions_dict: Dictionary containing the conditions.
:type conditions_dict: dict
:return: True if a graph is present in the conditions, False otherwise
:rtype: bool
"""
# Iterate over the conditions dictionary
for v in conditions_dict.values():
# Iterate over the values of the current condition
for cond in v.values():
if isinstance(cond, (Data, Graph, list)):
# Check if the current value is a list of Data objects
if isinstance(cond, (Data, Graph, list, tuple)):
return True
return False
@@ -46,6 +76,19 @@ class PinaDataset(Dataset):
def __init__(
self, conditions_dict, max_conditions_lengths, automatic_batching
):
"""
Initialize the :class:`PinaDataset`.
Stores the conditions dictionary, the maximum number of conditions to
consider, and the automatic batching flag.
:param dict conditions_dict: Dictionary containing the conditions.
:param dict max_conditions_lengths: Maximum number of data points to
consider in a single batch for each condition.
:param bool automatic_batching: Whether PyTorch automatic batching is
enabled in :class:`PinaDataModule`.
"""
# Store the conditions dictionary
self.conditions_dict = conditions_dict
# Store the maximum number of conditions to consider
@@ -63,7 +106,13 @@ class PinaDataset(Dataset):
self._getitem_func = self._getitem_dummy
def _get_max_len(self):
""""""
"""
Returns the length of the longest condition in the dataset
:return: Length of the longest condition in the dataset
:rtype: int
"""
max_len = 0
for condition in self.conditions_dict.values():
max_len = max(max_len, len(condition["input"]))
@@ -76,10 +125,29 @@ class PinaDataset(Dataset):
return self._getitem_func(idx)
def _getitem_dummy(self, idx):
"""
Return the index itself. This is used when automatic batching is
disabled to postpone the data retrieval to the dataloader.
:param idx: Index
:type idx: int
:return: Index
:rtype: int
"""
# If automatic batching is disabled, return the data at the given index
return idx
def _getitem_int(self, idx):
"""
Return the data at the given index in the dataset. This is used when
automatic batching is enabled.
:param int idx: Index
:return: A dictionary containing the data at the given index
:rtype: dict
"""
# If automatic batching is enabled, return the data at the given index
return {
k: {k_data: v[k_data][idx % len(v["input"])] for k_data in v.keys()}
@@ -121,7 +189,14 @@ class PinaDataset(Dataset):
@abstractmethod
def _retrive_data(self, data, idx_list):
pass
"""
Retrieve data from the dataset given a list of indices
:param dict data: Dictionary containing the data
:param list idx_list: List of indices to retrieve
:return: Dictionary containing the data at the given indices
:rtype: dict
"""
class PinaTensorDataset(PinaDataset):
@@ -131,12 +206,26 @@ class PinaTensorDataset(PinaDataset):
# Override _retrive_data method for torch.Tensor data
def _retrive_data(self, data, idx_list):
"""
Retrieve data from the dataset given a list of indices
:param data: Dictionary containing the data
(only torch.Tensor/LableTensor)
:type data: dict
:param list(int) idx_list: indices to retrieve
:return: Dictionary containing the data at the given indices
:rtype: dict
"""
return {k: v[idx_list] for k, v in data.items()}
@property
def input(self):
"""
Method to return input points for training.
Method to return all input points from the dataset.
:return: Dictionary containing the input points
:rtype: dict
"""
return {k: v["input"] for k, v in self.conditions_dict.items()}
@@ -146,15 +235,33 @@ class PinaGraphDataset(PinaDataset):
Class for the PINA dataset with torch_geometric.data.Data data
"""
def _create_graph_batch_from_list(self, data):
def _create_graph_batch(self, data):
"""
Create a LabelBatch object from a list of Data objects.
:param data: List of Data or Graph objects
:type data: list(Data) | list(Graph)
:return: LabelBatch object all the graph collated in a single batch
disconnected graphs.
:rtype: LabelBatch
"""
batch = LabelBatch.from_data_list(data)
return batch
def _create_output_batch(self, data):
def _create_tensor_batch(self, data):
"""
Create a torch.Tensor object from a list of torch.Tensor objects.
:param data: torch.Tensor object of shape (N, ...) where N is the
number of data points.
:type data: torch.Tensor | LabelTensor
:return: reshaped torch.Tensor or LabelTensor object
:rtype: torch.Tensor | LabelTensor
"""
out = data.reshape(-1, *data.shape[2:])
return out
def create_graph_batch(self, data):
def create_batch(self, data):
"""
Create a Batch object from a list of Data objects.
@@ -163,20 +270,29 @@ class PinaGraphDataset(PinaDataset):
:return: Batch object
:rtype: Batch or PinaBatch
"""
if isinstance(data[0], Data):
return self._create_graph_batch_from_list(data)
return self._create_output_batch(data)
return self._create_graph_batch(data)
return self._create_tensor_batch(data)
# Override _retrive_data method for graph handling
def _retrive_data(self, data, idx_list):
"""
Retrieve data from the dataset given a list of indices
:param dict data: dictionary containing the data
:param list idx_list: list of indices to retrieve
:return: dictionary containing the data at the given indices
:rtype: dict
"""
# Return the data from the current condition
# If the data is a list of Data objects, create a Batch object
# If the data is a list of torch.Tensor objects, create a torch.Tensor
return {
k: (
self._create_graph_batch_from_list([v[i] for i in idx_list])
self._create_graph_batch([v[i] for i in idx_list])
if isinstance(v, list)
else self._create_output_batch(v[idx_list])
else self._create_tensor_batch(v[idx_list])
)
for k, v in data.items()
}