Formatting

* Adding black as dev dependency
* Formatting pina code
* Formatting tests
This commit is contained in:
Dario Coscia
2025-02-24 11:26:49 +01:00
committed by Nicola Demo
parent 4c4482b155
commit 42ab1a666b
77 changed files with 1170 additions and 924 deletions

View File

@@ -1,4 +1,5 @@
""" Trainer module. """
"""Trainer module."""
import sys
import torch
import lightning
@@ -9,18 +10,20 @@ from .solver import SolverInterface, PINNInterface
class Trainer(lightning.pytorch.Trainer):
def __init__(self,
solver,
batch_size=None,
train_size=.7,
test_size=.2,
val_size=.1,
predict_size=0.,
compile=None,
automatic_batching=None,
num_workers=None,
pin_memory=None,
**kwargs):
def __init__(
self,
solver,
batch_size=None,
train_size=0.7,
test_size=0.2,
val_size=0.1,
predict_size=0.0,
compile=None,
automatic_batching=None,
num_workers=None,
pin_memory=None,
**kwargs,
):
"""
PINA Trainer class for costumizing every aspect of training via flags.
@@ -75,30 +78,34 @@ class Trainer(lightning.pytorch.Trainer):
else:
num_workers = 0
if train_size + test_size + val_size + predict_size > 1:
raise ValueError('train_size, test_size, val_size and predict_size '
'must sum up to 1.')
raise ValueError(
"train_size, test_size, val_size and predict_size "
"must sum up to 1."
)
for size in [train_size, test_size, val_size, predict_size]:
if size < 0 or size > 1:
raise ValueError('splitting sizes for train, validation, test '
'and prediction must be between [0, 1].')
raise ValueError(
"splitting sizes for train, validation, test "
"and prediction must be between [0, 1]."
)
if batch_size is not None:
check_consistency(batch_size, int)
# inference mode set to false when validating/testing PINNs otherwise
# gradient is not tracked and optimization_cycle fails
if isinstance(solver, PINNInterface):
kwargs['inference_mode'] = False
kwargs["inference_mode"] = False
# Logging depends on the batch size, when batch_size is None then
# log_every_n_steps should be zero
if batch_size is None:
kwargs['log_every_n_steps'] = 0
kwargs["log_every_n_steps"] = 0
else:
kwargs.setdefault('log_every_n_steps', 50) # default for lightning
kwargs.setdefault("log_every_n_steps", 50) # default for lightning
# Setting default kwargs, overriding lightning defaults
kwargs.setdefault('enable_progress_bar', True)
kwargs.setdefault('logger', None)
kwargs.setdefault("enable_progress_bar", True)
kwargs.setdefault("logger", None)
super().__init__(**kwargs)
@@ -106,27 +113,37 @@ class Trainer(lightning.pytorch.Trainer):
if compile is None or sys.platform == "win32":
compile = False
self.automatic_batching = automatic_batching if automatic_batching \
is not None else False
self.automatic_batching = (
automatic_batching if automatic_batching is not None else False
)
# set attributes
self.compile = compile
self.solver = solver
self.batch_size = batch_size
self._move_to_device()
self.data_module = None
self._create_datamodule(train_size, test_size, val_size, predict_size,
batch_size, automatic_batching, pin_memory,
num_workers)
self._create_datamodule(
train_size,
test_size,
val_size,
predict_size,
batch_size,
automatic_batching,
pin_memory,
num_workers,
)
# logging
self.logging_kwargs = {
'logger': bool(
kwargs['logger'] is None or kwargs['logger'] is True),
'sync_dist': bool(
len(self._accelerator_connector._parallel_devices) > 1),
'on_step': bool(kwargs['log_every_n_steps'] > 0),
'prog_bar': bool(kwargs['enable_progress_bar']),
'on_epoch': True
"logger": bool(
kwargs["logger"] is None or kwargs["logger"] is True
),
"sync_dist": bool(
len(self._accelerator_connector._parallel_devices) > 1
),
"on_step": bool(kwargs["log_every_n_steps"] > 0),
"prog_bar": bool(kwargs["enable_progress_bar"]),
"on_epoch": True,
}
def _move_to_device(self):
@@ -136,32 +153,39 @@ class Trainer(lightning.pytorch.Trainer):
if hasattr(pb, "unknown_parameters"):
for key in pb.unknown_parameters:
pb.unknown_parameters[key] = torch.nn.Parameter(
pb.unknown_parameters[key].data.to(device))
pb.unknown_parameters[key].data.to(device)
)
def _create_datamodule(self,
train_size,
test_size,
val_size,
predict_size,
batch_size,
automatic_batching,
pin_memory,
num_workers):
def _create_datamodule(
self,
train_size,
test_size,
val_size,
predict_size,
batch_size,
automatic_batching,
pin_memory,
num_workers,
):
"""
This method is used here because is resampling is needed
during training, there is no need to define to touch the
trainer dataloader, just call the method.
"""
if not self.solver.problem.are_all_domains_discretised:
error_message = '\n'.join([
f"""{" " * 13} ---> Domain {key} {
error_message = "\n".join(
[
f"""{" " * 13} ---> Domain {key} {
"sampled" if key in self.solver.problem.discretised_domains else
"not sampled"}""" for key in
self.solver.problem.domains.keys()
])
raise RuntimeError('Cannot create Trainer if not all conditions '
'are sampled. The Trainer got the following:\n'
f'{error_message}')
"not sampled"}"""
for key in self.solver.problem.domains.keys()
]
)
raise RuntimeError(
"Cannot create Trainer if not all conditions "
"are sampled. The Trainer got the following:\n"
f"{error_message}"
)
self.data_module = PinaDataModule(
self.solver.problem,
train_size=train_size,
@@ -171,7 +195,8 @@ class Trainer(lightning.pytorch.Trainer):
batch_size=batch_size,
automatic_batching=automatic_batching,
num_workers=num_workers,
pin_memory=pin_memory)
pin_memory=pin_memory,
)
def train(self, **kwargs):
"""