From 0194fab0d12f1e1d0340d48fb47b1aa7e37b9db7 Mon Sep 17 00:00:00 2001 From: Filippo Olivo Date: Tue, 28 Jan 2025 13:51:57 +0100 Subject: [PATCH] Improvement in DDP and bug fix in DataModule (#432) --- pina/data/data_module.py | 94 ++++++++++++++++++++++------------------ pina/data/dataset.py | 60 ++++++++++++++++++------- pina/trainer.py | 4 +- 3 files changed, 98 insertions(+), 60 deletions(-) diff --git a/pina/data/data_module.py b/pina/data/data_module.py index 4b529fe..56473f8 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -8,6 +8,7 @@ from torch.utils.data import DataLoader, BatchSampler, SequentialSampler, \ from torch.utils.data.distributed import DistributedSampler from .dataset import PinaDatasetFactory + class DummyDataloader: def __init__(self, dataset, device): self.dataset = dataset.get_all_data() @@ -21,16 +22,17 @@ class DummyDataloader: def __next__(self): return self.dataset + class Collator: - def __init__(self, max_conditions_lengths, ): + def __init__(self, max_conditions_lengths, dataset=None): self.max_conditions_lengths = max_conditions_lengths self.callable_function = self._collate_custom_dataloader if \ max_conditions_lengths is None else ( self._collate_standard_dataloader) + self.dataset = dataset - @staticmethod - def _collate_custom_dataloader(batch): - return batch[0] + def _collate_custom_dataloader(self, batch): + return self.dataset.fetch_from_idx_list(batch) def _collate_standard_dataloader(self, batch): """ @@ -59,26 +61,24 @@ class Collator: batch_dict[condition_name] = single_cond_dict return batch_dict + def __call__(self, batch): return self.callable_function(batch) -class PinaBatchSampler(BatchSampler): - def __init__(self, dataset, batch_size, shuffle, sampler=None): - if sampler is None: - if (torch.distributed.is_available() and - torch.distributed.is_initialized()): - rank = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() - sampler = DistributedSampler(dataset, shuffle=shuffle, - rank=rank, num_replicas=world_size) +class PinaSampler: + def __new__(self, dataset, batch_size, shuffle, automatic_batching): + + if (torch.distributed.is_available() and + torch.distributed.is_initialized()): + sampler = DistributedSampler(dataset, shuffle=shuffle) + else: + if shuffle: + sampler = RandomSampler(dataset) else: - if shuffle: - sampler = RandomSampler(dataset) - else: - sampler = SequentialSampler(dataset) - super().__init__(sampler=sampler, batch_size=batch_size, - drop_last=False) + sampler = SequentialSampler(dataset) + return sampler + class PinaDataModule(LightningDataModule): """ @@ -136,6 +136,7 @@ class PinaDataModule(LightningDataModule): else: self.predict_dataloader = super().predict_dataloader self.collector_splits = self._create_splits(collector, splits_dict) + self.transfer_batch_to_device = self._transfer_batch_to_device def setup(self, stage=None): """ @@ -151,7 +152,7 @@ class PinaDataModule(LightningDataModule): self.val_dataset = PinaDatasetFactory( self.collector_splits['val'], max_conditions_lengths=self.find_max_conditions_lengths( - 'val'), automatic_batching=self.automatic_batching + 'val'), automatic_batching=self.automatic_batching ) elif stage == 'test': self.test_dataset = PinaDatasetFactory( @@ -215,6 +216,7 @@ class PinaDataModule(LightningDataModule): condition_dict[k] = v[idx] else: raise ValueError(f"Data type {type(v)} not supported") + # ----------- End auxiliary function ------------ logging.debug('Dataset creation in PinaDataModule obj') @@ -247,18 +249,21 @@ class PinaDataModule(LightningDataModule): """ # Use custom batching (good if batch size is large) if self.batch_size is not None: - # Use default batching in torch DataLoader (good is batch size is small) + sampler = PinaSampler(self.val_dataset, self.batch_size, + self.shuffle, self.automatic_batching) if self.automatic_batching: collate = Collator(self.find_max_conditions_lengths('val')) - return DataLoader(self.val_dataset, self.batch_size, - collate_fn=collate) - collate = Collator(None) - sampler = PinaBatchSampler(self.val_dataset, self.batch_size, shuffle=False) - return DataLoader(self.val_dataset, sampler=sampler, - collate_fn=collate) - dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device) - dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0) - self.transfer_batch_to_device = self.dummy_transfer_to_device + else: + collate = Collator(None, self.val_dataset) + return DataLoader(self.val_dataset, self.batch_size, + collate_fn=collate, sampler=sampler) + dataloader = DummyDataloader(self.val_dataset, + self.trainer.strategy.root_device) + dataloader.dataset = self._transfer_batch_to_device(dataloader.dataset, + self.trainer.strategy.root_device, + 0) + self.transfer_batch_to_device = self._transfer_batch_to_device_dummy + return dataloader def train_dataloader(self): """ @@ -266,19 +271,21 @@ class PinaDataModule(LightningDataModule): """ # Use custom batching (good if batch size is large) if self.batch_size is not None: - # Use default batching in torch DataLoader (good is batch size is small) + sampler = PinaSampler(self.train_dataset, self.batch_size, + self.shuffle, self.automatic_batching) if self.automatic_batching: collate = Collator(self.find_max_conditions_lengths('train')) - return DataLoader(self.train_dataset, self.batch_size, - collate_fn=collate) - collate = Collator(None) - sampler = PinaBatchSampler(self.train_dataset, self.batch_size, - shuffle=False) - return DataLoader(self.train_dataset, sampler=sampler, - collate_fn=collate) - dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device) - dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0) - self.transfer_batch_to_device = self.dummy_transfer_to_device + + else: + collate = Collator(None, self.train_dataset) + return DataLoader(self.train_dataset, self.batch_size, + collate_fn=collate, sampler=sampler) + dataloader = DummyDataloader(self.train_dataset, + self.trainer.strategy.root_device) + dataloader.dataset = self._transfer_batch_to_device(dataloader.dataset, + self.trainer.strategy.root_device, + 0) + self.transfer_batch_to_device = self._transfer_batch_to_device_dummy return dataloader def test_dataloader(self): @@ -293,10 +300,10 @@ class PinaDataModule(LightningDataModule): """ raise NotImplementedError("Predict dataloader not implemented") - def dummy_transfer_to_device(self, batch, device, dataloader_idx): + def _transfer_batch_to_device_dummy(self, batch, device, dataloader_idx): return batch - def transfer_batch_to_device(self, batch, device, dataloader_idx): + def _transfer_batch_to_device(self, batch, device, dataloader_idx): """ Transfer the batch to the device. This method is called in the training loop and is used to transfer the batch to the device. @@ -307,4 +314,5 @@ class PinaDataModule(LightningDataModule): dataloader_idx)) for k, v in batch.items() ] + return batch diff --git a/pina/data/dataset.py b/pina/data/dataset.py index 8f41c0b..8b5f998 100644 --- a/pina/data/dataset.py +++ b/pina/data/dataset.py @@ -6,6 +6,7 @@ from torch.utils.data import Dataset from abc import abstractmethod from torch_geometric.data import Batch + class PinaDatasetFactory: """ Factory class for the PINA dataset. Depending on the type inside the @@ -13,6 +14,7 @@ class PinaDatasetFactory: - PinaTensorDataset for torch.Tensor - PinaGraphDataset for list of torch_geometric.data.Data objects """ + def __new__(cls, conditions_dict, **kwargs): if len(conditions_dict) == 0: raise ValueError('No conditions provided') @@ -25,10 +27,12 @@ class PinaDatasetFactory: raise ValueError('Conditions must be either torch.Tensor or list of Data ' 'objects.') + class PinaDataset(Dataset): """ Abstract class for the PINA dataset """ + def __init__(self, conditions_dict, max_conditions_lengths): self.conditions_dict = conditions_dict self.max_conditions_lengths = max_conditions_lengths @@ -49,6 +53,7 @@ class PinaDataset(Dataset): def __getitem__(self, item): pass + class PinaTensorDataset(PinaDataset): def __init__(self, conditions_dict, max_conditions_lengths, automatic_batching): @@ -64,45 +69,68 @@ class PinaTensorDataset(PinaDataset): in v.keys()} for k, v in self.conditions_dict.items() } - def _getitem_list(self, idx): + def fetch_from_idx_list(self, idx): to_return_dict = {} for condition, data in self.conditions_dict.items(): cond_idx = idx[:self.max_conditions_lengths[condition]] condition_len = self.conditions_length[condition] if self.length > condition_len: - cond_idx = [idx%condition_len for idx in cond_idx] + cond_idx = [idx % condition_len for idx in cond_idx] to_return_dict[condition] = {k: v[cond_idx] for k, v in data.items()} return to_return_dict + @staticmethod + def _getitem_list(idx): + return idx + def get_all_data(self): index = [i for i in range(len(self))] - return self._getitem_list(index) + return self.fetch_from_idx_list(index) def __getitem__(self, idx): return self._getitem_func(idx) + class PinaGraphDataset(PinaDataset): pass - """ - def __init__(self, conditions_dict, max_conditions_lengths): +''' + def __init__(self, conditions_dict, max_conditions_lengths, + automatic_batching): super().__init__(conditions_dict, max_conditions_lengths) + if automatic_batching: + self._getitem_func = self._getitem_int + else: + self._getitem_func = self._getitem_list - def __getitem__(self, idx): - - Getitem method for large batch size - + def fetch_from_idx_list(self, idx): to_return_dict = {} for condition, data in self.conditions_dict.items(): cond_idx = idx[:self.max_conditions_lengths[condition]] condition_len = self.conditions_length[condition] if self.length > condition_len: - cond_idx = [idx%condition_len for idx in cond_idx] + cond_idx = [idx % condition_len for idx in cond_idx] to_return_dict[condition] = {k: Batch.from_data_list([v[i] - for i in cond_idx]) - if isinstance(v, list) - else v[cond_idx].tensor.reshape(-1, v.size(-1)) - for k, v in data.items() - } + for i in cond_idx]) + if isinstance(v, list) + else v[cond_idx] + for k, v in data.items() + } return to_return_dict - """ + + def _getitem_list(self, idx): + return idx + + def _getitem_int(self, idx): + return { + k: {k_data: v[k_data][idx % len(v['input_points'])] for k_data + in v.keys()} for k, v in self.conditions_dict.items() + } + + def get_all_data(self): + index = [i for i in range(len(self))] + return self.fetch_from_idx_list(index) + + def __getitem__(self, idx): + return self._getitem_func(idx) +''' \ No newline at end of file diff --git a/pina/trainer.py b/pina/trainer.py index f8bccd8..6a16248 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -72,12 +72,14 @@ class Trainer(lightning.pytorch.Trainer): raise RuntimeError('Cannot create Trainer if not all conditions ' 'are sampled. The Trainer got the following:\n' f'{error_message}') + automatic_batching = False self.data_module = PinaDataModule(collector=self.solver.problem.collector, train_size=self.train_size, test_size=self.test_size, val_size=self.val_size, predict_size=self.predict_size, - batch_size=self.batch_size,) + batch_size=self.batch_size, + automatic_batching=automatic_batching) def train(self, **kwargs): """