Documentation and docstring graph and data

This commit is contained in:
FilippoOlivo
2025-03-10 15:57:15 +01:00
committed by Nicola Demo
parent 6ce0bafc2b
commit 635e3b3a75
3 changed files with 342 additions and 83 deletions

View File

@@ -23,16 +23,24 @@ class DummyDataloader:
def __init__(self, dataset):
"""
param dataset: The dataset object to be processed.
:notes:
- **Distributed Environment**:
- Divides the dataset across processes using the
rank and world size.
- Fetches only the portion of data corresponding to
the current process.
- **Non-Distributed Environment**:
- Fetches the entire dataset.
Preprare a dataloader object which will return the entire dataset
in a single batch. Depending on the number of GPUs, the dataset we
have the following cases:
- **Distributed Environment** (multiple GPUs):
- Divides the dataset across processes using the rank and world
size.
- Fetches only the portion of data corresponding to the current
process.
- **Non-Distributed Environment** (single GPU):
- Fetches the entire dataset.
:param dataset: The dataset object to be processed.
:type dataset: PinaDataset
.. note:: This data loader is used when the batch size is None.
"""
if (
torch.distributed.is_available()
and torch.distributed.is_initialized()
@@ -67,23 +75,50 @@ class Collator:
Class used to collate the batch
"""
def __init__(self, max_conditions_lengths, dataset=None):
def __init__(
self, max_conditions_lengths, automatic_batching, dataset=None
):
"""
Initialize the object, setting the collate function based on whether
automatic batching is enabled or not.
:param dict max_conditions_lengths: dict containing the maximum number
of data points to consider in a single batch for each condition.
:param PinaDataset dataset: The dataset where the data is stored.
"""
self.max_conditions_lengths = max_conditions_lengths
# Set the collate function based on the batching strategy
# collate_pina_dataloader is used when automatic batching is disabled
# collate_torch_dataloader is used when automatic batching is enabled
self.callable_function = (
self._collate_custom_dataloader
if max_conditions_lengths is None
else (self._collate_standard_dataloader)
self._collate_torch_dataloader
if automatic_batching
else (self._collate_pina_dataloader)
)
self.dataset = dataset
# Set the function which performs the actual collation
if isinstance(self.dataset, PinaTensorDataset):
# If the dataset is a PinaTensorDataset, use this collate function
self._collate = self._collate_tensor_dataset
else:
# If the dataset is a PinaDataset, use this collate function
self._collate = self._collate_graph_dataset
def _collate_custom_dataloader(self, batch):
def _collate_pina_dataloader(self, batch):
"""
Function used to create a batch when automatic batching is disabled.
:param list(int) batch: List of integers representing the indices of
the data points to be fetched.
:return: Dictionary containing the data points fetched from the dataset.
:rtype: dict
"""
# Call the fetch_from_idx_list method of the dataset
return self.dataset.fetch_from_idx_list(batch)
def _collate_standard_dataloader(self, batch):
def _collate_torch_dataloader(self, batch):
"""
Function used to collate the batch
"""
@@ -112,6 +147,19 @@ class Collator:
@staticmethod
def _collate_tensor_dataset(data_list):
"""
Function used to collate the data when the dataset is a
`PinaTensorDataset`.
:param data_list: List of `torch.Tensor` or `LabelTensor` to be
collated.
:type data_list: list(torch.Tensor) | list(LabelTensor)
:raises RuntimeError: If the data is not a `torch.Tensor` or a
`LabelTensor`.
:return: Batch of data
:rtype: dict
"""
if isinstance(data_list[0], LabelTensor):
return LabelTensor.stack(data_list)
if isinstance(data_list[0], torch.Tensor):
@@ -119,15 +167,36 @@ class Collator:
raise RuntimeError("Data must be Tensors or LabelTensor ")
def _collate_graph_dataset(self, data_list):
"""
Function used to collate the data when the dataset is a
`PinaGraphDataset`.
:param data_list: List of `Data` or `Graph` to be collated.
:type data_list: list(Data) | list(Graph)
:raises RuntimeError: If the data is not a `Data` or a `Graph`.
:return: Batch of data
:rtype: dict
"""
if isinstance(data_list[0], LabelTensor):
return LabelTensor.cat(data_list)
if isinstance(data_list[0], torch.Tensor):
return torch.cat(data_list)
if isinstance(data_list[0], Data):
return self.dataset.create_graph_batch(data_list)
return self.dataset.create_batch(data_list)
raise RuntimeError("Data must be Tensors or LabelTensor or pyG Data")
def __call__(self, batch):
"""
Call the function to collate the batch, defined in __init__.
:param batch: list of indices or list of retrieved data
:type batch: list(int) | list(dict)
:return: Dictionary containing the data points fetched from the dataset,
collated.
:rtype: dict
"""
return self.callable_function(batch)
@@ -137,6 +206,16 @@ class PinaSampler:
"""
def __new__(cls, dataset, shuffle):
"""
Create the sampler instance, according to shuffle and whether the
environment is distributed or not.
:param PinaDataset dataset: The dataset to be sampled.
:param bool shuffle: whether to shuffle the dataset or not before
sampling.
:return: The sampler instance.
:rtype: torch.utils.data.Sampler
"""
if (
torch.distributed.is_available()
@@ -173,29 +252,24 @@ class PinaDataModule(LightningDataModule):
"""
Initialize the object, creating datasets based on the input problem.
:param problem: The problem defining the dataset.
:type problem: AbstractProblem
:param train_size: Fraction or number of elements in the training split.
:type train_size: float
:param test_size: Fraction or number of elements in the test split.
:type test_size: float
:param val_size: Fraction or number of elements in the validation split.
:type val_size: float
:param batch_size: Batch size used for training. If None, the entire
dataset is used per batch.
:type batch_size: int or None
:param shuffle: Whether to shuffle the dataset before splitting.
:type shuffle: bool
:param repeat: Whether to repeat the dataset indefinitely.
:type repeat: bool
:param AbstractProblem problem: The problem containing the data on which
to train/test the model.
:param float train_size: Fraction or number of elements in the training
split.
:param float test_size: Fraction or number of elements in the test
split.
:param float val_size: Fraction or number of elements in the validation
split.
:param batch_size: The batch size used for training. If `None`, the
entire dataset is used per batch.
:type batch_size: int | None
:param bool shuffle: Whether to shuffle the dataset before splitting.
:param bool repeat: Whether to repeat the dataset indefinitely.
:param automatic_batching: Whether to enable automatic batching.
:type automatic_batching: bool
:param num_workers: Number of worker threads for data loading.
:param int num_workers: Number of worker threads for data loading.
Default 0 (serial loading)
:type num_workers: int
:param pin_memory: Whether to use pinned memory for faster data
:param bool pin_memory: Whether to use pinned memory for faster data
transfer to GPU. (Default False)
:type pin_memory: bool
"""
super().__init__()
@@ -365,10 +439,14 @@ class PinaDataModule(LightningDataModule):
sampler = PinaSampler(dataset, shuffle)
if self.automatic_batching:
collate = Collator(
self.find_max_conditions_lengths(split), dataset=dataset
self.find_max_conditions_lengths(split),
self.automatic_batching,
dataset=dataset,
)
else:
collate = Collator(None, dataset=dataset)
collate = Collator(
None, self.automatic_batching, dataset=dataset
)
return DataLoader(
dataset,
self.batch_size,
@@ -413,23 +491,51 @@ class PinaDataModule(LightningDataModule):
def train_dataloader(self):
"""
Create the training dataloader
:return: The training dataloader
:rtype: DataLoader
"""
return self._create_dataloader("train", self.train_dataset)
def test_dataloader(self):
"""
Create the testing dataloader
:return: The testing dataloader
:rtype: DataLoader
"""
return self._create_dataloader("test", self.test_dataset)
@staticmethod
def _transfer_batch_to_device_dummy(batch, device, dataloader_idx):
"""
Transfer the batch to the device. This method is called in the
training loop and is used to transfer the batch to the device.
This method is used when the batch size is None: batch has already
been transferred to the device.
:param list(tuple) batch: list of tuple where the first element of the
tuple is the condition name and the second element is the data.
:param device: device to which the batch is transferred
:type device: torch.device
:param dataloader_idx: index of the dataloader
:type dataloader_idx: int
:return: The batch transferred to the device.
:rtype: list(tuple)
"""
return batch
def _transfer_batch_to_device(self, batch, device, dataloader_idx):
"""
Transfer the batch to the device. This method is called in the
training loop and is used to transfer the batch to the device.
:param dict batch: The batch to be transferred to the device.
:param device: The device to which the batch is transferred.
:type device: torch.device
:param int dataloader_idx: The index of the dataloader.
:return: The batch transferred to the device.
:rtype: list(tuple)
"""
batch = [
(
@@ -456,7 +562,10 @@ class PinaDataModule(LightningDataModule):
@property
def input(self):
"""
# TODO
Return all the input points coming from all the datasets.
:return: The input points for training.
:rtype dict
"""
to_return = {}
if hasattr(self, "train_dataset") and self.train_dataset is not None:
@@ -464,5 +573,5 @@ class PinaDataModule(LightningDataModule):
if hasattr(self, "val_dataset") and self.val_dataset is not None:
to_return["val"] = self.val_dataset.input
if hasattr(self, "test_dataset") and self.test_dataset is not None:
to_return = self.test_dataset.input
to_return["test"] = self.test_dataset.input
return to_return