Correct codacy warnings
This commit is contained in:
committed by
Nicola Demo
parent
1bc1b3a580
commit
3e30450e9a
@@ -9,7 +9,8 @@ from .solvers.solver import SolverInterface
|
||||
|
||||
class Trainer(pytorch_lightning.Trainer):
|
||||
|
||||
def __init__(self, solver, batch_size=None, train_size=.7, test_size=.2, eval_size=.1, **kwargs):
|
||||
def __init__(self, solver, batch_size=None, train_size=.7, test_size=.2,
|
||||
eval_size=.1, **kwargs):
|
||||
"""
|
||||
PINA Trainer class for costumizing every aspect of training via flags.
|
||||
|
||||
@@ -39,10 +40,9 @@ class Trainer(pytorch_lightning.Trainer):
|
||||
self._create_loader()
|
||||
self._move_to_device()
|
||||
|
||||
|
||||
def _move_to_device(self):
|
||||
device = self._accelerator_connector._parallel_devices[0]
|
||||
|
||||
|
||||
# move parameters to device
|
||||
pb = self.solver.problem
|
||||
if hasattr(pb, "unknown_parameters"):
|
||||
@@ -59,11 +59,13 @@ class Trainer(pytorch_lightning.Trainer):
|
||||
"""
|
||||
if not self.solver.problem.collector.full:
|
||||
error_message = '\n'.join(
|
||||
[f'{" " * 13} ---> Condition {key} {"sampled" if value else "not sampled"}'
|
||||
for key, value in self.solver.problem.collector._is_conditions_ready.items()])
|
||||
[
|
||||
f'{" " * 13} ---> Condition {key} {"sampled" if value else "not sampled"}'
|
||||
for key, value in
|
||||
self.solver.problem.collector._is_conditions_ready.items()])
|
||||
raise RuntimeError('Cannot create Trainer if not all conditions '
|
||||
'are sampled. The Trainer got the following:\n'
|
||||
f'{error_message}')
|
||||
'are sampled. The Trainer got the following:\n'
|
||||
f'{error_message}')
|
||||
devices = self._accelerator_connector._parallel_devices
|
||||
|
||||
if len(devices) > 1:
|
||||
@@ -72,7 +74,8 @@ class Trainer(pytorch_lightning.Trainer):
|
||||
device = devices[0]
|
||||
|
||||
data_module = PinaDataModule(problem=self.solver.problem, device=device,
|
||||
train_size=self.train_size, test_size=self.test_size,
|
||||
train_size=self.train_size,
|
||||
test_size=self.test_size,
|
||||
eval_size=self.eval_size)
|
||||
data_module.setup()
|
||||
self._loader = data_module.train_dataloader()
|
||||
|
||||
Reference in New Issue
Block a user