Update solvers (#434)

* Enable DDP training with batch_size=None and add validity check for split sizes
* Refactoring SolverInterfaces (#435)
* Solver update + weighting
* Updating PINN for 0.2
* Modify GAROM + tests
* Adding more versatile loggers
* Disable compilation when running on Windows
* Fix tests

---------

Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
Co-authored-by: FilippoOlivo <filippo@filippoolivo.com>
This commit is contained in:
Dario Coscia
2025-02-17 11:26:21 +01:00
committed by Nicola Demo
parent 780c4921eb
commit 9cae9a438f
50 changed files with 2848 additions and 4187 deletions

View File

@@ -1,9 +1,10 @@
""" Trainer module. """
import sys
import torch
import lightning
from .utils import check_consistency
from .data import PinaDataModule
from .solvers.solver import SolverInterface
from .solvers import SolverInterface, PINNInterface
class Trainer(lightning.pytorch.Trainer):
@@ -14,29 +15,88 @@ class Trainer(lightning.pytorch.Trainer):
train_size=.7,
test_size=.2,
val_size=.1,
predict_size=.0,
predict_size=0.,
compile=None,
automatic_batching=None,
**kwargs):
"""
PINA Trainer class for costumizing every aspect of training via flags.
:param solver: A pina:class:`SolverInterface` solver for the differential problem.
:param solver: A pina:class:`SolverInterface` solver for the
differential problem.
:type solver: SolverInterface
:param batch_size: How many samples per batch to load. If ``batch_size=None`` all
:param batch_size: How many samples per batch to load.
If ``batch_size=None`` all
samples are loaded and data are not batched, defaults to None.
:type batch_size: int | None
:param train_size: percentage of elements in the train dataset
:type train_size: float
:param test_size: percentage of elements in the test dataset
:type test_size: float
:param val_size: percentage of elements in the val dataset
:type val_size: float
:param predict_size: percentage of elements in the predict dataset
:type predict_size: float
:param compile: if True model is compiled before training,
default False. For Windows users compilation is always disabled.
:type compile: bool
:param automatic_batching: if True automatic PyTorch batching is
performed. Please avoid using automatic batching when batch_size is
large, default False.
:type automatic_batching: bool
:Keyword Arguments:
The additional keyword arguments specify the training setup
and can be choosen from the `pytorch-lightning
Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_
"""
# check consistency for init types
check_consistency(solver, SolverInterface)
check_consistency(train_size, float)
check_consistency(test_size, float)
check_consistency(val_size, float)
check_consistency(predict_size, float)
if automatic_batching is not None:
check_consistency(automatic_batching, bool)
if compile is not None:
check_consistency(compile, bool)
if train_size + test_size + val_size + predict_size > 1:
raise ValueError('train_size, test_size, val_size and predict_size '
'must sum up to 1.')
for size in [train_size, test_size, val_size, predict_size]:
if size < 0 or size > 1:
raise ValueError('splitting sizes for train, validation, test '
'and prediction must be between [0, 1].')
if batch_size is not None:
check_consistency(batch_size, int)
# inference mode set to false when validating/testing PINNs otherwise
# gradient is not tracked and optimization_cycle fails
if isinstance(solver, PINNInterface):
kwargs['inference_mode'] = False
# Logging depends on the batch size, when batch_size is None then
# log_every_n_steps should be zero
if batch_size is None:
kwargs['log_every_n_steps'] = 0
else:
kwargs.setdefault('log_every_n_steps', 50) # default for lightning
# Setting default kwargs, overriding lightning defaults
kwargs.setdefault('enable_progress_bar', True)
kwargs.setdefault('logger', None)
super().__init__(**kwargs)
# check inheritance consistency for solver and batch size
check_consistency(solver, SolverInterface)
if batch_size is not None:
check_consistency(batch_size, int)
# checking compilation and automatic batching
if compile is None or sys.platform == "win32":
compile = False
if automatic_batching is None:
automatic_batching = False
# set attributes
self.compile = compile
self.automatic_batching = automatic_batching
self.train_size = train_size
self.test_size = test_size
self.val_size = val_size
@@ -47,9 +107,19 @@ class Trainer(lightning.pytorch.Trainer):
self.data_module = None
self._create_loader()
# logging
self.logging_kwargs = {
'logger': bool(
kwargs['logger'] is None or kwargs['logger'] is True),
'sync_dist': bool(
len(self._accelerator_connector._parallel_devices) > 1),
'on_step': bool(kwargs['log_every_n_steps'] > 0),
'prog_bar': bool(kwargs['enable_progress_bar']),
'on_epoch': True
}
def _move_to_device(self):
device = self._accelerator_connector._parallel_devices[0]
# move parameters to device
pb = self.solver.problem
if hasattr(pb, "unknown_parameters"):
@@ -65,37 +135,34 @@ class Trainer(lightning.pytorch.Trainer):
"""
if not self.solver.problem.are_all_domains_discretised:
error_message = '\n'.join([
f"""{" " * 13} ---> Domain {key} {"sampled" if key in self.solver.problem.discretised_domains else
"not sampled"}""" for key in
f"""{" " * 13} ---> Domain {key} {
"sampled" if key in self.solver.problem.discretised_domains else
"not sampled"}""" for key in
self.solver.problem.domains.keys()
])
raise RuntimeError('Cannot create Trainer if not all conditions '
'are sampled. The Trainer got the following:\n'
f'{error_message}')
automatic_batching = False
self.data_module = PinaDataModule(self.solver.problem,
train_size=self.train_size,
test_size=self.test_size,
val_size=self.val_size,
predict_size=self.predict_size,
batch_size=self.batch_size,
automatic_batching=automatic_batching)
self.data_module = PinaDataModule(
self.solver.problem,
train_size=self.train_size,
test_size=self.test_size,
val_size=self.val_size,
predict_size=self.predict_size,
batch_size=self.batch_size,
automatic_batching=self.automatic_batching)
def train(self, **kwargs):
"""
Train the solver method.
"""
return super().fit(self.solver,
datamodule=self.data_module,
**kwargs)
return super().fit(self.solver, datamodule=self.data_module, **kwargs)
def test(self, **kwargs):
"""
Test the solver method.
"""
return super().test(self.solver,
datamodule=self.data_module,
**kwargs)
return super().test(self.solver, datamodule=self.data_module, **kwargs)
@property
def solver(self):