Improve DataLoader performance when batch_size=None (#423)
This commit is contained in:
committed by
Nicola Demo
parent
7706ef12c3
commit
3ea05e845d
@@ -8,6 +8,19 @@ from torch.utils.data import DataLoader, BatchSampler, SequentialSampler, \
|
|||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from .dataset import PinaDatasetFactory
|
from .dataset import PinaDatasetFactory
|
||||||
|
|
||||||
|
class DummyDataloader:
|
||||||
|
def __init__(self, dataset, device):
|
||||||
|
self.dataset = dataset.get_all_data()
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
return self.dataset
|
||||||
|
|
||||||
class Collator:
|
class Collator:
|
||||||
def __init__(self, max_conditions_lengths, ):
|
def __init__(self, max_conditions_lengths, ):
|
||||||
self.max_conditions_lengths = max_conditions_lengths
|
self.max_conditions_lengths = max_conditions_lengths
|
||||||
@@ -232,40 +245,41 @@ class PinaDataModule(LightningDataModule):
|
|||||||
"""
|
"""
|
||||||
Create the validation dataloader
|
Create the validation dataloader
|
||||||
"""
|
"""
|
||||||
|
|
||||||
batch_size = self.batch_size if self.batch_size is not None else len(
|
|
||||||
self.val_dataset)
|
|
||||||
|
|
||||||
# Use default batching in torch DataLoader (good is batch size is small)
|
|
||||||
if self.automatic_batching:
|
|
||||||
collate = Collator(self.find_max_conditions_lengths('val'))
|
|
||||||
return DataLoader(self.val_dataset, batch_size,
|
|
||||||
collate_fn=collate)
|
|
||||||
collate = Collator(None)
|
|
||||||
# Use custom batching (good if batch size is large)
|
# Use custom batching (good if batch size is large)
|
||||||
sampler = PinaBatchSampler(self.val_dataset, batch_size, shuffle=False)
|
if self.batch_size is not None:
|
||||||
return DataLoader(self.val_dataset, sampler=sampler,
|
# Use default batching in torch DataLoader (good is batch size is small)
|
||||||
|
if self.automatic_batching:
|
||||||
|
collate = Collator(self.find_max_conditions_lengths('val'))
|
||||||
|
return DataLoader(self.val_dataset, self.batch_size,
|
||||||
collate_fn=collate)
|
collate_fn=collate)
|
||||||
|
collate = Collator(None)
|
||||||
|
sampler = PinaBatchSampler(self.val_dataset, self.batch_size, shuffle=False)
|
||||||
|
return DataLoader(self.val_dataset, sampler=sampler,
|
||||||
|
collate_fn=collate)
|
||||||
|
dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device)
|
||||||
|
dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0)
|
||||||
|
self.transfer_batch_to_device = self.dummy_transfer_to_device
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
"""
|
"""
|
||||||
Create the training dataloader
|
Create the training dataloader
|
||||||
"""
|
"""
|
||||||
# Use default batching in torch DataLoader (good is batch size is small)
|
|
||||||
batch_size = self.batch_size if self.batch_size is not None else len(
|
|
||||||
self.train_dataset)
|
|
||||||
|
|
||||||
if self.automatic_batching:
|
|
||||||
collate = Collator(self.find_max_conditions_lengths('train'))
|
|
||||||
return DataLoader(self.train_dataset, batch_size,
|
|
||||||
collate_fn=collate)
|
|
||||||
collate = Collator(None)
|
|
||||||
# Use custom batching (good if batch size is large)
|
# Use custom batching (good if batch size is large)
|
||||||
|
if self.batch_size is not None:
|
||||||
sampler = PinaBatchSampler(self.train_dataset, batch_size,
|
# Use default batching in torch DataLoader (good is batch size is small)
|
||||||
shuffle=False)
|
if self.automatic_batching:
|
||||||
return DataLoader(self.train_dataset, sampler=sampler,
|
collate = Collator(self.find_max_conditions_lengths('train'))
|
||||||
|
return DataLoader(self.train_dataset, self.batch_size,
|
||||||
|
collate_fn=collate)
|
||||||
|
collate = Collator(None)
|
||||||
|
sampler = PinaBatchSampler(self.train_dataset, self.batch_size,
|
||||||
|
shuffle=False)
|
||||||
|
return DataLoader(self.train_dataset, sampler=sampler,
|
||||||
collate_fn=collate)
|
collate_fn=collate)
|
||||||
|
dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device)
|
||||||
|
dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0)
|
||||||
|
self.transfer_batch_to_device = self.dummy_transfer_to_device
|
||||||
|
return dataloader
|
||||||
|
|
||||||
def test_dataloader(self):
|
def test_dataloader(self):
|
||||||
"""
|
"""
|
||||||
@@ -279,6 +293,9 @@ class PinaDataModule(LightningDataModule):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError("Predict dataloader not implemented")
|
raise NotImplementedError("Predict dataloader not implemented")
|
||||||
|
|
||||||
|
def dummy_transfer_to_device(self, batch, device, dataloader_idx):
|
||||||
|
return batch
|
||||||
|
|
||||||
def transfer_batch_to_device(self, batch, device, dataloader_idx):
|
def transfer_batch_to_device(self, batch, device, dataloader_idx):
|
||||||
"""
|
"""
|
||||||
Transfer the batch to the device. This method is called in the
|
Transfer the batch to the device. This method is called in the
|
||||||
|
|||||||
@@ -75,6 +75,10 @@ class PinaTensorDataset(PinaDataset):
|
|||||||
for k, v in data.items()}
|
for k, v in data.items()}
|
||||||
return to_return_dict
|
return to_return_dict
|
||||||
|
|
||||||
|
def get_all_data(self):
|
||||||
|
index = [i for i in range(len(self))]
|
||||||
|
return self._getitem_list(index)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
return self._getitem_func(idx)
|
return self._getitem_func(idx)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user