Correct codacy warnings

This commit is contained in:
FilippoOlivo
2024-10-22 14:54:22 +02:00
committed by Nicola Demo
parent 1bc1b3a580
commit 3e30450e9a
10 changed files with 60 additions and 37 deletions

View File

@@ -9,7 +9,8 @@ from .solvers.solver import SolverInterface
class Trainer(pytorch_lightning.Trainer):
def __init__(self, solver, batch_size=None, train_size=.7, test_size=.2, eval_size=.1, **kwargs):
def __init__(self, solver, batch_size=None, train_size=.7, test_size=.2,
eval_size=.1, **kwargs):
"""
PINA Trainer class for costumizing every aspect of training via flags.
@@ -39,10 +40,9 @@ class Trainer(pytorch_lightning.Trainer):
self._create_loader()
self._move_to_device()
def _move_to_device(self):
device = self._accelerator_connector._parallel_devices[0]
# move parameters to device
pb = self.solver.problem
if hasattr(pb, "unknown_parameters"):
@@ -59,11 +59,13 @@ class Trainer(pytorch_lightning.Trainer):
"""
if not self.solver.problem.collector.full:
error_message = '\n'.join(
[f'{" " * 13} ---> Condition {key} {"sampled" if value else "not sampled"}'
for key, value in self.solver.problem.collector._is_conditions_ready.items()])
[
f'{" " * 13} ---> Condition {key} {"sampled" if value else "not sampled"}'
for key, value in
self.solver.problem.collector._is_conditions_ready.items()])
raise RuntimeError('Cannot create Trainer if not all conditions '
'are sampled. The Trainer got the following:\n'
f'{error_message}')
'are sampled. The Trainer got the following:\n'
f'{error_message}')
devices = self._accelerator_connector._parallel_devices
if len(devices) > 1:
@@ -72,7 +74,8 @@ class Trainer(pytorch_lightning.Trainer):
device = devices[0]
data_module = PinaDataModule(problem=self.solver.problem, device=device,
train_size=self.train_size, test_size=self.test_size,
train_size=self.train_size,
test_size=self.test_size,
eval_size=self.eval_size)
data_module.setup()
self._loader = data_module.train_dataloader()