Codacy correction
This commit is contained in:
committed by
Nicola Demo
parent
ea3d1924e7
commit
dd43c8304c
@@ -9,8 +9,13 @@ from .solvers.solver import SolverInterface
|
||||
|
||||
class Trainer(pytorch_lightning.Trainer):
|
||||
|
||||
def __init__(self, solver, batch_size=None, train_size=.7, test_size=.2,
|
||||
eval_size=.1, **kwargs):
|
||||
def __init__(self,
|
||||
solver,
|
||||
batch_size=None,
|
||||
train_size=.7,
|
||||
test_size=.2,
|
||||
eval_size=.1,
|
||||
**kwargs):
|
||||
"""
|
||||
PINA Trainer class for costumizing every aspect of training via flags.
|
||||
|
||||
@@ -48,8 +53,7 @@ class Trainer(pytorch_lightning.Trainer):
|
||||
if hasattr(pb, "unknown_parameters"):
|
||||
for key in pb.unknown_parameters:
|
||||
pb.unknown_parameters[key] = torch.nn.Parameter(
|
||||
pb.unknown_parameters[key].data.to(device)
|
||||
)
|
||||
pb.unknown_parameters[key].data.to(device))
|
||||
|
||||
def _create_loader(self):
|
||||
"""
|
||||
@@ -58,14 +62,11 @@ class Trainer(pytorch_lightning.Trainer):
|
||||
trainer dataloader, just call the method.
|
||||
"""
|
||||
if not self.solver.problem.collector.full:
|
||||
error_message = '\n'.join(
|
||||
[
|
||||
f"""{" " * 13} ---> Condition {key} {"sampled" if value else
|
||||
"not sampled"}"""
|
||||
for key, value in
|
||||
self._solver.problem.collector._is_conditions_ready.items()
|
||||
]
|
||||
)
|
||||
error_message = '\n'.join([
|
||||
f"""{" " * 13} ---> Condition {key} {"sampled" if value else
|
||||
"not sampled"}""" for key, value in
|
||||
self._solver.problem.collector._is_conditions_ready.items()
|
||||
])
|
||||
raise RuntimeError('Cannot create Trainer if not all conditions '
|
||||
'are sampled. The Trainer got the following:\n'
|
||||
f'{error_message}')
|
||||
@@ -76,7 +77,8 @@ class Trainer(pytorch_lightning.Trainer):
|
||||
|
||||
device = devices[0]
|
||||
|
||||
data_module = PinaDataModule(problem=self.solver.problem, device=device,
|
||||
data_module = PinaDataModule(problem=self.solver.problem,
|
||||
device=device,
|
||||
train_size=self.train_size,
|
||||
test_size=self.test_size,
|
||||
val_size=self.eval_size)
|
||||
@@ -87,9 +89,9 @@ class Trainer(pytorch_lightning.Trainer):
|
||||
"""
|
||||
Train the solver method.
|
||||
"""
|
||||
return super().fit(
|
||||
self.solver, train_dataloaders=self._loader, **kwargs
|
||||
)
|
||||
return super().fit(self.solver,
|
||||
train_dataloaders=self._loader,
|
||||
**kwargs)
|
||||
|
||||
@property
|
||||
def solver(self):
|
||||
|
||||
Reference in New Issue
Block a user