Fix Codacy Warnings (#477)
--------- Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
e3790e049a
commit
4177bfbb50
112
pina/trainer.py
112
pina/trainer.py
@@ -9,15 +9,18 @@ from .solver import SolverInterface, PINNInterface
|
||||
|
||||
|
||||
class Trainer(lightning.pytorch.Trainer):
|
||||
"""
|
||||
PINA custom Trainer class which allows to customize standard Lightning
|
||||
Trainer class for PINNs training.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
solver,
|
||||
batch_size=None,
|
||||
train_size=0.7,
|
||||
test_size=0.2,
|
||||
val_size=0.1,
|
||||
predict_size=0.0,
|
||||
train_size=1.0,
|
||||
test_size=0.0,
|
||||
val_size=0.0,
|
||||
compile=None,
|
||||
automatic_batching=None,
|
||||
num_workers=None,
|
||||
@@ -26,7 +29,8 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
PINA Trainer class for costumizing every aspect of training via flags.
|
||||
Initialize the Trainer class for by calling Lightning costructor and
|
||||
adding many other functionalities.
|
||||
|
||||
:param solver: A pina:class:`SolverInterface` solver for the
|
||||
differential problem.
|
||||
@@ -41,8 +45,6 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
:type test_size: float
|
||||
:param val_size: Percentage of elements in the val dataset.
|
||||
:type val_size: float
|
||||
:param predict_size: Percentage of elements in the predict dataset.
|
||||
:type predict_size: float
|
||||
:param compile: if True model is compiled before training,
|
||||
default False. For Windows users compilation is always disabled.
|
||||
:type compile: bool
|
||||
@@ -62,43 +64,23 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
:Keyword Arguments:
|
||||
The additional keyword arguments specify the training setup
|
||||
and can be choosen from the `pytorch-lightning
|
||||
Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_
|
||||
Trainer API <https://lightning.ai/docs/pytorch/stable/common/
|
||||
trainer.html#trainer-class-api>`_
|
||||
"""
|
||||
# check consistency for init types
|
||||
check_consistency(solver, SolverInterface)
|
||||
check_consistency(train_size, float)
|
||||
check_consistency(test_size, float)
|
||||
check_consistency(val_size, float)
|
||||
check_consistency(predict_size, float)
|
||||
if automatic_batching is not None:
|
||||
check_consistency(automatic_batching, bool)
|
||||
if compile is not None:
|
||||
check_consistency(compile, bool)
|
||||
if pin_memory is not None:
|
||||
check_consistency(pin_memory, bool)
|
||||
else:
|
||||
pin_memory = False
|
||||
if num_workers is not None:
|
||||
check_consistency(pin_memory, int)
|
||||
else:
|
||||
num_workers = 0
|
||||
if shuffle is not None:
|
||||
check_consistency(shuffle, bool)
|
||||
else:
|
||||
shuffle = True
|
||||
if train_size + test_size + val_size + predict_size > 1:
|
||||
raise ValueError(
|
||||
"train_size, test_size, val_size and predict_size "
|
||||
"must sum up to 1."
|
||||
self._check_input_consistency(
|
||||
solver,
|
||||
train_size,
|
||||
test_size,
|
||||
val_size,
|
||||
automatic_batching,
|
||||
compile,
|
||||
)
|
||||
pin_memory, num_workers, shuffle, batch_size = (
|
||||
self._check_consistency_and_set_defaults(
|
||||
pin_memory, num_workers, shuffle, batch_size
|
||||
)
|
||||
for size in [train_size, test_size, val_size, predict_size]:
|
||||
if size < 0 or size > 1:
|
||||
raise ValueError(
|
||||
"splitting sizes for train, validation, test "
|
||||
"and prediction must be between [0, 1]."
|
||||
)
|
||||
if batch_size is not None:
|
||||
check_consistency(batch_size, int)
|
||||
)
|
||||
|
||||
# inference mode set to false when validating/testing PINNs otherwise
|
||||
# gradient is not tracked and optimization_cycle fails
|
||||
@@ -125,6 +107,7 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
self.automatic_batching = (
|
||||
automatic_batching if automatic_batching is not None else False
|
||||
)
|
||||
|
||||
# set attributes
|
||||
self.compile = compile
|
||||
self.solver = solver
|
||||
@@ -135,7 +118,6 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
train_size,
|
||||
test_size,
|
||||
val_size,
|
||||
predict_size,
|
||||
batch_size,
|
||||
automatic_batching,
|
||||
pin_memory,
|
||||
@@ -171,7 +153,6 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
train_size,
|
||||
test_size,
|
||||
val_size,
|
||||
predict_size,
|
||||
batch_size,
|
||||
automatic_batching,
|
||||
pin_memory,
|
||||
@@ -187,7 +168,8 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
error_message = "\n".join(
|
||||
[
|
||||
f"""{" " * 13} ---> Domain {key} {
|
||||
"sampled" if key in self.solver.problem.discretised_domains else
|
||||
"sampled" if key in self.solver.problem.discretised_domains
|
||||
else
|
||||
"not sampled"}"""
|
||||
for key in self.solver.problem.domains.keys()
|
||||
]
|
||||
@@ -202,7 +184,6 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
train_size=train_size,
|
||||
test_size=test_size,
|
||||
val_size=val_size,
|
||||
predict_size=predict_size,
|
||||
batch_size=batch_size,
|
||||
automatic_batching=automatic_batching,
|
||||
num_workers=num_workers,
|
||||
@@ -232,3 +213,44 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
@solver.setter
|
||||
def solver(self, solver):
|
||||
self._solver = solver
|
||||
|
||||
@staticmethod
|
||||
def _check_input_consistency(
|
||||
solver, train_size, test_size, val_size, automatic_batching, compile
|
||||
):
|
||||
"""
|
||||
Check the consistency of the input parameters."
|
||||
"""
|
||||
|
||||
check_consistency(solver, SolverInterface)
|
||||
check_consistency(train_size, float)
|
||||
check_consistency(test_size, float)
|
||||
check_consistency(val_size, float)
|
||||
if automatic_batching is not None:
|
||||
check_consistency(automatic_batching, bool)
|
||||
if compile is not None:
|
||||
check_consistency(compile, bool)
|
||||
|
||||
@staticmethod
|
||||
def _check_consistency_and_set_defaults(
|
||||
pin_memory, num_workers, shuffle, batch_size
|
||||
):
|
||||
"""
|
||||
Check the consistency of the input parameters and set the default
|
||||
values.
|
||||
"""
|
||||
if pin_memory is not None:
|
||||
check_consistency(pin_memory, bool)
|
||||
else:
|
||||
pin_memory = False
|
||||
if num_workers is not None:
|
||||
check_consistency(pin_memory, int)
|
||||
else:
|
||||
num_workers = 0
|
||||
if shuffle is not None:
|
||||
check_consistency(shuffle, bool)
|
||||
else:
|
||||
shuffle = True
|
||||
if batch_size is not None:
|
||||
check_consistency(batch_size, int)
|
||||
return pin_memory, num_workers, shuffle, batch_size
|
||||
|
||||
Reference in New Issue
Block a user