Documentation and docstring graph and data
This commit is contained in:
committed by
Nicola Demo
parent
6ce0bafc2b
commit
635e3b3a75
@@ -23,16 +23,24 @@ class DummyDataloader:
|
||||
|
||||
def __init__(self, dataset):
|
||||
"""
|
||||
param dataset: The dataset object to be processed.
|
||||
:notes:
|
||||
- **Distributed Environment**:
|
||||
- Divides the dataset across processes using the
|
||||
rank and world size.
|
||||
- Fetches only the portion of data corresponding to
|
||||
the current process.
|
||||
- **Non-Distributed Environment**:
|
||||
- Fetches the entire dataset.
|
||||
Preprare a dataloader object which will return the entire dataset
|
||||
in a single batch. Depending on the number of GPUs, the dataset we
|
||||
have the following cases:
|
||||
|
||||
- **Distributed Environment** (multiple GPUs):
|
||||
- Divides the dataset across processes using the rank and world
|
||||
size.
|
||||
- Fetches only the portion of data corresponding to the current
|
||||
process.
|
||||
- **Non-Distributed Environment** (single GPU):
|
||||
- Fetches the entire dataset.
|
||||
|
||||
:param dataset: The dataset object to be processed.
|
||||
:type dataset: PinaDataset
|
||||
|
||||
.. note:: This data loader is used when the batch size is None.
|
||||
"""
|
||||
|
||||
if (
|
||||
torch.distributed.is_available()
|
||||
and torch.distributed.is_initialized()
|
||||
@@ -67,23 +75,50 @@ class Collator:
|
||||
Class used to collate the batch
|
||||
"""
|
||||
|
||||
def __init__(self, max_conditions_lengths, dataset=None):
|
||||
def __init__(
|
||||
self, max_conditions_lengths, automatic_batching, dataset=None
|
||||
):
|
||||
"""
|
||||
Initialize the object, setting the collate function based on whether
|
||||
automatic batching is enabled or not.
|
||||
|
||||
:param dict max_conditions_lengths: dict containing the maximum number
|
||||
of data points to consider in a single batch for each condition.
|
||||
:param PinaDataset dataset: The dataset where the data is stored.
|
||||
"""
|
||||
|
||||
self.max_conditions_lengths = max_conditions_lengths
|
||||
# Set the collate function based on the batching strategy
|
||||
# collate_pina_dataloader is used when automatic batching is disabled
|
||||
# collate_torch_dataloader is used when automatic batching is enabled
|
||||
self.callable_function = (
|
||||
self._collate_custom_dataloader
|
||||
if max_conditions_lengths is None
|
||||
else (self._collate_standard_dataloader)
|
||||
self._collate_torch_dataloader
|
||||
if automatic_batching
|
||||
else (self._collate_pina_dataloader)
|
||||
)
|
||||
self.dataset = dataset
|
||||
|
||||
# Set the function which performs the actual collation
|
||||
if isinstance(self.dataset, PinaTensorDataset):
|
||||
# If the dataset is a PinaTensorDataset, use this collate function
|
||||
self._collate = self._collate_tensor_dataset
|
||||
else:
|
||||
# If the dataset is a PinaDataset, use this collate function
|
||||
self._collate = self._collate_graph_dataset
|
||||
|
||||
def _collate_custom_dataloader(self, batch):
|
||||
def _collate_pina_dataloader(self, batch):
|
||||
"""
|
||||
Function used to create a batch when automatic batching is disabled.
|
||||
|
||||
:param list(int) batch: List of integers representing the indices of
|
||||
the data points to be fetched.
|
||||
:return: Dictionary containing the data points fetched from the dataset.
|
||||
:rtype: dict
|
||||
"""
|
||||
# Call the fetch_from_idx_list method of the dataset
|
||||
return self.dataset.fetch_from_idx_list(batch)
|
||||
|
||||
def _collate_standard_dataloader(self, batch):
|
||||
def _collate_torch_dataloader(self, batch):
|
||||
"""
|
||||
Function used to collate the batch
|
||||
"""
|
||||
@@ -112,6 +147,19 @@ class Collator:
|
||||
|
||||
@staticmethod
|
||||
def _collate_tensor_dataset(data_list):
|
||||
"""
|
||||
Function used to collate the data when the dataset is a
|
||||
`PinaTensorDataset`.
|
||||
|
||||
:param data_list: List of `torch.Tensor` or `LabelTensor` to be
|
||||
collated.
|
||||
:type data_list: list(torch.Tensor) | list(LabelTensor)
|
||||
:raises RuntimeError: If the data is not a `torch.Tensor` or a
|
||||
`LabelTensor`.
|
||||
:return: Batch of data
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
if isinstance(data_list[0], LabelTensor):
|
||||
return LabelTensor.stack(data_list)
|
||||
if isinstance(data_list[0], torch.Tensor):
|
||||
@@ -119,15 +167,36 @@ class Collator:
|
||||
raise RuntimeError("Data must be Tensors or LabelTensor ")
|
||||
|
||||
def _collate_graph_dataset(self, data_list):
|
||||
"""
|
||||
Function used to collate the data when the dataset is a
|
||||
`PinaGraphDataset`.
|
||||
|
||||
:param data_list: List of `Data` or `Graph` to be collated.
|
||||
:type data_list: list(Data) | list(Graph)
|
||||
:raises RuntimeError: If the data is not a `Data` or a `Graph`.
|
||||
:return: Batch of data
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
if isinstance(data_list[0], LabelTensor):
|
||||
return LabelTensor.cat(data_list)
|
||||
if isinstance(data_list[0], torch.Tensor):
|
||||
return torch.cat(data_list)
|
||||
if isinstance(data_list[0], Data):
|
||||
return self.dataset.create_graph_batch(data_list)
|
||||
return self.dataset.create_batch(data_list)
|
||||
raise RuntimeError("Data must be Tensors or LabelTensor or pyG Data")
|
||||
|
||||
def __call__(self, batch):
|
||||
"""
|
||||
Call the function to collate the batch, defined in __init__.
|
||||
|
||||
:param batch: list of indices or list of retrieved data
|
||||
:type batch: list(int) | list(dict)
|
||||
:return: Dictionary containing the data points fetched from the dataset,
|
||||
collated.
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
return self.callable_function(batch)
|
||||
|
||||
|
||||
@@ -137,6 +206,16 @@ class PinaSampler:
|
||||
"""
|
||||
|
||||
def __new__(cls, dataset, shuffle):
|
||||
"""
|
||||
Create the sampler instance, according to shuffle and whether the
|
||||
environment is distributed or not.
|
||||
|
||||
:param PinaDataset dataset: The dataset to be sampled.
|
||||
:param bool shuffle: whether to shuffle the dataset or not before
|
||||
sampling.
|
||||
:return: The sampler instance.
|
||||
:rtype: torch.utils.data.Sampler
|
||||
"""
|
||||
|
||||
if (
|
||||
torch.distributed.is_available()
|
||||
@@ -173,29 +252,24 @@ class PinaDataModule(LightningDataModule):
|
||||
"""
|
||||
Initialize the object, creating datasets based on the input problem.
|
||||
|
||||
:param problem: The problem defining the dataset.
|
||||
:type problem: AbstractProblem
|
||||
:param train_size: Fraction or number of elements in the training split.
|
||||
:type train_size: float
|
||||
:param test_size: Fraction or number of elements in the test split.
|
||||
:type test_size: float
|
||||
:param val_size: Fraction or number of elements in the validation split.
|
||||
:type val_size: float
|
||||
:param batch_size: Batch size used for training. If None, the entire
|
||||
dataset is used per batch.
|
||||
:type batch_size: int or None
|
||||
:param shuffle: Whether to shuffle the dataset before splitting.
|
||||
:type shuffle: bool
|
||||
:param repeat: Whether to repeat the dataset indefinitely.
|
||||
:type repeat: bool
|
||||
:param AbstractProblem problem: The problem containing the data on which
|
||||
to train/test the model.
|
||||
:param float train_size: Fraction or number of elements in the training
|
||||
split.
|
||||
:param float test_size: Fraction or number of elements in the test
|
||||
split.
|
||||
:param float val_size: Fraction or number of elements in the validation
|
||||
split.
|
||||
:param batch_size: The batch size used for training. If `None`, the
|
||||
entire dataset is used per batch.
|
||||
:type batch_size: int | None
|
||||
:param bool shuffle: Whether to shuffle the dataset before splitting.
|
||||
:param bool repeat: Whether to repeat the dataset indefinitely.
|
||||
:param automatic_batching: Whether to enable automatic batching.
|
||||
:type automatic_batching: bool
|
||||
:param num_workers: Number of worker threads for data loading.
|
||||
:param int num_workers: Number of worker threads for data loading.
|
||||
Default 0 (serial loading)
|
||||
:type num_workers: int
|
||||
:param pin_memory: Whether to use pinned memory for faster data
|
||||
:param bool pin_memory: Whether to use pinned memory for faster data
|
||||
transfer to GPU. (Default False)
|
||||
:type pin_memory: bool
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@@ -365,10 +439,14 @@ class PinaDataModule(LightningDataModule):
|
||||
sampler = PinaSampler(dataset, shuffle)
|
||||
if self.automatic_batching:
|
||||
collate = Collator(
|
||||
self.find_max_conditions_lengths(split), dataset=dataset
|
||||
self.find_max_conditions_lengths(split),
|
||||
self.automatic_batching,
|
||||
dataset=dataset,
|
||||
)
|
||||
else:
|
||||
collate = Collator(None, dataset=dataset)
|
||||
collate = Collator(
|
||||
None, self.automatic_batching, dataset=dataset
|
||||
)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
self.batch_size,
|
||||
@@ -413,23 +491,51 @@ class PinaDataModule(LightningDataModule):
|
||||
def train_dataloader(self):
|
||||
"""
|
||||
Create the training dataloader
|
||||
|
||||
:return: The training dataloader
|
||||
:rtype: DataLoader
|
||||
"""
|
||||
return self._create_dataloader("train", self.train_dataset)
|
||||
|
||||
def test_dataloader(self):
|
||||
"""
|
||||
Create the testing dataloader
|
||||
|
||||
:return: The testing dataloader
|
||||
:rtype: DataLoader
|
||||
"""
|
||||
return self._create_dataloader("test", self.test_dataset)
|
||||
|
||||
@staticmethod
|
||||
def _transfer_batch_to_device_dummy(batch, device, dataloader_idx):
|
||||
"""
|
||||
Transfer the batch to the device. This method is called in the
|
||||
training loop and is used to transfer the batch to the device.
|
||||
This method is used when the batch size is None: batch has already
|
||||
been transferred to the device.
|
||||
|
||||
:param list(tuple) batch: list of tuple where the first element of the
|
||||
tuple is the condition name and the second element is the data.
|
||||
:param device: device to which the batch is transferred
|
||||
:type device: torch.device
|
||||
:param dataloader_idx: index of the dataloader
|
||||
:type dataloader_idx: int
|
||||
:return: The batch transferred to the device.
|
||||
:rtype: list(tuple)
|
||||
"""
|
||||
return batch
|
||||
|
||||
def _transfer_batch_to_device(self, batch, device, dataloader_idx):
|
||||
"""
|
||||
Transfer the batch to the device. This method is called in the
|
||||
training loop and is used to transfer the batch to the device.
|
||||
|
||||
:param dict batch: The batch to be transferred to the device.
|
||||
:param device: The device to which the batch is transferred.
|
||||
:type device: torch.device
|
||||
:param int dataloader_idx: The index of the dataloader.
|
||||
:return: The batch transferred to the device.
|
||||
:rtype: list(tuple)
|
||||
"""
|
||||
batch = [
|
||||
(
|
||||
@@ -456,7 +562,10 @@ class PinaDataModule(LightningDataModule):
|
||||
@property
|
||||
def input(self):
|
||||
"""
|
||||
# TODO
|
||||
Return all the input points coming from all the datasets.
|
||||
|
||||
:return: The input points for training.
|
||||
:rtype dict
|
||||
"""
|
||||
to_return = {}
|
||||
if hasattr(self, "train_dataset") and self.train_dataset is not None:
|
||||
@@ -464,5 +573,5 @@ class PinaDataModule(LightningDataModule):
|
||||
if hasattr(self, "val_dataset") and self.val_dataset is not None:
|
||||
to_return["val"] = self.val_dataset.input
|
||||
if hasattr(self, "test_dataset") and self.test_dataset is not None:
|
||||
to_return = self.test_dataset.input
|
||||
to_return["test"] = self.test_dataset.input
|
||||
return to_return
|
||||
|
||||
@@ -10,13 +10,31 @@ from ..graph import Graph, LabelBatch
|
||||
|
||||
class PinaDatasetFactory:
|
||||
"""
|
||||
Factory class for the PINA dataset. Depending on the type inside the
|
||||
conditions it creates a different dataset object:
|
||||
- PinaTensorDataset for torch.Tensor
|
||||
- PinaGraphDataset for list of torch_geometric.data.Data objects
|
||||
Factory class for the PINA dataset.
|
||||
|
||||
Depending on the type inside the conditions, it creates a different dataset
|
||||
object:
|
||||
|
||||
- :class:`PinaTensorDataset` for `torch.Tensor`
|
||||
- :class:`PinaGraphDataset` for `list` of `torch_geometric.data.Data`
|
||||
objects
|
||||
"""
|
||||
|
||||
def __new__(cls, conditions_dict, **kwargs):
|
||||
"""
|
||||
Instantiate the appropriate subclass of :class:`PinaDataset`.
|
||||
|
||||
If a graph is present in the conditions, returns a
|
||||
:class:`PinaGraphDataset`, otherwise returns a
|
||||
:class:`PinaTensorDataset`.
|
||||
|
||||
:param dict conditions_dict: Dictionary containing the conditions.
|
||||
:return: A subclass of :class:`PinaDataset`.
|
||||
:rtype: :class:`PinaTensorDataset` | :class:`PinaGraphDataset`
|
||||
|
||||
:raises ValueError: If an empty dictionary is provided.
|
||||
"""
|
||||
|
||||
# Check if conditions_dict is empty
|
||||
if len(conditions_dict) == 0:
|
||||
raise ValueError("No conditions provided")
|
||||
@@ -31,9 +49,21 @@ class PinaDatasetFactory:
|
||||
|
||||
@staticmethod
|
||||
def _is_graph_dataset(conditions_dict):
|
||||
"""
|
||||
Check if a graph is present in the conditions.
|
||||
|
||||
:param conditions_dict: Dictionary containing the conditions.
|
||||
:type conditions_dict: dict
|
||||
:return: True if a graph is present in the conditions, False otherwise
|
||||
:rtype: bool
|
||||
"""
|
||||
|
||||
# Iterate over the conditions dictionary
|
||||
for v in conditions_dict.values():
|
||||
# Iterate over the values of the current condition
|
||||
for cond in v.values():
|
||||
if isinstance(cond, (Data, Graph, list)):
|
||||
# Check if the current value is a list of Data objects
|
||||
if isinstance(cond, (Data, Graph, list, tuple)):
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -46,6 +76,19 @@ class PinaDataset(Dataset):
|
||||
def __init__(
|
||||
self, conditions_dict, max_conditions_lengths, automatic_batching
|
||||
):
|
||||
"""
|
||||
Initialize the :class:`PinaDataset`.
|
||||
|
||||
Stores the conditions dictionary, the maximum number of conditions to
|
||||
consider, and the automatic batching flag.
|
||||
|
||||
:param dict conditions_dict: Dictionary containing the conditions.
|
||||
:param dict max_conditions_lengths: Maximum number of data points to
|
||||
consider in a single batch for each condition.
|
||||
:param bool automatic_batching: Whether PyTorch automatic batching is
|
||||
enabled in :class:`PinaDataModule`.
|
||||
"""
|
||||
|
||||
# Store the conditions dictionary
|
||||
self.conditions_dict = conditions_dict
|
||||
# Store the maximum number of conditions to consider
|
||||
@@ -63,7 +106,13 @@ class PinaDataset(Dataset):
|
||||
self._getitem_func = self._getitem_dummy
|
||||
|
||||
def _get_max_len(self):
|
||||
""""""
|
||||
"""
|
||||
Returns the length of the longest condition in the dataset
|
||||
|
||||
:return: Length of the longest condition in the dataset
|
||||
:rtype: int
|
||||
"""
|
||||
|
||||
max_len = 0
|
||||
for condition in self.conditions_dict.values():
|
||||
max_len = max(max_len, len(condition["input"]))
|
||||
@@ -76,10 +125,29 @@ class PinaDataset(Dataset):
|
||||
return self._getitem_func(idx)
|
||||
|
||||
def _getitem_dummy(self, idx):
|
||||
"""
|
||||
Return the index itself. This is used when automatic batching is
|
||||
disabled to postpone the data retrieval to the dataloader.
|
||||
|
||||
:param idx: Index
|
||||
:type idx: int
|
||||
:return: Index
|
||||
:rtype: int
|
||||
"""
|
||||
|
||||
# If automatic batching is disabled, return the data at the given index
|
||||
return idx
|
||||
|
||||
def _getitem_int(self, idx):
|
||||
"""
|
||||
Return the data at the given index in the dataset. This is used when
|
||||
automatic batching is enabled.
|
||||
|
||||
:param int idx: Index
|
||||
:return: A dictionary containing the data at the given index
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
# If automatic batching is enabled, return the data at the given index
|
||||
return {
|
||||
k: {k_data: v[k_data][idx % len(v["input"])] for k_data in v.keys()}
|
||||
@@ -121,7 +189,14 @@ class PinaDataset(Dataset):
|
||||
|
||||
@abstractmethod
|
||||
def _retrive_data(self, data, idx_list):
|
||||
pass
|
||||
"""
|
||||
Retrieve data from the dataset given a list of indices
|
||||
|
||||
:param dict data: Dictionary containing the data
|
||||
:param list idx_list: List of indices to retrieve
|
||||
:return: Dictionary containing the data at the given indices
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
|
||||
class PinaTensorDataset(PinaDataset):
|
||||
@@ -131,12 +206,26 @@ class PinaTensorDataset(PinaDataset):
|
||||
|
||||
# Override _retrive_data method for torch.Tensor data
|
||||
def _retrive_data(self, data, idx_list):
|
||||
"""
|
||||
Retrieve data from the dataset given a list of indices
|
||||
|
||||
:param data: Dictionary containing the data
|
||||
(only torch.Tensor/LableTensor)
|
||||
:type data: dict
|
||||
:param list(int) idx_list: indices to retrieve
|
||||
:return: Dictionary containing the data at the given indices
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
return {k: v[idx_list] for k, v in data.items()}
|
||||
|
||||
@property
|
||||
def input(self):
|
||||
"""
|
||||
Method to return input points for training.
|
||||
Method to return all input points from the dataset.
|
||||
|
||||
:return: Dictionary containing the input points
|
||||
:rtype: dict
|
||||
"""
|
||||
return {k: v["input"] for k, v in self.conditions_dict.items()}
|
||||
|
||||
@@ -146,15 +235,33 @@ class PinaGraphDataset(PinaDataset):
|
||||
Class for the PINA dataset with torch_geometric.data.Data data
|
||||
"""
|
||||
|
||||
def _create_graph_batch_from_list(self, data):
|
||||
def _create_graph_batch(self, data):
|
||||
"""
|
||||
Create a LabelBatch object from a list of Data objects.
|
||||
|
||||
:param data: List of Data or Graph objects
|
||||
:type data: list(Data) | list(Graph)
|
||||
:return: LabelBatch object all the graph collated in a single batch
|
||||
disconnected graphs.
|
||||
:rtype: LabelBatch
|
||||
"""
|
||||
batch = LabelBatch.from_data_list(data)
|
||||
return batch
|
||||
|
||||
def _create_output_batch(self, data):
|
||||
def _create_tensor_batch(self, data):
|
||||
"""
|
||||
Create a torch.Tensor object from a list of torch.Tensor objects.
|
||||
|
||||
:param data: torch.Tensor object of shape (N, ...) where N is the
|
||||
number of data points.
|
||||
:type data: torch.Tensor | LabelTensor
|
||||
:return: reshaped torch.Tensor or LabelTensor object
|
||||
:rtype: torch.Tensor | LabelTensor
|
||||
"""
|
||||
out = data.reshape(-1, *data.shape[2:])
|
||||
return out
|
||||
|
||||
def create_graph_batch(self, data):
|
||||
def create_batch(self, data):
|
||||
"""
|
||||
Create a Batch object from a list of Data objects.
|
||||
|
||||
@@ -163,20 +270,29 @@ class PinaGraphDataset(PinaDataset):
|
||||
:return: Batch object
|
||||
:rtype: Batch or PinaBatch
|
||||
"""
|
||||
|
||||
if isinstance(data[0], Data):
|
||||
return self._create_graph_batch_from_list(data)
|
||||
return self._create_output_batch(data)
|
||||
return self._create_graph_batch(data)
|
||||
return self._create_tensor_batch(data)
|
||||
|
||||
# Override _retrive_data method for graph handling
|
||||
def _retrive_data(self, data, idx_list):
|
||||
"""
|
||||
Retrieve data from the dataset given a list of indices
|
||||
|
||||
:param dict data: dictionary containing the data
|
||||
:param list idx_list: list of indices to retrieve
|
||||
:return: dictionary containing the data at the given indices
|
||||
:rtype: dict
|
||||
"""
|
||||
# Return the data from the current condition
|
||||
# If the data is a list of Data objects, create a Batch object
|
||||
# If the data is a list of torch.Tensor objects, create a torch.Tensor
|
||||
return {
|
||||
k: (
|
||||
self._create_graph_batch_from_list([v[i] for i in idx_list])
|
||||
self._create_graph_batch([v[i] for i in idx_list])
|
||||
if isinstance(v, list)
|
||||
else self._create_output_batch(v[idx_list])
|
||||
else self._create_tensor_batch(v[idx_list])
|
||||
)
|
||||
for k, v in data.items()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user