Update solvers (#434)
* Enable DDP training with batch_size=None and add validity check for split sizes * Refactoring SolverInterfaces (#435) * Solver update + weighting * Updating PINN for 0.2 * Modify GAROM + tests * Adding more versatile loggers * Disable compilation when running on Windows * Fix tests --------- Co-authored-by: giovanni <giovanni.canali98@yahoo.it> Co-authored-by: FilippoOlivo <filippo@filippoolivo.com>
This commit is contained in:
committed by
Nicola Demo
parent
780c4921eb
commit
9cae9a438f
@@ -1,6 +1,5 @@
|
||||
import logging
|
||||
from lightning.pytorch import LightningDataModule
|
||||
import math
|
||||
import torch
|
||||
from ..label_tensor import LabelTensor
|
||||
from torch.utils.data import DataLoader, BatchSampler, SequentialSampler, \
|
||||
@@ -10,8 +9,38 @@ from .dataset import PinaDatasetFactory
|
||||
from ..collector import Collector
|
||||
|
||||
class DummyDataloader:
|
||||
def __init__(self, dataset, device):
|
||||
self.dataset = dataset.get_all_data()
|
||||
""""
|
||||
Dummy dataloader used when batch size is None. It callects all the data
|
||||
in self.dataset and returns it when it is called a single batch.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset):
|
||||
"""
|
||||
param dataset: The dataset object to be processed.
|
||||
:notes:
|
||||
- **Distributed Environment**:
|
||||
- Divides the dataset across processes using the
|
||||
rank and world size.
|
||||
- Fetches only the portion of data corresponding to
|
||||
the current process.
|
||||
- **Non-Distributed Environment**:
|
||||
- Fetches the entire dataset.
|
||||
"""
|
||||
if (torch.distributed.is_available() and
|
||||
torch.distributed.is_initialized()):
|
||||
rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
if len(dataset) < world_size:
|
||||
raise RuntimeError(
|
||||
"Dimension of the dataset smaller than world size."
|
||||
" Increase the size of the partition or use a single GPU")
|
||||
idx, i = [], rank
|
||||
while i < len(dataset):
|
||||
idx.append(i)
|
||||
i += world_size
|
||||
self.dataset = dataset.fetch_from_idx_list(idx)
|
||||
else:
|
||||
self.dataset = dataset.get_all_data()
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
@@ -50,7 +79,7 @@ class Collator:
|
||||
for arg in condition_args:
|
||||
data_list = [batch[idx][condition_name][arg] for idx in range(
|
||||
min(len(batch),
|
||||
self.max_conditions_lengths[condition_name]))]
|
||||
self.max_conditions_lengths[condition_name]))]
|
||||
if isinstance(data_list[0], LabelTensor):
|
||||
single_cond_dict[arg] = LabelTensor.stack(data_list)
|
||||
elif isinstance(data_list[0], torch.Tensor):
|
||||
@@ -61,7 +90,6 @@ class Collator:
|
||||
batch_dict[condition_name] = single_cond_dict
|
||||
return batch_dict
|
||||
|
||||
|
||||
def __call__(self, batch):
|
||||
return self.callable_function(batch)
|
||||
|
||||
@@ -99,6 +127,7 @@ class PinaDataModule(LightningDataModule):
|
||||
):
|
||||
"""
|
||||
Initialize the object, creating dataset based on input problem
|
||||
:param problem: Problem where data are defined
|
||||
:param train_size: number/percentage of elements in train split
|
||||
:param test_size: number/percentage of elements in test split
|
||||
:param val_size: number/percentage of elements in evaluation split
|
||||
@@ -112,6 +141,9 @@ class PinaDataModule(LightningDataModule):
|
||||
self.shuffle = shuffle
|
||||
self.repeat = repeat
|
||||
|
||||
# Check if the splits are correct
|
||||
self._check_slit_sizes(train_size, test_size, val_size, predict_size)
|
||||
|
||||
# Begin Data splitting
|
||||
splits_dict = {}
|
||||
if train_size > 0:
|
||||
@@ -179,23 +211,28 @@ class PinaDataModule(LightningDataModule):
|
||||
len_condition = len(condition_dict['input_points'])
|
||||
|
||||
lengths = [
|
||||
int(math.floor(len_condition * length)) for length in
|
||||
int(len_condition * length) for length in
|
||||
splits_dict.values()
|
||||
]
|
||||
|
||||
remainder = len_condition - sum(lengths)
|
||||
for i in range(remainder):
|
||||
lengths[i % len(lengths)] += 1
|
||||
splits_dict = {k: v for k, v in zip(splits_dict.keys(), lengths)
|
||||
|
||||
splits_dict = {k: max(1, v) for k, v in zip(splits_dict.keys(), lengths)
|
||||
}
|
||||
to_return_dict = {}
|
||||
offset = 0
|
||||
|
||||
for stage, stage_len in splits_dict.items():
|
||||
to_return_dict[stage] = {k: v[offset:offset + stage_len]
|
||||
for k, v in condition_dict.items() if
|
||||
k != 'equation'
|
||||
# Equations are NEVER dataloaded
|
||||
}
|
||||
if offset + stage_len > len_condition:
|
||||
offset = len_condition - 1
|
||||
continue
|
||||
offset += stage_len
|
||||
return to_return_dict
|
||||
|
||||
@@ -234,6 +271,26 @@ class PinaDataModule(LightningDataModule):
|
||||
dataset_dict[key].update({condition_name: data})
|
||||
return dataset_dict
|
||||
|
||||
|
||||
def _create_dataloader(self, split, dataset):
|
||||
shuffle = self.shuffle if split == 'train' else False
|
||||
# Use custom batching (good if batch size is large)
|
||||
if self.batch_size is not None:
|
||||
sampler = PinaSampler(dataset, self.batch_size,
|
||||
shuffle, self.automatic_batching)
|
||||
if self.automatic_batching:
|
||||
collate = Collator(self.find_max_conditions_lengths(split))
|
||||
|
||||
else:
|
||||
collate = Collator(None, dataset)
|
||||
return DataLoader(dataset, self.batch_size,
|
||||
collate_fn=collate, sampler=sampler)
|
||||
dataloader = DummyDataloader(dataset)
|
||||
dataloader.dataset = self._transfer_batch_to_device(
|
||||
dataloader.dataset, self.trainer.strategy.root_device, 0)
|
||||
self.transfer_batch_to_device = self._transfer_batch_to_device_dummy
|
||||
return dataloader
|
||||
|
||||
def find_max_conditions_lengths(self, split):
|
||||
max_conditions_lengths = {}
|
||||
for k, v in self.collector_splits[split].items():
|
||||
@@ -250,52 +307,19 @@ class PinaDataModule(LightningDataModule):
|
||||
"""
|
||||
Create the validation dataloader
|
||||
"""
|
||||
# Use custom batching (good if batch size is large)
|
||||
if self.batch_size is not None:
|
||||
sampler = PinaSampler(self.val_dataset, self.batch_size,
|
||||
self.shuffle, self.automatic_batching)
|
||||
if self.automatic_batching:
|
||||
collate = Collator(self.find_max_conditions_lengths('val'))
|
||||
else:
|
||||
collate = Collator(None, self.val_dataset)
|
||||
return DataLoader(self.val_dataset, self.batch_size,
|
||||
collate_fn=collate, sampler=sampler)
|
||||
dataloader = DummyDataloader(self.val_dataset,
|
||||
self.trainer.strategy.root_device)
|
||||
dataloader.dataset = self._transfer_batch_to_device(dataloader.dataset,
|
||||
self.trainer.strategy.root_device,
|
||||
0)
|
||||
self.transfer_batch_to_device = self._transfer_batch_to_device_dummy
|
||||
return dataloader
|
||||
return self._create_dataloader('val', self.val_dataset)
|
||||
|
||||
def train_dataloader(self):
|
||||
"""
|
||||
Create the training dataloader
|
||||
"""
|
||||
# Use custom batching (good if batch size is large)
|
||||
if self.batch_size is not None:
|
||||
sampler = PinaSampler(self.train_dataset, self.batch_size,
|
||||
self.shuffle, self.automatic_batching)
|
||||
if self.automatic_batching:
|
||||
collate = Collator(self.find_max_conditions_lengths('train'))
|
||||
|
||||
else:
|
||||
collate = Collator(None, self.train_dataset)
|
||||
return DataLoader(self.train_dataset, self.batch_size,
|
||||
collate_fn=collate, sampler=sampler)
|
||||
dataloader = DummyDataloader(self.train_dataset,
|
||||
self.trainer.strategy.root_device)
|
||||
dataloader.dataset = self._transfer_batch_to_device(dataloader.dataset,
|
||||
self.trainer.strategy.root_device,
|
||||
0)
|
||||
self.transfer_batch_to_device = self._transfer_batch_to_device_dummy
|
||||
return dataloader
|
||||
return self._create_dataloader('train', self.train_dataset)
|
||||
|
||||
def test_dataloader(self):
|
||||
"""
|
||||
Create the testing dataloader
|
||||
"""
|
||||
raise NotImplementedError("Test dataloader not implemented")
|
||||
return self._create_dataloader('test', self.test_dataset)
|
||||
|
||||
def predict_dataloader(self):
|
||||
"""
|
||||
@@ -303,7 +327,8 @@ class PinaDataModule(LightningDataModule):
|
||||
"""
|
||||
raise NotImplementedError("Predict dataloader not implemented")
|
||||
|
||||
def _transfer_batch_to_device_dummy(self, batch, device, dataloader_idx):
|
||||
@staticmethod
|
||||
def _transfer_batch_to_device_dummy(batch, device, dataloader_idx):
|
||||
return batch
|
||||
|
||||
def _transfer_batch_to_device(self, batch, device, dataloader_idx):
|
||||
@@ -312,10 +337,34 @@ class PinaDataModule(LightningDataModule):
|
||||
training loop and is used to transfer the batch to the device.
|
||||
"""
|
||||
batch = [
|
||||
(k, super(LightningDataModule, self).transfer_batch_to_device(v,
|
||||
device,
|
||||
dataloader_idx))
|
||||
(k,
|
||||
super(LightningDataModule, self).transfer_batch_to_device(
|
||||
v, device, dataloader_idx))
|
||||
for k, v in batch.items()
|
||||
]
|
||||
|
||||
return batch
|
||||
|
||||
@staticmethod
|
||||
def _check_slit_sizes(train_size, test_size, val_size, predict_size):
|
||||
"""
|
||||
Check if the splits are correct
|
||||
"""
|
||||
if train_size < 0 or test_size < 0 or val_size < 0 or predict_size < 0:
|
||||
raise ValueError("The splits must be positive")
|
||||
if abs(train_size + test_size + val_size + predict_size - 1) > 1e-6:
|
||||
raise ValueError("The sum of the splits must be 1")
|
||||
|
||||
@property
|
||||
def input_points(self):
|
||||
"""
|
||||
# TODO
|
||||
"""
|
||||
to_return = {}
|
||||
if hasattr(self, "train_dataset") and self.train_dataset is not None:
|
||||
to_return["train"] = self.train_dataset.input_points
|
||||
if hasattr(self, "val_dataset") and self.val_dataset is not None:
|
||||
to_return["val"] = self.val_dataset.input_points
|
||||
if hasattr(self, "test_dataset") and self.test_dataset is not None:
|
||||
to_return = self.test_dataset.input_points
|
||||
return to_return
|
||||
|
||||
Reference in New Issue
Block a user