Update solvers (#434)
* Enable DDP training with batch_size=None and add validity check for split sizes * Refactoring SolverInterfaces (#435) * Solver update + weighting * Updating PINN for 0.2 * Modify GAROM + tests * Adding more versatile loggers * Disable compilation when running on Windows * Fix tests --------- Co-authored-by: giovanni <giovanni.canali98@yahoo.it> Co-authored-by: FilippoOlivo <filippo@filippoolivo.com>
This commit is contained in:
committed by
Nicola Demo
parent
780c4921eb
commit
9cae9a438f
117
pina/trainer.py
117
pina/trainer.py
@@ -1,9 +1,10 @@
|
||||
""" Trainer module. """
|
||||
import sys
|
||||
import torch
|
||||
import lightning
|
||||
from .utils import check_consistency
|
||||
from .data import PinaDataModule
|
||||
from .solvers.solver import SolverInterface
|
||||
from .solvers import SolverInterface, PINNInterface
|
||||
|
||||
|
||||
class Trainer(lightning.pytorch.Trainer):
|
||||
@@ -14,29 +15,88 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
train_size=.7,
|
||||
test_size=.2,
|
||||
val_size=.1,
|
||||
predict_size=.0,
|
||||
predict_size=0.,
|
||||
compile=None,
|
||||
automatic_batching=None,
|
||||
**kwargs):
|
||||
"""
|
||||
PINA Trainer class for costumizing every aspect of training via flags.
|
||||
|
||||
:param solver: A pina:class:`SolverInterface` solver for the differential problem.
|
||||
:param solver: A pina:class:`SolverInterface` solver for the
|
||||
differential problem.
|
||||
:type solver: SolverInterface
|
||||
:param batch_size: How many samples per batch to load. If ``batch_size=None`` all
|
||||
:param batch_size: How many samples per batch to load.
|
||||
If ``batch_size=None`` all
|
||||
samples are loaded and data are not batched, defaults to None.
|
||||
:type batch_size: int | None
|
||||
:param train_size: percentage of elements in the train dataset
|
||||
:type train_size: float
|
||||
:param test_size: percentage of elements in the test dataset
|
||||
:type test_size: float
|
||||
:param val_size: percentage of elements in the val dataset
|
||||
:type val_size: float
|
||||
:param predict_size: percentage of elements in the predict dataset
|
||||
:type predict_size: float
|
||||
:param compile: if True model is compiled before training,
|
||||
default False. For Windows users compilation is always disabled.
|
||||
:type compile: bool
|
||||
:param automatic_batching: if True automatic PyTorch batching is
|
||||
performed. Please avoid using automatic batching when batch_size is
|
||||
large, default False.
|
||||
:type automatic_batching: bool
|
||||
|
||||
:Keyword Arguments:
|
||||
The additional keyword arguments specify the training setup
|
||||
and can be choosen from the `pytorch-lightning
|
||||
Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_
|
||||
"""
|
||||
# check consistency for init types
|
||||
check_consistency(solver, SolverInterface)
|
||||
check_consistency(train_size, float)
|
||||
check_consistency(test_size, float)
|
||||
check_consistency(val_size, float)
|
||||
check_consistency(predict_size, float)
|
||||
if automatic_batching is not None:
|
||||
check_consistency(automatic_batching, bool)
|
||||
if compile is not None:
|
||||
check_consistency(compile, bool)
|
||||
if train_size + test_size + val_size + predict_size > 1:
|
||||
raise ValueError('train_size, test_size, val_size and predict_size '
|
||||
'must sum up to 1.')
|
||||
for size in [train_size, test_size, val_size, predict_size]:
|
||||
if size < 0 or size > 1:
|
||||
raise ValueError('splitting sizes for train, validation, test '
|
||||
'and prediction must be between [0, 1].')
|
||||
if batch_size is not None:
|
||||
check_consistency(batch_size, int)
|
||||
|
||||
# inference mode set to false when validating/testing PINNs otherwise
|
||||
# gradient is not tracked and optimization_cycle fails
|
||||
if isinstance(solver, PINNInterface):
|
||||
kwargs['inference_mode'] = False
|
||||
|
||||
# Logging depends on the batch size, when batch_size is None then
|
||||
# log_every_n_steps should be zero
|
||||
if batch_size is None:
|
||||
kwargs['log_every_n_steps'] = 0
|
||||
else:
|
||||
kwargs.setdefault('log_every_n_steps', 50) # default for lightning
|
||||
|
||||
# Setting default kwargs, overriding lightning defaults
|
||||
kwargs.setdefault('enable_progress_bar', True)
|
||||
kwargs.setdefault('logger', None)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# check inheritance consistency for solver and batch size
|
||||
check_consistency(solver, SolverInterface)
|
||||
if batch_size is not None:
|
||||
check_consistency(batch_size, int)
|
||||
# checking compilation and automatic batching
|
||||
if compile is None or sys.platform == "win32":
|
||||
compile = False
|
||||
if automatic_batching is None:
|
||||
automatic_batching = False
|
||||
|
||||
# set attributes
|
||||
self.compile = compile
|
||||
self.automatic_batching = automatic_batching
|
||||
self.train_size = train_size
|
||||
self.test_size = test_size
|
||||
self.val_size = val_size
|
||||
@@ -47,9 +107,19 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
self.data_module = None
|
||||
self._create_loader()
|
||||
|
||||
# logging
|
||||
self.logging_kwargs = {
|
||||
'logger': bool(
|
||||
kwargs['logger'] is None or kwargs['logger'] is True),
|
||||
'sync_dist': bool(
|
||||
len(self._accelerator_connector._parallel_devices) > 1),
|
||||
'on_step': bool(kwargs['log_every_n_steps'] > 0),
|
||||
'prog_bar': bool(kwargs['enable_progress_bar']),
|
||||
'on_epoch': True
|
||||
}
|
||||
|
||||
def _move_to_device(self):
|
||||
device = self._accelerator_connector._parallel_devices[0]
|
||||
|
||||
# move parameters to device
|
||||
pb = self.solver.problem
|
||||
if hasattr(pb, "unknown_parameters"):
|
||||
@@ -65,37 +135,34 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
"""
|
||||
if not self.solver.problem.are_all_domains_discretised:
|
||||
error_message = '\n'.join([
|
||||
f"""{" " * 13} ---> Domain {key} {"sampled" if key in self.solver.problem.discretised_domains else
|
||||
"not sampled"}""" for key in
|
||||
f"""{" " * 13} ---> Domain {key} {
|
||||
"sampled" if key in self.solver.problem.discretised_domains else
|
||||
"not sampled"}""" for key in
|
||||
self.solver.problem.domains.keys()
|
||||
])
|
||||
raise RuntimeError('Cannot create Trainer if not all conditions '
|
||||
'are sampled. The Trainer got the following:\n'
|
||||
f'{error_message}')
|
||||
automatic_batching = False
|
||||
self.data_module = PinaDataModule(self.solver.problem,
|
||||
train_size=self.train_size,
|
||||
test_size=self.test_size,
|
||||
val_size=self.val_size,
|
||||
predict_size=self.predict_size,
|
||||
batch_size=self.batch_size,
|
||||
automatic_batching=automatic_batching)
|
||||
self.data_module = PinaDataModule(
|
||||
self.solver.problem,
|
||||
train_size=self.train_size,
|
||||
test_size=self.test_size,
|
||||
val_size=self.val_size,
|
||||
predict_size=self.predict_size,
|
||||
batch_size=self.batch_size,
|
||||
automatic_batching=self.automatic_batching)
|
||||
|
||||
def train(self, **kwargs):
|
||||
"""
|
||||
Train the solver method.
|
||||
"""
|
||||
return super().fit(self.solver,
|
||||
datamodule=self.data_module,
|
||||
**kwargs)
|
||||
return super().fit(self.solver, datamodule=self.data_module, **kwargs)
|
||||
|
||||
def test(self, **kwargs):
|
||||
"""
|
||||
Test the solver method.
|
||||
"""
|
||||
return super().test(self.solver,
|
||||
datamodule=self.data_module,
|
||||
**kwargs)
|
||||
return super().test(self.solver, datamodule=self.data_module, **kwargs)
|
||||
|
||||
@property
|
||||
def solver(self):
|
||||
|
||||
Reference in New Issue
Block a user