Improvement in DDP and bug fix in DataModule (#432)

This commit is contained in:
Filippo Olivo
2025-01-28 13:51:57 +01:00
committed by Nicola Demo
parent 629a6ee43b
commit 0194fab0d1
3 changed files with 98 additions and 60 deletions

View File

@@ -8,6 +8,7 @@ from torch.utils.data import DataLoader, BatchSampler, SequentialSampler, \
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from .dataset import PinaDatasetFactory from .dataset import PinaDatasetFactory
class DummyDataloader: class DummyDataloader:
def __init__(self, dataset, device): def __init__(self, dataset, device):
self.dataset = dataset.get_all_data() self.dataset = dataset.get_all_data()
@@ -21,16 +22,17 @@ class DummyDataloader:
def __next__(self): def __next__(self):
return self.dataset return self.dataset
class Collator: class Collator:
def __init__(self, max_conditions_lengths, ): def __init__(self, max_conditions_lengths, dataset=None):
self.max_conditions_lengths = max_conditions_lengths self.max_conditions_lengths = max_conditions_lengths
self.callable_function = self._collate_custom_dataloader if \ self.callable_function = self._collate_custom_dataloader if \
max_conditions_lengths is None else ( max_conditions_lengths is None else (
self._collate_standard_dataloader) self._collate_standard_dataloader)
self.dataset = dataset
@staticmethod def _collate_custom_dataloader(self, batch):
def _collate_custom_dataloader(batch): return self.dataset.fetch_from_idx_list(batch)
return batch[0]
def _collate_standard_dataloader(self, batch): def _collate_standard_dataloader(self, batch):
""" """
@@ -59,26 +61,24 @@ class Collator:
batch_dict[condition_name] = single_cond_dict batch_dict[condition_name] = single_cond_dict
return batch_dict return batch_dict
def __call__(self, batch): def __call__(self, batch):
return self.callable_function(batch) return self.callable_function(batch)
class PinaBatchSampler(BatchSampler): class PinaSampler:
def __init__(self, dataset, batch_size, shuffle, sampler=None): def __new__(self, dataset, batch_size, shuffle, automatic_batching):
if sampler is None:
if (torch.distributed.is_available() and if (torch.distributed.is_available() and
torch.distributed.is_initialized()): torch.distributed.is_initialized()):
rank = torch.distributed.get_rank() sampler = DistributedSampler(dataset, shuffle=shuffle)
world_size = torch.distributed.get_world_size() else:
sampler = DistributedSampler(dataset, shuffle=shuffle, if shuffle:
rank=rank, num_replicas=world_size) sampler = RandomSampler(dataset)
else: else:
if shuffle: sampler = SequentialSampler(dataset)
sampler = RandomSampler(dataset) return sampler
else:
sampler = SequentialSampler(dataset)
super().__init__(sampler=sampler, batch_size=batch_size,
drop_last=False)
class PinaDataModule(LightningDataModule): class PinaDataModule(LightningDataModule):
""" """
@@ -136,6 +136,7 @@ class PinaDataModule(LightningDataModule):
else: else:
self.predict_dataloader = super().predict_dataloader self.predict_dataloader = super().predict_dataloader
self.collector_splits = self._create_splits(collector, splits_dict) self.collector_splits = self._create_splits(collector, splits_dict)
self.transfer_batch_to_device = self._transfer_batch_to_device
def setup(self, stage=None): def setup(self, stage=None):
""" """
@@ -151,7 +152,7 @@ class PinaDataModule(LightningDataModule):
self.val_dataset = PinaDatasetFactory( self.val_dataset = PinaDatasetFactory(
self.collector_splits['val'], self.collector_splits['val'],
max_conditions_lengths=self.find_max_conditions_lengths( max_conditions_lengths=self.find_max_conditions_lengths(
'val'), automatic_batching=self.automatic_batching 'val'), automatic_batching=self.automatic_batching
) )
elif stage == 'test': elif stage == 'test':
self.test_dataset = PinaDatasetFactory( self.test_dataset = PinaDatasetFactory(
@@ -215,6 +216,7 @@ class PinaDataModule(LightningDataModule):
condition_dict[k] = v[idx] condition_dict[k] = v[idx]
else: else:
raise ValueError(f"Data type {type(v)} not supported") raise ValueError(f"Data type {type(v)} not supported")
# ----------- End auxiliary function ------------ # ----------- End auxiliary function ------------
logging.debug('Dataset creation in PinaDataModule obj') logging.debug('Dataset creation in PinaDataModule obj')
@@ -247,18 +249,21 @@ class PinaDataModule(LightningDataModule):
""" """
# Use custom batching (good if batch size is large) # Use custom batching (good if batch size is large)
if self.batch_size is not None: if self.batch_size is not None:
# Use default batching in torch DataLoader (good is batch size is small) sampler = PinaSampler(self.val_dataset, self.batch_size,
self.shuffle, self.automatic_batching)
if self.automatic_batching: if self.automatic_batching:
collate = Collator(self.find_max_conditions_lengths('val')) collate = Collator(self.find_max_conditions_lengths('val'))
return DataLoader(self.val_dataset, self.batch_size, else:
collate_fn=collate) collate = Collator(None, self.val_dataset)
collate = Collator(None) return DataLoader(self.val_dataset, self.batch_size,
sampler = PinaBatchSampler(self.val_dataset, self.batch_size, shuffle=False) collate_fn=collate, sampler=sampler)
return DataLoader(self.val_dataset, sampler=sampler, dataloader = DummyDataloader(self.val_dataset,
collate_fn=collate) self.trainer.strategy.root_device)
dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device) dataloader.dataset = self._transfer_batch_to_device(dataloader.dataset,
dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0) self.trainer.strategy.root_device,
self.transfer_batch_to_device = self.dummy_transfer_to_device 0)
self.transfer_batch_to_device = self._transfer_batch_to_device_dummy
return dataloader
def train_dataloader(self): def train_dataloader(self):
""" """
@@ -266,19 +271,21 @@ class PinaDataModule(LightningDataModule):
""" """
# Use custom batching (good if batch size is large) # Use custom batching (good if batch size is large)
if self.batch_size is not None: if self.batch_size is not None:
# Use default batching in torch DataLoader (good is batch size is small) sampler = PinaSampler(self.train_dataset, self.batch_size,
self.shuffle, self.automatic_batching)
if self.automatic_batching: if self.automatic_batching:
collate = Collator(self.find_max_conditions_lengths('train')) collate = Collator(self.find_max_conditions_lengths('train'))
return DataLoader(self.train_dataset, self.batch_size,
collate_fn=collate) else:
collate = Collator(None) collate = Collator(None, self.train_dataset)
sampler = PinaBatchSampler(self.train_dataset, self.batch_size, return DataLoader(self.train_dataset, self.batch_size,
shuffle=False) collate_fn=collate, sampler=sampler)
return DataLoader(self.train_dataset, sampler=sampler, dataloader = DummyDataloader(self.train_dataset,
collate_fn=collate) self.trainer.strategy.root_device)
dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device) dataloader.dataset = self._transfer_batch_to_device(dataloader.dataset,
dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0) self.trainer.strategy.root_device,
self.transfer_batch_to_device = self.dummy_transfer_to_device 0)
self.transfer_batch_to_device = self._transfer_batch_to_device_dummy
return dataloader return dataloader
def test_dataloader(self): def test_dataloader(self):
@@ -293,10 +300,10 @@ class PinaDataModule(LightningDataModule):
""" """
raise NotImplementedError("Predict dataloader not implemented") raise NotImplementedError("Predict dataloader not implemented")
def dummy_transfer_to_device(self, batch, device, dataloader_idx): def _transfer_batch_to_device_dummy(self, batch, device, dataloader_idx):
return batch return batch
def transfer_batch_to_device(self, batch, device, dataloader_idx): def _transfer_batch_to_device(self, batch, device, dataloader_idx):
""" """
Transfer the batch to the device. This method is called in the Transfer the batch to the device. This method is called in the
training loop and is used to transfer the batch to the device. training loop and is used to transfer the batch to the device.
@@ -307,4 +314,5 @@ class PinaDataModule(LightningDataModule):
dataloader_idx)) dataloader_idx))
for k, v in batch.items() for k, v in batch.items()
] ]
return batch return batch

