Improvement in DDP and bug fix in DataModule (#432)

This commit is contained in:
Filippo Olivo
2025-01-28 13:51:57 +01:00
committed by Nicola Demo
parent 629a6ee43b
commit 0194fab0d1
3 changed files with 98 additions and 60 deletions

View File

@@ -8,6 +8,7 @@ from torch.utils.data import DataLoader, BatchSampler, SequentialSampler, \
from torch.utils.data.distributed import DistributedSampler
from .dataset import PinaDatasetFactory
class DummyDataloader:
def __init__(self, dataset, device):
self.dataset = dataset.get_all_data()
@@ -21,16 +22,17 @@ class DummyDataloader:
def __next__(self):
return self.dataset
class Collator:
def __init__(self, max_conditions_lengths, ):
def __init__(self, max_conditions_lengths, dataset=None):
self.max_conditions_lengths = max_conditions_lengths
self.callable_function = self._collate_custom_dataloader if \
max_conditions_lengths is None else (
self._collate_standard_dataloader)
self.dataset = dataset
@staticmethod
def _collate_custom_dataloader(batch):
return batch[0]
def _collate_custom_dataloader(self, batch):
return self.dataset.fetch_from_idx_list(batch)
def _collate_standard_dataloader(self, batch):
"""
@@ -59,26 +61,24 @@ class Collator:
batch_dict[condition_name] = single_cond_dict
return batch_dict
def __call__(self, batch):
return self.callable_function(batch)
class PinaBatchSampler(BatchSampler):
def __init__(self, dataset, batch_size, shuffle, sampler=None):
if sampler is None:
if (torch.distributed.is_available() and
torch.distributed.is_initialized()):
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
sampler = DistributedSampler(dataset, shuffle=shuffle,
rank=rank, num_replicas=world_size)
class PinaSampler:
def __new__(self, dataset, batch_size, shuffle, automatic_batching):
if (torch.distributed.is_available() and
torch.distributed.is_initialized()):
sampler = DistributedSampler(dataset, shuffle=shuffle)
else:
if shuffle:
sampler = RandomSampler(dataset)
else:
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
super().__init__(sampler=sampler, batch_size=batch_size,
drop_last=False)
sampler = SequentialSampler(dataset)
return sampler
class PinaDataModule(LightningDataModule):
"""
@@ -136,6 +136,7 @@ class PinaDataModule(LightningDataModule):
else:
self.predict_dataloader = super().predict_dataloader
self.collector_splits = self._create_splits(collector, splits_dict)
self.transfer_batch_to_device = self._transfer_batch_to_device
def setup(self, stage=None):
"""
@@ -151,7 +152,7 @@ class PinaDataModule(LightningDataModule):
self.val_dataset = PinaDatasetFactory(
self.collector_splits['val'],
max_conditions_lengths=self.find_max_conditions_lengths(
'val'), automatic_batching=self.automatic_batching
'val'), automatic_batching=self.automatic_batching
)
elif stage == 'test':
self.test_dataset = PinaDatasetFactory(
@@ -215,6 +216,7 @@ class PinaDataModule(LightningDataModule):
condition_dict[k] = v[idx]
else:
raise ValueError(f"Data type {type(v)} not supported")
# ----------- End auxiliary function ------------
logging.debug('Dataset creation in PinaDataModule obj')
@@ -247,18 +249,21 @@ class PinaDataModule(LightningDataModule):
"""
# Use custom batching (good if batch size is large)
if self.batch_size is not None:
# Use default batching in torch DataLoader (good is batch size is small)
sampler = PinaSampler(self.val_dataset, self.batch_size,
self.shuffle, self.automatic_batching)
if self.automatic_batching:
collate = Collator(self.find_max_conditions_lengths('val'))
return DataLoader(self.val_dataset, self.batch_size,
collate_fn=collate)
collate = Collator(None)
sampler = PinaBatchSampler(self.val_dataset, self.batch_size, shuffle=False)
return DataLoader(self.val_dataset, sampler=sampler,
collate_fn=collate)
dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device)
dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0)
self.transfer_batch_to_device = self.dummy_transfer_to_device
else:
collate = Collator(None, self.val_dataset)
return DataLoader(self.val_dataset, self.batch_size,
collate_fn=collate, sampler=sampler)
dataloader = DummyDataloader(self.val_dataset,
self.trainer.strategy.root_device)
dataloader.dataset = self._transfer_batch_to_device(dataloader.dataset,
self.trainer.strategy.root_device,
0)
self.transfer_batch_to_device = self._transfer_batch_to_device_dummy
return dataloader
def train_dataloader(self):
"""
@@ -266,19 +271,21 @@ class PinaDataModule(LightningDataModule):
"""
# Use custom batching (good if batch size is large)
if self.batch_size is not None:
# Use default batching in torch DataLoader (good is batch size is small)
sampler = PinaSampler(self.train_dataset, self.batch_size,
self.shuffle, self.automatic_batching)
if self.automatic_batching:
collate = Collator(self.find_max_conditions_lengths('train'))
return DataLoader(self.train_dataset, self.batch_size,
collate_fn=collate)
collate = Collator(None)
sampler = PinaBatchSampler(self.train_dataset, self.batch_size,
shuffle=False)
return DataLoader(self.train_dataset, sampler=sampler,
collate_fn=collate)
dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device)
dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0)
self.transfer_batch_to_device = self.dummy_transfer_to_device
else:
collate = Collator(None, self.train_dataset)
return DataLoader(self.train_dataset, self.batch_size,
collate_fn=collate, sampler=sampler)
dataloader = DummyDataloader(self.train_dataset,
self.trainer.strategy.root_device)
dataloader.dataset = self._transfer_batch_to_device(dataloader.dataset,
self.trainer.strategy.root_device,
0)
self.transfer_batch_to_device = self._transfer_batch_to_device_dummy
return dataloader
def test_dataloader(self):
@@ -293,10 +300,10 @@ class PinaDataModule(LightningDataModule):
"""
raise NotImplementedError("Predict dataloader not implemented")
def dummy_transfer_to_device(self, batch, device, dataloader_idx):
def _transfer_batch_to_device_dummy(self, batch, device, dataloader_idx):
return batch
def transfer_batch_to_device(self, batch, device, dataloader_idx):
def _transfer_batch_to_device(self, batch, device, dataloader_idx):
"""
Transfer the batch to the device. This method is called in the
training loop and is used to transfer the batch to the device.
@@ -307,4 +314,5 @@ class PinaDataModule(LightningDataModule):
dataloader_idx))
for k, v in batch.items()
]
return batch