refact dataset, dataloader and datamodule

This commit is contained in:
FilippoOlivo
2025-11-12 14:32:56 +01:00
parent f07e59b69b
commit 99e2f07cf7
4 changed files with 375 additions and 460 deletions

View File

@@ -11,203 +11,8 @@ from torch_geometric.data import Data
from torch.utils.data import DataLoader, SequentialSampler, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from ..label_tensor import LabelTensor
from .dataset import PinaDatasetFactory, PinaTensorDataset
class DummyDataloader:
def __init__(self, dataset):
"""
Prepare a dataloader object that returns the entire dataset in a single
batch. Depending on the number of GPUs, the dataset is managed
as follows:
- **Distributed Environment** (multiple GPUs): Divides dataset across
processes using the rank and world size. Fetches only portion of
data corresponding to the current process.
- **Non-Distributed Environment** (single GPU): Fetches the entire
dataset.
:param PinaDataset dataset: The dataset object to be processed.
.. note::
This dataloader is used when the batch size is ``None``.
"""
if (
torch.distributed.is_available()
and torch.distributed.is_initialized()
):
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
if len(dataset) < world_size:
raise RuntimeError(
"Dimension of the dataset smaller than world size."
" Increase the size of the partition or use a single GPU"
)
idx, i = [], rank
while i < len(dataset):
idx.append(i)
i += world_size
self.dataset = dataset.fetch_from_idx_list(idx)
else:
self.dataset = dataset.get_all_data()
def __iter__(self):
return self
def __len__(self):
return 1
def __next__(self):
return self.dataset
class Collator:
"""
This callable class is used to collate the data points fetched from the
dataset. The collation is performed based on the type of dataset used and
on the batching strategy.
"""
def __init__(
self, max_conditions_lengths, automatic_batching, dataset=None
):
"""
Initialize the object, setting the collate function based on whether
automatic batching is enabled or not.
:param dict max_conditions_lengths: ``dict`` containing the maximum
number of data points to consider in a single batch for
each condition.
:param bool automatic_batching: Whether automatic PyTorch batching is
enabled or not. For more information, see the
:class:`~pina.data.data_module.PinaDataModule` class.
:param PinaDataset dataset: The dataset where the data is stored.
"""
self.max_conditions_lengths = max_conditions_lengths
# Set the collate function based on the batching strategy
# collate_pina_dataloader is used when automatic batching is disabled
# collate_torch_dataloader is used when automatic batching is enabled
self.callable_function = (
self._collate_torch_dataloader
if automatic_batching
else (self._collate_pina_dataloader)
)
self.dataset = dataset
# Set the function which performs the actual collation
if isinstance(self.dataset, PinaTensorDataset):
# If the dataset is a PinaTensorDataset, use this collate function
self._collate = self._collate_tensor_dataset
else:
# If the dataset is a PinaDataset, use this collate function
self._collate = self._collate_graph_dataset
def _collate_pina_dataloader(self, batch):
"""
Function used to create a batch when automatic batching is disabled.
:param list[int] batch: List of integers representing the indices of
the data points to be fetched.
:return: Dictionary containing the data points fetched from the dataset.
:rtype: dict
"""
# Call the fetch_from_idx_list method of the dataset
return self.dataset.fetch_from_idx_list(batch)
def _collate_torch_dataloader(self, batch):
"""
Function used to collate the batch
:param list[dict] batch: List of retrieved data.
:return: Dictionary containing the data points fetched from the dataset,
collated.
:rtype: dict
"""
batch_dict = {}
if isinstance(batch, dict):
return batch
conditions_names = batch[0].keys()
# Condition names
for condition_name in conditions_names:
single_cond_dict = {}
condition_args = batch[0][condition_name].keys()
for arg in condition_args:
data_list = [
batch[idx][condition_name][arg]
for idx in range(
min(
len(batch),
self.max_conditions_lengths[condition_name],
)
)
]
single_cond_dict[arg] = self._collate(data_list)
batch_dict[condition_name] = single_cond_dict
return batch_dict
@staticmethod
def _collate_tensor_dataset(data_list):
"""
Function used to collate the data when the dataset is a
:class:`~pina.data.dataset.PinaTensorDataset`.
:param data_list: Elements to be collated.
:type data_list: list[torch.Tensor] | list[LabelTensor]
:return: Batch of data.
:rtype: dict
:raises RuntimeError: If the data is not a :class:`torch.Tensor` or a
:class:`~pina.label_tensor.LabelTensor`.
"""
if isinstance(data_list[0], LabelTensor):
return LabelTensor.stack(data_list)
if isinstance(data_list[0], torch.Tensor):
return torch.stack(data_list)
raise RuntimeError("Data must be Tensors or LabelTensor ")
def _collate_graph_dataset(self, data_list):
"""
Function used to collate data when the dataset is a
:class:`~pina.data.dataset.PinaGraphDataset`.
:param data_list: Elememts to be collated.
:type data_list: list[Data] | list[Graph]
:return: Batch of data.
:rtype: dict
:raises RuntimeError: If the data is not a
:class:`~torch_geometric.data.Data` or a :class:`~pina.graph.Graph`.
"""
if isinstance(data_list[0], LabelTensor):
return LabelTensor.cat(data_list)
if isinstance(data_list[0], torch.Tensor):
return torch.cat(data_list)
if isinstance(data_list[0], Data):
return self.dataset.create_batch(data_list)
raise RuntimeError(
"Data must be Tensors or LabelTensor or pyG "
"torch_geometric.data.Data"
)
def __call__(self, batch):
"""
Perform the collation of data fetched from the dataset. The behavoior
of the function is set based on the batching strategy during class
initialization.
:param batch: List of retrieved data or sampled indices.
:type batch: list[int] | list[dict]
:return: Dictionary containing colleted data fetched from the dataset.
:rtype: dict
"""
return self.callable_function(batch)
from .dataset import PinaDatasetFactory
from .dataloader import PinaDataLoader
class PinaSampler:
@@ -235,6 +40,19 @@ class PinaSampler:
return sampler
def DataloaderCollector():
def __init__(self, dataloader_list):
"""
Initialize the object.
"""
assert isinstance(dataloader_list, list)
assert all(
isinstance(dataloader, DataLoader) for dataloader in dataloader_list
)
self.dataloader_list = dataloader_list
class PinaDataModule(LightningDataModule):
"""
This class extends :class:`~lightning.pytorch.core.LightningDataModule`,
@@ -376,23 +194,23 @@ class PinaDataModule(LightningDataModule):
if stage == "fit" or stage is None:
self.train_dataset = PinaDatasetFactory(
self.data_splits["train"],
max_conditions_lengths=self.find_max_conditions_lengths(
"train"
),
# max_conditions_lengths=self.find_max_conditions_lengths(
# "train"
# ),
automatic_batching=self.automatic_batching,
)
if "val" in self.data_splits.keys():
self.val_dataset = PinaDatasetFactory(
self.data_splits["val"],
max_conditions_lengths=self.find_max_conditions_lengths(
"val"
),
# max_conditions_lengths=self.find_max_conditions_lengths(
# "val"
# ),
automatic_batching=self.automatic_batching,
)
elif stage == "test":
self.test_dataset = PinaDatasetFactory(
self.data_splits["test"],
max_conditions_lengths=self.find_max_conditions_lengths("test"),
# max_conditions_lengths=self.find_max_conditions_lengths("test"),
automatic_batching=self.automatic_batching,
)
else:
@@ -502,32 +320,14 @@ class PinaDataModule(LightningDataModule):
),
module="lightning.pytorch.trainer.connectors.data_connector",
)
# Use custom batching (good if batch size is large)
if self.batch_size is not None:
sampler = PinaSampler(dataset)
if self.automatic_batching:
collate = Collator(
self.find_max_conditions_lengths(split),
self.automatic_batching,
dataset=dataset,
)
else:
collate = Collator(
None, self.automatic_batching, dataset=dataset
)
return DataLoader(
dataset,
self.batch_size,
collate_fn=collate,
sampler=sampler,
num_workers=self.num_workers,
)
dataloader = DummyDataloader(dataset)
dataloader.dataset = self._transfer_batch_to_device(
dataloader.dataset, self.trainer.strategy.root_device, 0
return PinaDataLoader(
dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
collate_fn=None,
common_batch_size=True,
)
self.transfer_batch_to_device = self._transfer_batch_to_device_dummy
return dataloader
def find_max_conditions_lengths(self, split):
"""