Codacy correction

This commit is contained in:
FilippoOlivo
2024-10-31 09:50:19 +01:00
committed by Nicola Demo
parent ea3d1924e7
commit dd43c8304c
23 changed files with 246 additions and 214 deletions

View File

@@ -9,8 +9,13 @@ from .solvers.solver import SolverInterface
class Trainer(pytorch_lightning.Trainer):
def __init__(self, solver, batch_size=None, train_size=.7, test_size=.2,
eval_size=.1, **kwargs):
def __init__(self,
solver,
batch_size=None,
train_size=.7,
test_size=.2,
eval_size=.1,
**kwargs):
"""
PINA Trainer class for costumizing every aspect of training via flags.
@@ -48,8 +53,7 @@ class Trainer(pytorch_lightning.Trainer):
if hasattr(pb, "unknown_parameters"):
for key in pb.unknown_parameters:
pb.unknown_parameters[key] = torch.nn.Parameter(
pb.unknown_parameters[key].data.to(device)
)
pb.unknown_parameters[key].data.to(device))
def _create_loader(self):
"""
@@ -58,14 +62,11 @@ class Trainer(pytorch_lightning.Trainer):
trainer dataloader, just call the method.
"""
if not self.solver.problem.collector.full:
error_message = '\n'.join(
[
f"""{" " * 13} ---> Condition {key} {"sampled" if value else
"not sampled"}"""
for key, value in
self._solver.problem.collector._is_conditions_ready.items()
]
)
error_message = '\n'.join([
f"""{" " * 13} ---> Condition {key} {"sampled" if value else
"not sampled"}""" for key, value in
self._solver.problem.collector._is_conditions_ready.items()
])
raise RuntimeError('Cannot create Trainer if not all conditions '
'are sampled. The Trainer got the following:\n'
f'{error_message}')
@@ -76,7 +77,8 @@ class Trainer(pytorch_lightning.Trainer):
device = devices[0]
data_module = PinaDataModule(problem=self.solver.problem, device=device,
data_module = PinaDataModule(problem=self.solver.problem,
device=device,
train_size=self.train_size,
test_size=self.test_size,
val_size=self.eval_size)
@@ -87,9 +89,9 @@ class Trainer(pytorch_lightning.Trainer):
"""
Train the solver method.
"""
return super().fit(
self.solver, train_dataloaders=self._loader, **kwargs
)
return super().fit(self.solver,
train_dataloaders=self._loader,
**kwargs)
@property
def solver(self):