View File

@@ -6,6 +6,7 @@ from torch.utils.data import Dataset
from abc import abstractmethod from abc import abstractmethod
from torch_geometric.data import Batch from torch_geometric.data import Batch
class PinaDatasetFactory: class PinaDatasetFactory:
""" """
Factory class for the PINA dataset. Depending on the type inside the Factory class for the PINA dataset. Depending on the type inside the
@@ -13,6 +14,7 @@ class PinaDatasetFactory:
- PinaTensorDataset for torch.Tensor - PinaTensorDataset for torch.Tensor
- PinaGraphDataset for list of torch_geometric.data.Data objects - PinaGraphDataset for list of torch_geometric.data.Data objects
""" """
def __new__(cls, conditions_dict, **kwargs): def __new__(cls, conditions_dict, **kwargs):
if len(conditions_dict) == 0: if len(conditions_dict) == 0:
raise ValueError('No conditions provided') raise ValueError('No conditions provided')
@@ -25,10 +27,12 @@ class PinaDatasetFactory:
raise ValueError('Conditions must be either torch.Tensor or list of Data ' raise ValueError('Conditions must be either torch.Tensor or list of Data '
'objects.') 'objects.')
class PinaDataset(Dataset): class PinaDataset(Dataset):
""" """
Abstract class for the PINA dataset Abstract class for the PINA dataset
""" """
def __init__(self, conditions_dict, max_conditions_lengths): def __init__(self, conditions_dict, max_conditions_lengths):
self.conditions_dict = conditions_dict self.conditions_dict = conditions_dict
self.max_conditions_lengths = max_conditions_lengths self.max_conditions_lengths = max_conditions_lengths
@@ -49,6 +53,7 @@ class PinaDataset(Dataset):
def __getitem__(self, item): def __getitem__(self, item):
pass pass
class PinaTensorDataset(PinaDataset): class PinaTensorDataset(PinaDataset):
def __init__(self, conditions_dict, max_conditions_lengths, def __init__(self, conditions_dict, max_conditions_lengths,
automatic_batching): automatic_batching):
@@ -64,45 +69,68 @@ class PinaTensorDataset(PinaDataset):
in v.keys()} for k, v in self.conditions_dict.items() in v.keys()} for k, v in self.conditions_dict.items()
} }
def _getitem_list(self, idx): def fetch_from_idx_list(self, idx):
to_return_dict = {} to_return_dict = {}
for condition, data in self.conditions_dict.items(): for condition, data in self.conditions_dict.items():
cond_idx = idx[:self.max_conditions_lengths[condition]] cond_idx = idx[:self.max_conditions_lengths[condition]]
condition_len = self.conditions_length[condition] condition_len = self.conditions_length[condition]
if self.length > condition_len: if self.length > condition_len:
cond_idx = [idx%condition_len for idx in cond_idx] cond_idx = [idx % condition_len for idx in cond_idx]
to_return_dict[condition] = {k: v[cond_idx] to_return_dict[condition] = {k: v[cond_idx]
for k, v in data.items()} for k, v in data.items()}
return to_return_dict return to_return_dict
@staticmethod
def _getitem_list(idx):
return idx
def get_all_data(self): def get_all_data(self):
index = [i for i in range(len(self))] index = [i for i in range(len(self))]
return self._getitem_list(index) return self.fetch_from_idx_list(index)
def __getitem__(self, idx): def __getitem__(self, idx):
return self._getitem_func(idx) return self._getitem_func(idx)
class PinaGraphDataset(PinaDataset): class PinaGraphDataset(PinaDataset):
pass pass
""" '''
def __init__(self, conditions_dict, max_conditions_lengths): def __init__(self, conditions_dict, max_conditions_lengths,
automatic_batching):
super().__init__(conditions_dict, max_conditions_lengths) super().__init__(conditions_dict, max_conditions_lengths)
if automatic_batching:
self._getitem_func = self._getitem_int
else:
self._getitem_func = self._getitem_list
def __getitem__(self, idx): def fetch_from_idx_list(self, idx):
Getitem method for large batch size
to_return_dict = {} to_return_dict = {}
for condition, data in self.conditions_dict.items(): for condition, data in self.conditions_dict.items():
cond_idx = idx[:self.max_conditions_lengths[condition]] cond_idx = idx[:self.max_conditions_lengths[condition]]
condition_len = self.conditions_length[condition] condition_len = self.conditions_length[condition]
if self.length > condition_len: if self.length > condition_len:
cond_idx = [idx%condition_len for idx in cond_idx] cond_idx = [idx % condition_len for idx in cond_idx]
to_return_dict[condition] = {k: Batch.from_data_list([v[i] to_return_dict[condition] = {k: Batch.from_data_list([v[i]
for i in cond_idx]) for i in cond_idx])
if isinstance(v, list) if isinstance(v, list)
else v[cond_idx].tensor.reshape(-1, v.size(-1)) else v[cond_idx]
for k, v in data.items() for k, v in data.items()
} }
return to_return_dict return to_return_dict
"""
def _getitem_list(self, idx):
return idx
def _getitem_int(self, idx):
return {
k: {k_data: v[k_data][idx % len(v['input_points'])] for k_data
in v.keys()} for k, v in self.conditions_dict.items()
}
def get_all_data(self):
index = [i for i in range(len(self))]
return self.fetch_from_idx_list(index)
def __getitem__(self, idx):
return self._getitem_func(idx)
'''

View File

@@ -72,12 +72,14 @@ class Trainer(lightning.pytorch.Trainer):
raise RuntimeError('Cannot create Trainer if not all conditions ' raise RuntimeError('Cannot create Trainer if not all conditions '
'are sampled. The Trainer got the following:\n' 'are sampled. The Trainer got the following:\n'
f'{error_message}') f'{error_message}')
automatic_batching = False
self.data_module = PinaDataModule(collector=self.solver.problem.collector, self.data_module = PinaDataModule(collector=self.solver.problem.collector,
train_size=self.train_size, train_size=self.train_size,
test_size=self.test_size, test_size=self.test_size,
val_size=self.val_size, val_size=self.val_size,
predict_size=self.predict_size, predict_size=self.predict_size,
batch_size=self.batch_size,) batch_size=self.batch_size,
automatic_batching=automatic_batching)
def train(self, **kwargs): def train(self, **kwargs):
""" """