Improve DataLoader performance when batch_size=None (#423)

This commit is contained in:
Filippo Olivo
2025-01-23 12:04:31 +01:00
committed by Nicola Demo
parent 7706ef12c3
commit 3ea05e845d
2 changed files with 46 additions and 25 deletions

View File

@@ -8,6 +8,19 @@ from torch.utils.data import DataLoader, BatchSampler, SequentialSampler, \
from torch.utils.data.distributed import DistributedSampler
from .dataset import PinaDatasetFactory
class DummyDataloader:
def __init__(self, dataset, device):
self.dataset = dataset.get_all_data()
def __iter__(self):
return self
def __len__(self):
return 1
def __next__(self):
return self.dataset
class Collator:
def __init__(self, max_conditions_lengths, ):
self.max_conditions_lengths = max_conditions_lengths
@@ -232,40 +245,41 @@ class PinaDataModule(LightningDataModule):
"""
Create the validation dataloader
"""
batch_size = self.batch_size if self.batch_size is not None else len(
self.val_dataset)
# Use custom batching (good if batch size is large)
if self.batch_size is not None:
# Use default batching in torch DataLoader (good is batch size is small)
if self.automatic_batching:
collate = Collator(self.find_max_conditions_lengths('val'))
return DataLoader(self.val_dataset, batch_size,
return DataLoader(self.val_dataset, self.batch_size,
collate_fn=collate)
collate = Collator(None)
# Use custom batching (good if batch size is large)
sampler = PinaBatchSampler(self.val_dataset, batch_size, shuffle=False)
sampler = PinaBatchSampler(self.val_dataset, self.batch_size, shuffle=False)
return DataLoader(self.val_dataset, sampler=sampler,
collate_fn=collate)
dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device)
dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0)
self.transfer_batch_to_device = self.dummy_transfer_to_device
def train_dataloader(self):
"""
Create the training dataloader
"""
# Use custom batching (good if batch size is large)
if self.batch_size is not None:
# Use default batching in torch DataLoader (good is batch size is small)
batch_size = self.batch_size if self.batch_size is not None else len(
self.train_dataset)
if self.automatic_batching:
collate = Collator(self.find_max_conditions_lengths('train'))
return DataLoader(self.train_dataset, batch_size,
return DataLoader(self.train_dataset, self.batch_size,
collate_fn=collate)
collate = Collator(None)
# Use custom batching (good if batch size is large)
sampler = PinaBatchSampler(self.train_dataset, batch_size,
sampler = PinaBatchSampler(self.train_dataset, self.batch_size,
shuffle=False)
return DataLoader(self.train_dataset, sampler=sampler,
collate_fn=collate)
dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device)
dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0)
self.transfer_batch_to_device = self.dummy_transfer_to_device
return dataloader
def test_dataloader(self):
"""
@@ -279,6 +293,9 @@ class PinaDataModule(LightningDataModule):
"""
raise NotImplementedError("Predict dataloader not implemented")
def dummy_transfer_to_device(self, batch, device, dataloader_idx):
return batch
def transfer_batch_to_device(self, batch, device, dataloader_idx):
"""
Transfer the batch to the device. This method is called in the

View File

@@ -75,6 +75,10 @@ class PinaTensorDataset(PinaDataset):
for k, v in data.items()}
return to_return_dict
def get_all_data(self):
index = [i for i in range(len(self))]
return self._getitem_list(index)
def __getitem__(self, idx):
return self._getitem_func(idx)