Improvement in DDP and bug fix in DataModule (#432)
This commit is contained in:
committed by
Nicola Demo
parent
629a6ee43b
commit
0194fab0d1
@@ -8,6 +8,7 @@ from torch.utils.data import DataLoader, BatchSampler, SequentialSampler, \
|
|||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from .dataset import PinaDatasetFactory
|
from .dataset import PinaDatasetFactory
|
||||||
|
|
||||||
|
|
||||||
class DummyDataloader:
|
class DummyDataloader:
|
||||||
def __init__(self, dataset, device):
|
def __init__(self, dataset, device):
|
||||||
self.dataset = dataset.get_all_data()
|
self.dataset = dataset.get_all_data()
|
||||||
@@ -21,16 +22,17 @@ class DummyDataloader:
|
|||||||
def __next__(self):
|
def __next__(self):
|
||||||
return self.dataset
|
return self.dataset
|
||||||
|
|
||||||
|
|
||||||
class Collator:
|
class Collator:
|
||||||
def __init__(self, max_conditions_lengths, ):
|
def __init__(self, max_conditions_lengths, dataset=None):
|
||||||
self.max_conditions_lengths = max_conditions_lengths
|
self.max_conditions_lengths = max_conditions_lengths
|
||||||
self.callable_function = self._collate_custom_dataloader if \
|
self.callable_function = self._collate_custom_dataloader if \
|
||||||
max_conditions_lengths is None else (
|
max_conditions_lengths is None else (
|
||||||
self._collate_standard_dataloader)
|
self._collate_standard_dataloader)
|
||||||
|
self.dataset = dataset
|
||||||
|
|
||||||
@staticmethod
|
def _collate_custom_dataloader(self, batch):
|
||||||
def _collate_custom_dataloader(batch):
|
return self.dataset.fetch_from_idx_list(batch)
|
||||||
return batch[0]
|
|
||||||
|
|
||||||
def _collate_standard_dataloader(self, batch):
|
def _collate_standard_dataloader(self, batch):
|
||||||
"""
|
"""
|
||||||
@@ -59,26 +61,24 @@ class Collator:
|
|||||||
batch_dict[condition_name] = single_cond_dict
|
batch_dict[condition_name] = single_cond_dict
|
||||||
return batch_dict
|
return batch_dict
|
||||||
|
|
||||||
|
|
||||||
def __call__(self, batch):
|
def __call__(self, batch):
|
||||||
return self.callable_function(batch)
|
return self.callable_function(batch)
|
||||||
|
|
||||||
|
|
||||||
class PinaBatchSampler(BatchSampler):
|
class PinaSampler:
|
||||||
def __init__(self, dataset, batch_size, shuffle, sampler=None):
|
def __new__(self, dataset, batch_size, shuffle, automatic_batching):
|
||||||
if sampler is None:
|
|
||||||
if (torch.distributed.is_available() and
|
if (torch.distributed.is_available() and
|
||||||
torch.distributed.is_initialized()):
|
torch.distributed.is_initialized()):
|
||||||
rank = torch.distributed.get_rank()
|
sampler = DistributedSampler(dataset, shuffle=shuffle)
|
||||||
world_size = torch.distributed.get_world_size()
|
|
||||||
sampler = DistributedSampler(dataset, shuffle=shuffle,
|
|
||||||
rank=rank, num_replicas=world_size)
|
|
||||||
else:
|
else:
|
||||||
if shuffle:
|
if shuffle:
|
||||||
sampler = RandomSampler(dataset)
|
sampler = RandomSampler(dataset)
|
||||||
else:
|
else:
|
||||||
sampler = SequentialSampler(dataset)
|
sampler = SequentialSampler(dataset)
|
||||||
super().__init__(sampler=sampler, batch_size=batch_size,
|
return sampler
|
||||||
drop_last=False)
|
|
||||||
|
|
||||||
class PinaDataModule(LightningDataModule):
|
class PinaDataModule(LightningDataModule):
|
||||||
"""
|
"""
|
||||||
@@ -136,6 +136,7 @@ class PinaDataModule(LightningDataModule):
|
|||||||
else:
|
else:
|
||||||
self.predict_dataloader = super().predict_dataloader
|
self.predict_dataloader = super().predict_dataloader
|
||||||
self.collector_splits = self._create_splits(collector, splits_dict)
|
self.collector_splits = self._create_splits(collector, splits_dict)
|
||||||
|
self.transfer_batch_to_device = self._transfer_batch_to_device
|
||||||
|
|
||||||
def setup(self, stage=None):
|
def setup(self, stage=None):
|
||||||
"""
|
"""
|
||||||
@@ -215,6 +216,7 @@ class PinaDataModule(LightningDataModule):
|
|||||||
condition_dict[k] = v[idx]
|
condition_dict[k] = v[idx]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Data type {type(v)} not supported")
|
raise ValueError(f"Data type {type(v)} not supported")
|
||||||
|
|
||||||
# ----------- End auxiliary function ------------
|
# ----------- End auxiliary function ------------
|
||||||
|
|
||||||
logging.debug('Dataset creation in PinaDataModule obj')
|
logging.debug('Dataset creation in PinaDataModule obj')
|
||||||
@@ -247,18 +249,21 @@ class PinaDataModule(LightningDataModule):
|
|||||||
"""
|
"""
|
||||||
# Use custom batching (good if batch size is large)
|
# Use custom batching (good if batch size is large)
|
||||||
if self.batch_size is not None:
|
if self.batch_size is not None:
|
||||||
# Use default batching in torch DataLoader (good is batch size is small)
|
sampler = PinaSampler(self.val_dataset, self.batch_size,
|
||||||
|
self.shuffle, self.automatic_batching)
|
||||||
if self.automatic_batching:
|
if self.automatic_batching:
|
||||||
collate = Collator(self.find_max_conditions_lengths('val'))
|
collate = Collator(self.find_max_conditions_lengths('val'))
|
||||||
|
else:
|
||||||
|
collate = Collator(None, self.val_dataset)
|
||||||
return DataLoader(self.val_dataset, self.batch_size,
|
return DataLoader(self.val_dataset, self.batch_size,
|
||||||
collate_fn=collate)
|
collate_fn=collate, sampler=sampler)
|
||||||
collate = Collator(None)
|
dataloader = DummyDataloader(self.val_dataset,
|
||||||
sampler = PinaBatchSampler(self.val_dataset, self.batch_size, shuffle=False)
|
self.trainer.strategy.root_device)
|
||||||
return DataLoader(self.val_dataset, sampler=sampler,
|
dataloader.dataset = self._transfer_batch_to_device(dataloader.dataset,
|
||||||
collate_fn=collate)
|
self.trainer.strategy.root_device,
|
||||||
dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device)
|
0)
|
||||||
dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0)
|
self.transfer_batch_to_device = self._transfer_batch_to_device_dummy
|
||||||
self.transfer_batch_to_device = self.dummy_transfer_to_device
|
return dataloader
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
"""
|
"""
|
||||||
@@ -266,19 +271,21 @@ class PinaDataModule(LightningDataModule):
|
|||||||
"""
|
"""
|
||||||
# Use custom batching (good if batch size is large)
|
# Use custom batching (good if batch size is large)
|
||||||
if self.batch_size is not None:
|
if self.batch_size is not None:
|
||||||
# Use default batching in torch DataLoader (good is batch size is small)
|
sampler = PinaSampler(self.train_dataset, self.batch_size,
|
||||||
|
self.shuffle, self.automatic_batching)
|
||||||
if self.automatic_batching:
|
if self.automatic_batching:
|
||||||
collate = Collator(self.find_max_conditions_lengths('train'))
|
collate = Collator(self.find_max_conditions_lengths('train'))
|
||||||
|
|
||||||
|
else:
|
||||||
|
collate = Collator(None, self.train_dataset)
|
||||||
return DataLoader(self.train_dataset, self.batch_size,
|
return DataLoader(self.train_dataset, self.batch_size,
|
||||||
collate_fn=collate)
|
collate_fn=collate, sampler=sampler)
|
||||||
collate = Collator(None)
|
dataloader = DummyDataloader(self.train_dataset,
|
||||||
sampler = PinaBatchSampler(self.train_dataset, self.batch_size,
|
self.trainer.strategy.root_device)
|
||||||
shuffle=False)
|
dataloader.dataset = self._transfer_batch_to_device(dataloader.dataset,
|
||||||
return DataLoader(self.train_dataset, sampler=sampler,
|
self.trainer.strategy.root_device,
|
||||||
collate_fn=collate)
|
0)
|
||||||
dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device)
|
self.transfer_batch_to_device = self._transfer_batch_to_device_dummy
|
||||||
dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0)
|
|
||||||
self.transfer_batch_to_device = self.dummy_transfer_to_device
|
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
def test_dataloader(self):
|
def test_dataloader(self):
|
||||||
@@ -293,10 +300,10 @@ class PinaDataModule(LightningDataModule):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError("Predict dataloader not implemented")
|
raise NotImplementedError("Predict dataloader not implemented")
|
||||||
|
|
||||||
def dummy_transfer_to_device(self, batch, device, dataloader_idx):
|
def _transfer_batch_to_device_dummy(self, batch, device, dataloader_idx):
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def transfer_batch_to_device(self, batch, device, dataloader_idx):
|
def _transfer_batch_to_device(self, batch, device, dataloader_idx):
|
||||||
"""
|
"""
|
||||||
Transfer the batch to the device. This method is called in the
|
Transfer the batch to the device. This method is called in the
|
||||||
training loop and is used to transfer the batch to the device.
|
training loop and is used to transfer the batch to the device.
|
||||||
@@ -307,4 +314,5 @@ class PinaDataModule(LightningDataModule):
|
|||||||
dataloader_idx))
|
dataloader_idx))
|
||||||
for k, v in batch.items()
|
for k, v in batch.items()
|
||||||
]
|
]
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from torch.utils.data import Dataset
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from torch_geometric.data import Batch
|
from torch_geometric.data import Batch
|
||||||
|
|
||||||
|
|
||||||
class PinaDatasetFactory:
|
class PinaDatasetFactory:
|
||||||
"""
|
"""
|
||||||
Factory class for the PINA dataset. Depending on the type inside the
|
Factory class for the PINA dataset. Depending on the type inside the
|
||||||
@@ -13,6 +14,7 @@ class PinaDatasetFactory:
|
|||||||
- PinaTensorDataset for torch.Tensor
|
- PinaTensorDataset for torch.Tensor
|
||||||
- PinaGraphDataset for list of torch_geometric.data.Data objects
|
- PinaGraphDataset for list of torch_geometric.data.Data objects
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls, conditions_dict, **kwargs):
|
def __new__(cls, conditions_dict, **kwargs):
|
||||||
if len(conditions_dict) == 0:
|
if len(conditions_dict) == 0:
|
||||||
raise ValueError('No conditions provided')
|
raise ValueError('No conditions provided')
|
||||||
@@ -25,10 +27,12 @@ class PinaDatasetFactory:
|
|||||||
raise ValueError('Conditions must be either torch.Tensor or list of Data '
|
raise ValueError('Conditions must be either torch.Tensor or list of Data '
|
||||||
'objects.')
|
'objects.')
|
||||||
|
|
||||||
|
|
||||||
class PinaDataset(Dataset):
|
class PinaDataset(Dataset):
|
||||||
"""
|
"""
|
||||||
Abstract class for the PINA dataset
|
Abstract class for the PINA dataset
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, conditions_dict, max_conditions_lengths):
|
def __init__(self, conditions_dict, max_conditions_lengths):
|
||||||
self.conditions_dict = conditions_dict
|
self.conditions_dict = conditions_dict
|
||||||
self.max_conditions_lengths = max_conditions_lengths
|
self.max_conditions_lengths = max_conditions_lengths
|
||||||
@@ -49,6 +53,7 @@ class PinaDataset(Dataset):
|
|||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class PinaTensorDataset(PinaDataset):
|
class PinaTensorDataset(PinaDataset):
|
||||||
def __init__(self, conditions_dict, max_conditions_lengths,
|
def __init__(self, conditions_dict, max_conditions_lengths,
|
||||||
automatic_batching):
|
automatic_batching):
|
||||||
@@ -64,45 +69,68 @@ class PinaTensorDataset(PinaDataset):
|
|||||||
in v.keys()} for k, v in self.conditions_dict.items()
|
in v.keys()} for k, v in self.conditions_dict.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
def _getitem_list(self, idx):
|
def fetch_from_idx_list(self, idx):
|
||||||
to_return_dict = {}
|
to_return_dict = {}
|
||||||
for condition, data in self.conditions_dict.items():
|
for condition, data in self.conditions_dict.items():
|
||||||
cond_idx = idx[:self.max_conditions_lengths[condition]]
|
cond_idx = idx[:self.max_conditions_lengths[condition]]
|
||||||
condition_len = self.conditions_length[condition]
|
condition_len = self.conditions_length[condition]
|
||||||
if self.length > condition_len:
|
if self.length > condition_len:
|
||||||
cond_idx = [idx%condition_len for idx in cond_idx]
|
cond_idx = [idx % condition_len for idx in cond_idx]
|
||||||
to_return_dict[condition] = {k: v[cond_idx]
|
to_return_dict[condition] = {k: v[cond_idx]
|
||||||
for k, v in data.items()}
|
for k, v in data.items()}
|
||||||
return to_return_dict
|
return to_return_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _getitem_list(idx):
|
||||||
|
return idx
|
||||||
|
|
||||||
def get_all_data(self):
|
def get_all_data(self):
|
||||||
index = [i for i in range(len(self))]
|
index = [i for i in range(len(self))]
|
||||||
return self._getitem_list(index)
|
return self.fetch_from_idx_list(index)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
return self._getitem_func(idx)
|
return self._getitem_func(idx)
|
||||||
|
|
||||||
|
|
||||||
class PinaGraphDataset(PinaDataset):
|
class PinaGraphDataset(PinaDataset):
|
||||||
pass
|
pass
|
||||||
"""
|
'''
|
||||||
def __init__(self, conditions_dict, max_conditions_lengths):
|
def __init__(self, conditions_dict, max_conditions_lengths,
|
||||||
|
automatic_batching):
|
||||||
super().__init__(conditions_dict, max_conditions_lengths)
|
super().__init__(conditions_dict, max_conditions_lengths)
|
||||||
|
if automatic_batching:
|
||||||
|
self._getitem_func = self._getitem_int
|
||||||
|
else:
|
||||||
|
self._getitem_func = self._getitem_list
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def fetch_from_idx_list(self, idx):
|
||||||
|
|
||||||
Getitem method for large batch size
|
|
||||||
|
|
||||||
to_return_dict = {}
|
to_return_dict = {}
|
||||||
for condition, data in self.conditions_dict.items():
|
for condition, data in self.conditions_dict.items():
|
||||||
cond_idx = idx[:self.max_conditions_lengths[condition]]
|
cond_idx = idx[:self.max_conditions_lengths[condition]]
|
||||||
condition_len = self.conditions_length[condition]
|
condition_len = self.conditions_length[condition]
|
||||||
if self.length > condition_len:
|
if self.length > condition_len:
|
||||||
cond_idx = [idx%condition_len for idx in cond_idx]
|
cond_idx = [idx % condition_len for idx in cond_idx]
|
||||||
to_return_dict[condition] = {k: Batch.from_data_list([v[i]
|
to_return_dict[condition] = {k: Batch.from_data_list([v[i]
|
||||||
for i in cond_idx])
|
for i in cond_idx])
|
||||||
if isinstance(v, list)
|
if isinstance(v, list)
|
||||||
else v[cond_idx].tensor.reshape(-1, v.size(-1))
|
else v[cond_idx]
|
||||||
for k, v in data.items()
|
for k, v in data.items()
|
||||||
}
|
}
|
||||||
return to_return_dict
|
return to_return_dict
|
||||||
"""
|
|
||||||
|
def _getitem_list(self, idx):
|
||||||
|
return idx
|
||||||
|
|
||||||
|
def _getitem_int(self, idx):
|
||||||
|
return {
|
||||||
|
k: {k_data: v[k_data][idx % len(v['input_points'])] for k_data
|
||||||
|
in v.keys()} for k, v in self.conditions_dict.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_all_data(self):
|
||||||
|
index = [i for i in range(len(self))]
|
||||||
|
return self.fetch_from_idx_list(index)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self._getitem_func(idx)
|
||||||
|
'''
|
||||||
@@ -72,12 +72,14 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
raise RuntimeError('Cannot create Trainer if not all conditions '
|
raise RuntimeError('Cannot create Trainer if not all conditions '
|
||||||
'are sampled. The Trainer got the following:\n'
|
'are sampled. The Trainer got the following:\n'
|
||||||
f'{error_message}')
|
f'{error_message}')
|
||||||
|
automatic_batching = False
|
||||||
self.data_module = PinaDataModule(collector=self.solver.problem.collector,
|
self.data_module = PinaDataModule(collector=self.solver.problem.collector,
|
||||||
train_size=self.train_size,
|
train_size=self.train_size,
|
||||||
test_size=self.test_size,
|
test_size=self.test_size,
|
||||||
val_size=self.val_size,
|
val_size=self.val_size,
|
||||||
predict_size=self.predict_size,
|
predict_size=self.predict_size,
|
||||||
batch_size=self.batch_size,)
|
batch_size=self.batch_size,
|
||||||
|
automatic_batching=automatic_batching)
|
||||||
|
|
||||||
def train(self, **kwargs):
|
def train(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user