Documentation and docstring graph and data
This commit is contained in:
committed by
Nicola Demo
parent
6ce0bafc2b
commit
635e3b3a75
@@ -23,16 +23,24 @@ class DummyDataloader:
|
||||
|
||||
def __init__(self, dataset):
|
||||
"""
|
||||
param dataset: The dataset object to be processed.
|
||||
:notes:
|
||||
- **Distributed Environment**:
|
||||
- Divides the dataset across processes using the
|
||||
rank and world size.
|
||||
- Fetches only the portion of data corresponding to
|
||||
the current process.
|
||||
- **Non-Distributed Environment**:
|
||||
- Fetches the entire dataset.
|
||||
Preprare a dataloader object which will return the entire dataset
|
||||
in a single batch. Depending on the number of GPUs, the dataset we
|
||||
have the following cases:
|
||||
|
||||
- **Distributed Environment** (multiple GPUs):
|
||||
- Divides the dataset across processes using the rank and world
|
||||
size.
|
||||
- Fetches only the portion of data corresponding to the current
|
||||
process.
|
||||
- **Non-Distributed Environment** (single GPU):
|
||||
- Fetches the entire dataset.
|
||||
|
||||
:param dataset: The dataset object to be processed.
|
||||
:type dataset: PinaDataset
|
||||
|
||||
.. note:: This data loader is used when the batch size is None.
|
||||
"""
|
||||
|
||||
if (
|
||||
torch.distributed.is_available()
|
||||
and torch.distributed.is_initialized()
|
||||
@@ -67,23 +75,50 @@ class Collator:
|
||||
Class used to collate the batch
|
||||
"""
|
||||
|
||||
def __init__(self, max_conditions_lengths, dataset=None):
|
||||
def __init__(
|
||||
self, max_conditions_lengths, automatic_batching, dataset=None
|
||||
):
|
||||
"""
|
||||
Initialize the object, setting the collate function based on whether
|
||||
automatic batching is enabled or not.
|
||||
|
||||
:param dict max_conditions_lengths: dict containing the maximum number
|
||||
of data points to consider in a single batch for each condition.
|
||||
:param PinaDataset dataset: The dataset where the data is stored.
|
||||
"""
|
||||
|
||||
self.max_conditions_lengths = max_conditions_lengths
|
||||
# Set the collate function based on the batching strategy
|
||||
# collate_pina_dataloader is used when automatic batching is disabled
|
||||
# collate_torch_dataloader is used when automatic batching is enabled
|
||||
self.callable_function = (
|
||||
self._collate_custom_dataloader
|
||||
if max_conditions_lengths is None
|
||||
else (self._collate_standard_dataloader)
|
||||
self._collate_torch_dataloader
|
||||
if automatic_batching
|
||||
else (self._collate_pina_dataloader)
|
||||
)
|
||||
self.dataset = dataset
|
||||
|
||||
# Set the function which performs the actual collation
|
||||
if isinstance(self.dataset, PinaTensorDataset):
|
||||
# If the dataset is a PinaTensorDataset, use this collate function
|
||||
self._collate = self._collate_tensor_dataset
|
||||
else:
|
||||
# If the dataset is a PinaDataset, use this collate function
|
||||
self._collate = self._collate_graph_dataset
|
||||
|
||||
def _collate_custom_dataloader(self, batch):
|
||||
def _collate_pina_dataloader(self, batch):
|
||||
"""
|
||||
Function used to create a batch when automatic batching is disabled.
|
||||
|
||||
:param list(int) batch: List of integers representing the indices of
|
||||
the data points to be fetched.
|
||||
:return: Dictionary containing the data points fetched from the dataset.
|
||||
:rtype: dict
|
||||
"""
|
||||
# Call the fetch_from_idx_list method of the dataset
|
||||
return self.dataset.fetch_from_idx_list(batch)
|
||||
|
||||
def _collate_standard_dataloader(self, batch):
|
||||
def _collate_torch_dataloader(self, batch):
|
||||
"""
|
||||
Function used to collate the batch
|
||||
"""
|
||||
@@ -112,6 +147,19 @@ class Collator:
|
||||
|
||||
@staticmethod
|
||||
def _collate_tensor_dataset(data_list):
|
||||
"""
|
||||
Function used to collate the data when the dataset is a
|
||||
`PinaTensorDataset`.
|
||||
|
||||
:param data_list: List of `torch.Tensor` or `LabelTensor` to be
|
||||
collated.
|
||||
:type data_list: list(torch.Tensor) | list(LabelTensor)
|
||||
:raises RuntimeError: If the data is not a `torch.Tensor` or a
|
||||
`LabelTensor`.
|
||||
:return: Batch of data
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
if isinstance(data_list[0], LabelTensor):
|
||||
return LabelTensor.stack(data_list)
|
||||
if isinstance(data_list[0], torch.Tensor):
|
||||
@@ -119,15 +167,36 @@ class Collator:
|
||||
raise RuntimeError("Data must be Tensors or LabelTensor ")
|
||||
|
||||
def _collate_graph_dataset(self, data_list):
|
||||
"""
|
||||
Function used to collate the data when the dataset is a
|
||||
`PinaGraphDataset`.
|
||||
|
||||
:param data_list: List of `Data` or `Graph` to be collated.
|
||||
:type data_list: list(Data) | list(Graph)
|
||||
:raises RuntimeError: If the data is not a `Data` or a `Graph`.
|
||||
:return: Batch of data
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
if isinstance(data_list[0], LabelTensor):
|
||||
return LabelTensor.cat(data_list)
|
||||
if isinstance(data_list[0], torch.Tensor):
|
||||
return torch.cat(data_list)
|
||||
if isinstance(data_list[0], Data):
|
||||
return self.dataset.create_graph_batch(data_list)
|
||||
return self.dataset.create_batch(data_list)
|
||||
raise RuntimeError("Data must be Tensors or LabelTensor or pyG Data")
|
||||
|
||||
def __call__(self, batch):
|
||||
"""
|
||||
Call the function to collate the batch, defined in __init__.
|
||||
|
||||
:param batch: list of indices or list of retrieved data
|
||||
:type batch: list(int) | list(dict)
|
||||
:return: Dictionary containing the data points fetched from the dataset,
|
||||
collated.
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
return self.callable_function(batch)
|
||||
|
||||
|
||||
@@ -137,6 +206,16 @@ class PinaSampler:
|
||||
"""
|
||||
|
||||
def __new__(cls, dataset, shuffle):
|
||||
"""
|
||||
Create the sampler instance, according to shuffle and whether the
|
||||
environment is distributed or not.
|
||||
|
||||
:param PinaDataset dataset: The dataset to be sampled.
|
||||
:param bool shuffle: whether to shuffle the dataset or not before
|
||||
sampling.
|
||||
:return: The sampler instance.
|
||||
:rtype: torch.utils.data.Sampler
|
||||
"""
|
||||
|
||||
if (
|
||||
torch.distributed.is_available()
|
||||
@@ -173,29 +252,24 @@ class PinaDataModule(LightningDataModule):
|
||||
"""
|
||||
Initialize the object, creating datasets based on the input problem.
|
||||
|
||||
:param problem: The problem defining the dataset.
|
||||
:type problem: AbstractProblem
|
||||
:param train_size: Fraction or number of elements in the training split.
|
||||
:type train_size: float
|
||||
:param test_size: Fraction or number of elements in the test split.
|
||||
:type test_size: float
|
||||
:param val_size: Fraction or number of elements in the validation split.
|
||||
:type val_size: float
|
||||
:param batch_size: Batch size used for training. If None, the entire
|
||||
dataset is used per batch.
|
||||
:type batch_size: int or None
|
||||
:param shuffle: Whether to shuffle the dataset before splitting.
|
||||
:type shuffle: bool
|
||||
:param repeat: Whether to repeat the dataset indefinitely.
|
||||
:type repeat: bool
|
||||
:param AbstractProblem problem: The problem containing the data on which
|
||||
to train/test the model.
|
||||
:param float train_size: Fraction or number of elements in the training
|
||||
split.
|
||||
:param float test_size: Fraction or number of elements in the test
|
||||
split.
|
||||
:param float val_size: Fraction or number of elements in the validation
|
||||
split.
|
||||
:param batch_size: The batch size used for training. If `None`, the
|
||||
entire dataset is used per batch.
|
||||
:type batch_size: int | None
|
||||
:param bool shuffle: Whether to shuffle the dataset before splitting.
|
||||
:param bool repeat: Whether to repeat the dataset indefinitely.
|
||||
:param automatic_batching: Whether to enable automatic batching.
|
||||
:type automatic_batching: bool
|
||||
:param num_workers: Number of worker threads for data loading.
|
||||
:param int num_workers: Number of worker threads for data loading.
|
||||
Default 0 (serial loading)
|
||||
:type num_workers: int
|
||||
:param pin_memory: Whether to use pinned memory for faster data
|
||||
:param bool pin_memory: Whether to use pinned memory for faster data
|
||||
transfer to GPU. (Default False)
|
||||
:type pin_memory: bool
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@@ -365,10 +439,14 @@ class PinaDataModule(LightningDataModule):
|
||||
sampler = PinaSampler(dataset, shuffle)
|
||||
if self.automatic_batching:
|
||||
collate = Collator(
|
||||
self.find_max_conditions_lengths(split), dataset=dataset
|
||||
self.find_max_conditions_lengths(split),
|
||||
self.automatic_batching,
|
||||
dataset=dataset,
|
||||
)
|
||||
else:
|
||||
collate = Collator(None, dataset=dataset)
|
||||
collate = Collator(
|
||||
None, self.automatic_batching, dataset=dataset
|
||||
)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
self.batch_size,
|
||||
@@ -413,23 +491,51 @@ class PinaDataModule(LightningDataModule):
|
||||
def train_dataloader(self):
|
||||
"""
|
||||
Create the training dataloader
|
||||
|
||||
:return: The training dataloader
|
||||
:rtype: DataLoader
|
||||
"""
|
||||
return self._create_dataloader("train", self.train_dataset)
|
||||
|
||||
def test_dataloader(self):
|
||||
"""
|
||||
Create the testing dataloader
|
||||
|
||||
:return: The testing dataloader
|
||||
:rtype: DataLoader
|
||||
"""
|
||||
return self._create_dataloader("test", self.test_dataset)
|
||||
|
||||
@staticmethod
|
||||
def _transfer_batch_to_device_dummy(batch, device, dataloader_idx):
|
||||
"""
|
||||
Transfer the batch to the device. This method is called in the
|
||||
training loop and is used to transfer the batch to the device.
|
||||
This method is used when the batch size is None: batch has already
|
||||
been transferred to the device.
|
||||
|
||||
:param list(tuple) batch: list of tuple where the first element of the
|
||||
tuple is the condition name and the second element is the data.
|
||||
:param device: device to which the batch is transferred
|
||||
:type device: torch.device
|
||||
:param dataloader_idx: index of the dataloader
|
||||
:type dataloader_idx: int
|
||||
:return: The batch transferred to the device.
|
||||
:rtype: list(tuple)
|
||||
"""
|
||||
return batch
|
||||
|
||||
def _transfer_batch_to_device(self, batch, device, dataloader_idx):
|
||||
"""
|
||||
Transfer the batch to the device. This method is called in the
|
||||
training loop and is used to transfer the batch to the device.
|
||||
|
||||
:param dict batch: The batch to be transferred to the device.
|
||||
:param device: The device to which the batch is transferred.
|
||||
:type device: torch.device
|
||||
:param int dataloader_idx: The index of the dataloader.
|
||||
:return: The batch transferred to the device.
|
||||
:rtype: list(tuple)
|
||||
"""
|
||||
batch = [
|
||||
(
|
||||
@@ -456,7 +562,10 @@ class PinaDataModule(LightningDataModule):
|
||||
@property
|
||||
def input(self):
|
||||
"""
|
||||
# TODO
|
||||
Return all the input points coming from all the datasets.
|
||||
|
||||
:return: The input points for training.
|
||||
:rtype dict
|
||||
"""
|
||||
to_return = {}
|
||||
if hasattr(self, "train_dataset") and self.train_dataset is not None:
|
||||
@@ -464,5 +573,5 @@ class PinaDataModule(LightningDataModule):
|
||||
if hasattr(self, "val_dataset") and self.val_dataset is not None:
|
||||
to_return["val"] = self.val_dataset.input
|
||||
if hasattr(self, "test_dataset") and self.test_dataset is not None:
|
||||
to_return = self.test_dataset.input
|
||||
to_return["test"] = self.test_dataset.input
|
||||
return to_return
|
||||
|
||||
Reference in New Issue
Block a user