From 3ea05e845d41a181c4071e6e2a936be2f4a5d8c5 Mon Sep 17 00:00:00 2001 From: Filippo Olivo Date: Thu, 23 Jan 2025 12:04:31 +0100 Subject: [PATCH] Improve DataLoader performance when batch_size=None (#423) --- pina/data/data_module.py | 67 +++++++++++++++++++++++++--------------- pina/data/dataset.py | 4 +++ 2 files changed, 46 insertions(+), 25 deletions(-) diff --git a/pina/data/data_module.py b/pina/data/data_module.py index 4831e20..4b529fe 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -8,6 +8,19 @@ from torch.utils.data import DataLoader, BatchSampler, SequentialSampler, \ from torch.utils.data.distributed import DistributedSampler from .dataset import PinaDatasetFactory +class DummyDataloader: + def __init__(self, dataset, device): + self.dataset = dataset.get_all_data() + + def __iter__(self): + return self + + def __len__(self): + return 1 + + def __next__(self): + return self.dataset + class Collator: def __init__(self, max_conditions_lengths, ): self.max_conditions_lengths = max_conditions_lengths @@ -232,40 +245,41 @@ class PinaDataModule(LightningDataModule): """ Create the validation dataloader """ - - batch_size = self.batch_size if self.batch_size is not None else len( - self.val_dataset) - - # Use default batching in torch DataLoader (good is batch size is small) - if self.automatic_batching: - collate = Collator(self.find_max_conditions_lengths('val')) - return DataLoader(self.val_dataset, batch_size, - collate_fn=collate) - collate = Collator(None) # Use custom batching (good if batch size is large) - sampler = PinaBatchSampler(self.val_dataset, batch_size, shuffle=False) - return DataLoader(self.val_dataset, sampler=sampler, + if self.batch_size is not None: + # Use default batching in torch DataLoader (good is batch size is small) + if self.automatic_batching: + collate = Collator(self.find_max_conditions_lengths('val')) + return DataLoader(self.val_dataset, self.batch_size, collate_fn=collate) + collate = Collator(None) + sampler = PinaBatchSampler(self.val_dataset, self.batch_size, shuffle=False) + return DataLoader(self.val_dataset, sampler=sampler, + collate_fn=collate) + dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device) + dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0) + self.transfer_batch_to_device = self.dummy_transfer_to_device def train_dataloader(self): """ Create the training dataloader """ - # Use default batching in torch DataLoader (good is batch size is small) - batch_size = self.batch_size if self.batch_size is not None else len( - self.train_dataset) - - if self.automatic_batching: - collate = Collator(self.find_max_conditions_lengths('train')) - return DataLoader(self.train_dataset, batch_size, - collate_fn=collate) - collate = Collator(None) # Use custom batching (good if batch size is large) - - sampler = PinaBatchSampler(self.train_dataset, batch_size, - shuffle=False) - return DataLoader(self.train_dataset, sampler=sampler, + if self.batch_size is not None: + # Use default batching in torch DataLoader (good is batch size is small) + if self.automatic_batching: + collate = Collator(self.find_max_conditions_lengths('train')) + return DataLoader(self.train_dataset, self.batch_size, + collate_fn=collate) + collate = Collator(None) + sampler = PinaBatchSampler(self.train_dataset, self.batch_size, + shuffle=False) + return DataLoader(self.train_dataset, sampler=sampler, collate_fn=collate) + dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device) + dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0) + self.transfer_batch_to_device = self.dummy_transfer_to_device + return dataloader def test_dataloader(self): """ @@ -279,6 +293,9 @@ class PinaDataModule(LightningDataModule): """ raise NotImplementedError("Predict dataloader not implemented") + def dummy_transfer_to_device(self, batch, device, dataloader_idx): + return batch + def transfer_batch_to_device(self, batch, device, dataloader_idx): """ Transfer the batch to the device. This method is called in the diff --git a/pina/data/dataset.py b/pina/data/dataset.py index e5685f1..8f41c0b 100644 --- a/pina/data/dataset.py +++ b/pina/data/dataset.py @@ -75,6 +75,10 @@ class PinaTensorDataset(PinaDataset): for k, v in data.items()} return to_return_dict + def get_all_data(self): + index = [i for i in range(len(self))] + return self._getitem_list(index) + def __getitem__(self, idx): return self._getitem_func(idx)