Correct codacy warnings
This commit is contained in:
committed by
Nicola Demo
parent
c9304fb9bb
commit
1bc1b3a580
@@ -40,15 +40,13 @@ class SupervisedSolver(SolverInterface):
|
||||
accepted_condition_types = ['supervised']
|
||||
__name__ = 'SupervisedSolver'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
problem,
|
||||
model,
|
||||
loss=None,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
extra_features=None
|
||||
):
|
||||
def __init__(self,
|
||||
problem,
|
||||
model,
|
||||
loss=None,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
extra_features=None):
|
||||
"""
|
||||
:param AbstractProblem problem: The formualation of the problem.
|
||||
:param torch.nn.Module model: The neural network model to use.
|
||||
@@ -68,16 +66,13 @@ class SupervisedSolver(SolverInterface):
|
||||
optimizer = TorchOptimizer(torch.optim.Adam, lr=0.001)
|
||||
|
||||
if scheduler is None:
|
||||
scheduler = TorchScheduler(
|
||||
torch.optim.lr_scheduler.ConstantLR)
|
||||
scheduler = TorchScheduler(torch.optim.lr_scheduler.ConstantLR)
|
||||
|
||||
super().__init__(
|
||||
models=model,
|
||||
problem=problem,
|
||||
optimizers=optimizer,
|
||||
schedulers=scheduler,
|
||||
extra_features=extra_features
|
||||
)
|
||||
super().__init__(models=model,
|
||||
problem=problem,
|
||||
optimizers=optimizer,
|
||||
schedulers=scheduler,
|
||||
extra_features=extra_features)
|
||||
|
||||
# check consistency
|
||||
check_consistency(loss, (LossInterface, _Loss), subclass=False)
|
||||
@@ -107,10 +102,8 @@ class SupervisedSolver(SolverInterface):
|
||||
"""
|
||||
self._optimizer.hook(self._model.parameters())
|
||||
self._scheduler.hook(self._optimizer)
|
||||
return (
|
||||
[self._optimizer.optimizer_instance],
|
||||
[self._scheduler.scheduler_instance]
|
||||
)
|
||||
return ([self._optimizer.optimizer_instance],
|
||||
[self._scheduler.scheduler_instance])
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
"""Solver training step.
|
||||
@@ -136,8 +129,7 @@ class SupervisedSolver(SolverInterface):
|
||||
# for data driven mode
|
||||
if not hasattr(condition, "output_points"):
|
||||
raise NotImplementedError(
|
||||
f"{type(self).__name__} works only in data-driven mode."
|
||||
)
|
||||
f"{type(self).__name__} works only in data-driven mode.")
|
||||
|
||||
output_pts = out[condition_idx == condition_id]
|
||||
input_pts = pts[condition_idx == condition_id]
|
||||
@@ -145,9 +137,7 @@ class SupervisedSolver(SolverInterface):
|
||||
input_pts.labels = pts.labels
|
||||
output_pts.labels = out.labels
|
||||
|
||||
loss = (
|
||||
self.loss_data(input_pts=input_pts, output_pts=output_pts)
|
||||
)
|
||||
loss = (self.loss_data(input_pts=input_pts, output_pts=output_pts))
|
||||
loss = loss.as_subclass(torch.Tensor)
|
||||
|
||||
self.log("mean_loss", float(loss), prog_bar=True, logger=True)
|
||||
|
||||
Reference in New Issue
Block a user