Formatting
* Adding black as dev dependency * Formatting pina code * Formatting tests
This commit is contained in:
committed by
Nicola Demo
parent
4c4482b155
commit
42ab1a666b
131
pina/trainer.py
131
pina/trainer.py
@@ -1,4 +1,5 @@
|
||||
""" Trainer module. """
|
||||
"""Trainer module."""
|
||||
|
||||
import sys
|
||||
import torch
|
||||
import lightning
|
||||
@@ -9,18 +10,20 @@ from .solver import SolverInterface, PINNInterface
|
||||
|
||||
class Trainer(lightning.pytorch.Trainer):
|
||||
|
||||
def __init__(self,
|
||||
solver,
|
||||
batch_size=None,
|
||||
train_size=.7,
|
||||
test_size=.2,
|
||||
val_size=.1,
|
||||
predict_size=0.,
|
||||
compile=None,
|
||||
automatic_batching=None,
|
||||
num_workers=None,
|
||||
pin_memory=None,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
solver,
|
||||
batch_size=None,
|
||||
train_size=0.7,
|
||||
test_size=0.2,
|
||||
val_size=0.1,
|
||||
predict_size=0.0,
|
||||
compile=None,
|
||||
automatic_batching=None,
|
||||
num_workers=None,
|
||||
pin_memory=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
PINA Trainer class for costumizing every aspect of training via flags.
|
||||
|
||||
@@ -75,30 +78,34 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
else:
|
||||
num_workers = 0
|
||||
if train_size + test_size + val_size + predict_size > 1:
|
||||
raise ValueError('train_size, test_size, val_size and predict_size '
|
||||
'must sum up to 1.')
|
||||
raise ValueError(
|
||||
"train_size, test_size, val_size and predict_size "
|
||||
"must sum up to 1."
|
||||
)
|
||||
for size in [train_size, test_size, val_size, predict_size]:
|
||||
if size < 0 or size > 1:
|
||||
raise ValueError('splitting sizes for train, validation, test '
|
||||
'and prediction must be between [0, 1].')
|
||||
raise ValueError(
|
||||
"splitting sizes for train, validation, test "
|
||||
"and prediction must be between [0, 1]."
|
||||
)
|
||||
if batch_size is not None:
|
||||
check_consistency(batch_size, int)
|
||||
|
||||
# inference mode set to false when validating/testing PINNs otherwise
|
||||
# gradient is not tracked and optimization_cycle fails
|
||||
if isinstance(solver, PINNInterface):
|
||||
kwargs['inference_mode'] = False
|
||||
kwargs["inference_mode"] = False
|
||||
|
||||
# Logging depends on the batch size, when batch_size is None then
|
||||
# log_every_n_steps should be zero
|
||||
if batch_size is None:
|
||||
kwargs['log_every_n_steps'] = 0
|
||||
kwargs["log_every_n_steps"] = 0
|
||||
else:
|
||||
kwargs.setdefault('log_every_n_steps', 50) # default for lightning
|
||||
kwargs.setdefault("log_every_n_steps", 50) # default for lightning
|
||||
|
||||
# Setting default kwargs, overriding lightning defaults
|
||||
kwargs.setdefault('enable_progress_bar', True)
|
||||
kwargs.setdefault('logger', None)
|
||||
kwargs.setdefault("enable_progress_bar", True)
|
||||
kwargs.setdefault("logger", None)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -106,27 +113,37 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
if compile is None or sys.platform == "win32":
|
||||
compile = False
|
||||
|
||||
self.automatic_batching = automatic_batching if automatic_batching \
|
||||
is not None else False
|
||||
self.automatic_batching = (
|
||||
automatic_batching if automatic_batching is not None else False
|
||||
)
|
||||
# set attributes
|
||||
self.compile = compile
|
||||
self.solver = solver
|
||||
self.batch_size = batch_size
|
||||
self._move_to_device()
|
||||
self.data_module = None
|
||||
self._create_datamodule(train_size, test_size, val_size, predict_size,
|
||||
batch_size, automatic_batching, pin_memory,
|
||||
num_workers)
|
||||
self._create_datamodule(
|
||||
train_size,
|
||||
test_size,
|
||||
val_size,
|
||||
predict_size,
|
||||
batch_size,
|
||||
automatic_batching,
|
||||
pin_memory,
|
||||
num_workers,
|
||||
)
|
||||
|
||||
# logging
|
||||
self.logging_kwargs = {
|
||||
'logger': bool(
|
||||
kwargs['logger'] is None or kwargs['logger'] is True),
|
||||
'sync_dist': bool(
|
||||
len(self._accelerator_connector._parallel_devices) > 1),
|
||||
'on_step': bool(kwargs['log_every_n_steps'] > 0),
|
||||
'prog_bar': bool(kwargs['enable_progress_bar']),
|
||||
'on_epoch': True
|
||||
"logger": bool(
|
||||
kwargs["logger"] is None or kwargs["logger"] is True
|
||||
),
|
||||
"sync_dist": bool(
|
||||
len(self._accelerator_connector._parallel_devices) > 1
|
||||
),
|
||||
"on_step": bool(kwargs["log_every_n_steps"] > 0),
|
||||
"prog_bar": bool(kwargs["enable_progress_bar"]),
|
||||
"on_epoch": True,
|
||||
}
|
||||
|
||||
def _move_to_device(self):
|
||||
@@ -136,32 +153,39 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
if hasattr(pb, "unknown_parameters"):
|
||||
for key in pb.unknown_parameters:
|
||||
pb.unknown_parameters[key] = torch.nn.Parameter(
|
||||
pb.unknown_parameters[key].data.to(device))
|
||||
pb.unknown_parameters[key].data.to(device)
|
||||
)
|
||||
|
||||
def _create_datamodule(self,
|
||||
train_size,
|
||||
test_size,
|
||||
val_size,
|
||||
predict_size,
|
||||
batch_size,
|
||||
automatic_batching,
|
||||
pin_memory,
|
||||
num_workers):
|
||||
def _create_datamodule(
|
||||
self,
|
||||
train_size,
|
||||
test_size,
|
||||
val_size,
|
||||
predict_size,
|
||||
batch_size,
|
||||
automatic_batching,
|
||||
pin_memory,
|
||||
num_workers,
|
||||
):
|
||||
"""
|
||||
This method is used here because is resampling is needed
|
||||
during training, there is no need to define to touch the
|
||||
trainer dataloader, just call the method.
|
||||
"""
|
||||
if not self.solver.problem.are_all_domains_discretised:
|
||||
error_message = '\n'.join([
|
||||
f"""{" " * 13} ---> Domain {key} {
|
||||
error_message = "\n".join(
|
||||
[
|
||||
f"""{" " * 13} ---> Domain {key} {
|
||||
"sampled" if key in self.solver.problem.discretised_domains else
|
||||
"not sampled"}""" for key in
|
||||
self.solver.problem.domains.keys()
|
||||
])
|
||||
raise RuntimeError('Cannot create Trainer if not all conditions '
|
||||
'are sampled. The Trainer got the following:\n'
|
||||
f'{error_message}')
|
||||
"not sampled"}"""
|
||||
for key in self.solver.problem.domains.keys()
|
||||
]
|
||||
)
|
||||
raise RuntimeError(
|
||||
"Cannot create Trainer if not all conditions "
|
||||
"are sampled. The Trainer got the following:\n"
|
||||
f"{error_message}"
|
||||
)
|
||||
self.data_module = PinaDataModule(
|
||||
self.solver.problem,
|
||||
train_size=train_size,
|
||||
@@ -171,7 +195,8 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
batch_size=batch_size,
|
||||
automatic_batching=automatic_batching,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory)
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
|
||||
def train(self, **kwargs):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user