Fix Codacy Warnings (#477)

---------

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
Filippo Olivo
2025-03-10 15:38:45 +01:00
committed by Nicola Demo
parent e3790e049a
commit 4177bfbb50
157 changed files with 3473 additions and 3839 deletions

View File

@@ -9,15 +9,18 @@ from .solver import SolverInterface, PINNInterface
class Trainer(lightning.pytorch.Trainer):
"""
PINA custom Trainer class which allows to customize standard Lightning
Trainer class for PINNs training.
"""
def __init__(
self,
solver,
batch_size=None,
train_size=0.7,
test_size=0.2,
val_size=0.1,
predict_size=0.0,
train_size=1.0,
test_size=0.0,
val_size=0.0,
compile=None,
automatic_batching=None,
num_workers=None,
@@ -26,7 +29,8 @@ class Trainer(lightning.pytorch.Trainer):
**kwargs,
):
"""
PINA Trainer class for costumizing every aspect of training via flags.
Initialize the Trainer class for by calling Lightning costructor and
adding many other functionalities.
:param solver: A pina:class:`SolverInterface` solver for the
differential problem.
@@ -41,8 +45,6 @@ class Trainer(lightning.pytorch.Trainer):
:type test_size: float
:param val_size: Percentage of elements in the val dataset.
:type val_size: float
:param predict_size: Percentage of elements in the predict dataset.
:type predict_size: float
:param compile: if True model is compiled before training,
default False. For Windows users compilation is always disabled.
:type compile: bool
@@ -62,43 +64,23 @@ class Trainer(lightning.pytorch.Trainer):
:Keyword Arguments:
The additional keyword arguments specify the training setup
and can be choosen from the `pytorch-lightning
Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_
Trainer API <https://lightning.ai/docs/pytorch/stable/common/
trainer.html#trainer-class-api>`_
"""
# check consistency for init types
check_consistency(solver, SolverInterface)
check_consistency(train_size, float)
check_consistency(test_size, float)
check_consistency(val_size, float)
check_consistency(predict_size, float)
if automatic_batching is not None:
check_consistency(automatic_batching, bool)
if compile is not None:
check_consistency(compile, bool)
if pin_memory is not None:
check_consistency(pin_memory, bool)
else:
pin_memory = False
if num_workers is not None:
check_consistency(pin_memory, int)
else:
num_workers = 0
if shuffle is not None:
check_consistency(shuffle, bool)
else:
shuffle = True
if train_size + test_size + val_size + predict_size > 1:
raise ValueError(
"train_size, test_size, val_size and predict_size "
"must sum up to 1."
self._check_input_consistency(
solver,
train_size,
test_size,
val_size,
automatic_batching,
compile,
)
pin_memory, num_workers, shuffle, batch_size = (
self._check_consistency_and_set_defaults(
pin_memory, num_workers, shuffle, batch_size
)
for size in [train_size, test_size, val_size, predict_size]:
if size < 0 or size > 1:
raise ValueError(
"splitting sizes for train, validation, test "
"and prediction must be between [0, 1]."
)
if batch_size is not None:
check_consistency(batch_size, int)
)
# inference mode set to false when validating/testing PINNs otherwise
# gradient is not tracked and optimization_cycle fails
@@ -125,6 +107,7 @@ class Trainer(lightning.pytorch.Trainer):
self.automatic_batching = (
automatic_batching if automatic_batching is not None else False
)
# set attributes
self.compile = compile
self.solver = solver
@@ -135,7 +118,6 @@ class Trainer(lightning.pytorch.Trainer):
train_size,
test_size,
val_size,
predict_size,
batch_size,
automatic_batching,
pin_memory,
@@ -171,7 +153,6 @@ class Trainer(lightning.pytorch.Trainer):
train_size,
test_size,
val_size,
predict_size,
batch_size,
automatic_batching,
pin_memory,
@@ -187,7 +168,8 @@ class Trainer(lightning.pytorch.Trainer):
error_message = "\n".join(
[
f"""{" " * 13} ---> Domain {key} {
"sampled" if key in self.solver.problem.discretised_domains else
"sampled" if key in self.solver.problem.discretised_domains
else
"not sampled"}"""
for key in self.solver.problem.domains.keys()
]
@@ -202,7 +184,6 @@ class Trainer(lightning.pytorch.Trainer):
train_size=train_size,
test_size=test_size,
val_size=val_size,
predict_size=predict_size,
batch_size=batch_size,
automatic_batching=automatic_batching,
num_workers=num_workers,
@@ -232,3 +213,44 @@ class Trainer(lightning.pytorch.Trainer):
@solver.setter
def solver(self, solver):
self._solver = solver
@staticmethod
def _check_input_consistency(
solver, train_size, test_size, val_size, automatic_batching, compile
):
"""
Check the consistency of the input parameters."
"""
check_consistency(solver, SolverInterface)
check_consistency(train_size, float)
check_consistency(test_size, float)
check_consistency(val_size, float)
if automatic_batching is not None:
check_consistency(automatic_batching, bool)
if compile is not None:
check_consistency(compile, bool)
@staticmethod
def _check_consistency_and_set_defaults(
pin_memory, num_workers, shuffle, batch_size
):
"""
Check the consistency of the input parameters and set the default
values.
"""
if pin_memory is not None:
check_consistency(pin_memory, bool)
else:
pin_memory = False
if num_workers is not None:
check_consistency(pin_memory, int)
else:
num_workers = 0
if shuffle is not None:
check_consistency(shuffle, bool)
else:
shuffle = True
if batch_size is not None:
check_consistency(batch_size, int)
return pin_memory, num_workers, shuffle, batch_size