Improve DataLoader performance when batch_size=None (#423)
This commit is contained in:
committed by
Nicola Demo
parent
7706ef12c3
commit
3ea05e845d
@@ -8,6 +8,19 @@ from torch.utils.data import DataLoader, BatchSampler, SequentialSampler, \
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from .dataset import PinaDatasetFactory
|
||||
|
||||
class DummyDataloader:
|
||||
def __init__(self, dataset, device):
|
||||
self.dataset = dataset.get_all_data()
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __len__(self):
|
||||
return 1
|
||||
|
||||
def __next__(self):
|
||||
return self.dataset
|
||||
|
||||
class Collator:
|
||||
def __init__(self, max_conditions_lengths, ):
|
||||
self.max_conditions_lengths = max_conditions_lengths
|
||||
@@ -232,40 +245,41 @@ class PinaDataModule(LightningDataModule):
|
||||
"""
|
||||
Create the validation dataloader
|
||||
"""
|
||||
|
||||
batch_size = self.batch_size if self.batch_size is not None else len(
|
||||
self.val_dataset)
|
||||
|
||||
# Use default batching in torch DataLoader (good is batch size is small)
|
||||
if self.automatic_batching:
|
||||
collate = Collator(self.find_max_conditions_lengths('val'))
|
||||
return DataLoader(self.val_dataset, batch_size,
|
||||
collate_fn=collate)
|
||||
collate = Collator(None)
|
||||
# Use custom batching (good if batch size is large)
|
||||
sampler = PinaBatchSampler(self.val_dataset, batch_size, shuffle=False)
|
||||
return DataLoader(self.val_dataset, sampler=sampler,
|
||||
if self.batch_size is not None:
|
||||
# Use default batching in torch DataLoader (good is batch size is small)
|
||||
if self.automatic_batching:
|
||||
collate = Collator(self.find_max_conditions_lengths('val'))
|
||||
return DataLoader(self.val_dataset, self.batch_size,
|
||||
collate_fn=collate)
|
||||
collate = Collator(None)
|
||||
sampler = PinaBatchSampler(self.val_dataset, self.batch_size, shuffle=False)
|
||||
return DataLoader(self.val_dataset, sampler=sampler,
|
||||
collate_fn=collate)
|
||||
dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device)
|
||||
dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0)
|
||||
self.transfer_batch_to_device = self.dummy_transfer_to_device
|
||||
|
||||
def train_dataloader(self):
|
||||
"""
|
||||
Create the training dataloader
|
||||
"""
|
||||
# Use default batching in torch DataLoader (good is batch size is small)
|
||||
batch_size = self.batch_size if self.batch_size is not None else len(
|
||||
self.train_dataset)
|
||||
|
||||
if self.automatic_batching:
|
||||
collate = Collator(self.find_max_conditions_lengths('train'))
|
||||
return DataLoader(self.train_dataset, batch_size,
|
||||
collate_fn=collate)
|
||||
collate = Collator(None)
|
||||
# Use custom batching (good if batch size is large)
|
||||
|
||||
sampler = PinaBatchSampler(self.train_dataset, batch_size,
|
||||
shuffle=False)
|
||||
return DataLoader(self.train_dataset, sampler=sampler,
|
||||
if self.batch_size is not None:
|
||||
# Use default batching in torch DataLoader (good is batch size is small)
|
||||
if self.automatic_batching:
|
||||
collate = Collator(self.find_max_conditions_lengths('train'))
|
||||
return DataLoader(self.train_dataset, self.batch_size,
|
||||
collate_fn=collate)
|
||||
collate = Collator(None)
|
||||
sampler = PinaBatchSampler(self.train_dataset, self.batch_size,
|
||||
shuffle=False)
|
||||
return DataLoader(self.train_dataset, sampler=sampler,
|
||||
collate_fn=collate)
|
||||
dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device)
|
||||
dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0)
|
||||
self.transfer_batch_to_device = self.dummy_transfer_to_device
|
||||
return dataloader
|
||||
|
||||
def test_dataloader(self):
|
||||
"""
|
||||
@@ -279,6 +293,9 @@ class PinaDataModule(LightningDataModule):
|
||||
"""
|
||||
raise NotImplementedError("Predict dataloader not implemented")
|
||||
|
||||
def dummy_transfer_to_device(self, batch, device, dataloader_idx):
|
||||
return batch
|
||||
|
||||
def transfer_batch_to_device(self, batch, device, dataloader_idx):
|
||||
"""
|
||||
Transfer the batch to the device. This method is called in the
|
||||
|
||||
@@ -75,6 +75,10 @@ class PinaTensorDataset(PinaDataset):
|
||||
for k, v in data.items()}
|
||||
return to_return_dict
|
||||
|
||||
def get_all_data(self):
|
||||
index = [i for i in range(len(self))]
|
||||
return self._getitem_list(index)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self._getitem_func(idx)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user