Fix Codacy Warnings (#477)
--------- Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
e3790e049a
commit
4177bfbb50
@@ -1,16 +1,15 @@
|
||||
"""Solver module."""
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
import lightning
|
||||
import torch
|
||||
import sys
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from torch._dynamo.eval_frame import OptimizedModule
|
||||
from ..problem import AbstractProblem
|
||||
from ..optim import Optimizer, Scheduler, TorchOptimizer, TorchScheduler
|
||||
from ..loss import WeightingInterface
|
||||
from ..loss.scalar_weighting import _NoWeighting
|
||||
from ..utils import check_consistency, labelize_forward
|
||||
from torch._dynamo.eval_frame import OptimizedModule
|
||||
|
||||
|
||||
class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
@@ -119,6 +118,10 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
self.store_log("test_loss", loss, self.get_batch_size(batch))
|
||||
|
||||
def store_log(self, name, value, batch_size):
|
||||
"""
|
||||
TODO
|
||||
"""
|
||||
|
||||
self.log(
|
||||
name=name,
|
||||
value=value,
|
||||
@@ -128,7 +131,9 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, *args, **kwargs):
|
||||
pass
|
||||
"""
|
||||
TODO
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def optimization_cycle(self, batch):
|
||||
@@ -144,7 +149,6 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
containing the condition name and the associated scalar loss.
|
||||
:rtype: dict(torch.Tensor)
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def problem(self):
|
||||
@@ -169,7 +173,10 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
|
||||
@staticmethod
|
||||
def get_batch_size(batch):
|
||||
# assuming batch is a custom Batch object
|
||||
"""
|
||||
TODO
|
||||
"""
|
||||
|
||||
batch_size = 0
|
||||
for data in batch:
|
||||
batch_size += len(data[1]["input"])
|
||||
@@ -177,10 +184,18 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
|
||||
@staticmethod
|
||||
def default_torch_optimizer():
|
||||
"""
|
||||
TODO
|
||||
"""
|
||||
|
||||
return TorchOptimizer(torch.optim.Adam, lr=0.001)
|
||||
|
||||
@staticmethod
|
||||
def default_torch_scheduler():
|
||||
"""
|
||||
TODO
|
||||
"""
|
||||
|
||||
return TorchScheduler(torch.optim.lr_scheduler.ConstantLR)
|
||||
|
||||
def on_train_start(self):
|
||||
@@ -202,6 +217,10 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
self._compile_model()
|
||||
|
||||
def _check_already_compiled(self):
|
||||
"""
|
||||
TODO
|
||||
"""
|
||||
|
||||
models = self._pina_models
|
||||
if len(models) == 1 and isinstance(
|
||||
self._pina_models[0], torch.nn.ModuleDict
|
||||
@@ -214,6 +233,10 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
|
||||
@staticmethod
|
||||
def _perform_compilation(model):
|
||||
"""
|
||||
TODO
|
||||
"""
|
||||
|
||||
model_device = next(model.parameters()).device
|
||||
try:
|
||||
if model_device == torch.device("mps:0"):
|
||||
@@ -225,7 +248,9 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
return model
|
||||
|
||||
|
||||
class SingleSolverInterface(SolverInterface):
|
||||
class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
|
||||
"""TODO"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
problem,
|
||||
@@ -322,7 +347,7 @@ class SingleSolverInterface(SolverInterface):
|
||||
return self._pina_optimizers[0]
|
||||
|
||||
|
||||
class MultiSolverInterface(SolverInterface):
|
||||
class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
|
||||
"""
|
||||
Multiple Solver base class. This class inherits is a wrapper of
|
||||
SolverInterface class
|
||||
|
||||
Reference in New Issue
Block a user