Fix Codacy Warnings (#477)

---------

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
Filippo Olivo
2025-03-10 15:38:45 +01:00
committed by Nicola Demo
parent e3790e049a
commit 4177bfbb50
157 changed files with 3473 additions and 3839 deletions

View File

@@ -1,16 +1,15 @@
"""Solver module."""
from abc import ABCMeta, abstractmethod
import lightning
import torch
import sys
from abc import ABCMeta, abstractmethod
from torch._dynamo.eval_frame import OptimizedModule
from ..problem import AbstractProblem
from ..optim import Optimizer, Scheduler, TorchOptimizer, TorchScheduler
from ..loss import WeightingInterface
from ..loss.scalar_weighting import _NoWeighting
from ..utils import check_consistency, labelize_forward
from torch._dynamo.eval_frame import OptimizedModule
class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
@@ -119,6 +118,10 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
self.store_log("test_loss", loss, self.get_batch_size(batch))
def store_log(self, name, value, batch_size):
"""
TODO
"""
self.log(
name=name,
value=value,
@@ -128,7 +131,9 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
@abstractmethod
def forward(self, *args, **kwargs):
pass
"""
TODO
"""
@abstractmethod
def optimization_cycle(self, batch):
@@ -144,7 +149,6 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
containing the condition name and the associated scalar loss.
:rtype: dict(torch.Tensor)
"""
pass
@property
def problem(self):
@@ -169,7 +173,10 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
@staticmethod
def get_batch_size(batch):
# assuming batch is a custom Batch object
"""
TODO
"""
batch_size = 0
for data in batch:
batch_size += len(data[1]["input"])
@@ -177,10 +184,18 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
@staticmethod
def default_torch_optimizer():
"""
TODO
"""
return TorchOptimizer(torch.optim.Adam, lr=0.001)
@staticmethod
def default_torch_scheduler():
"""
TODO
"""
return TorchScheduler(torch.optim.lr_scheduler.ConstantLR)
def on_train_start(self):
@@ -202,6 +217,10 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
self._compile_model()
def _check_already_compiled(self):
"""
TODO
"""
models = self._pina_models
if len(models) == 1 and isinstance(
self._pina_models[0], torch.nn.ModuleDict
@@ -214,6 +233,10 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
@staticmethod
def _perform_compilation(model):
"""
TODO
"""
model_device = next(model.parameters()).device
try:
if model_device == torch.device("mps:0"):
@@ -225,7 +248,9 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
return model
class SingleSolverInterface(SolverInterface):
class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
"""TODO"""
def __init__(
self,
problem,
@@ -322,7 +347,7 @@ class SingleSolverInterface(SolverInterface):
return self._pina_optimizers[0]
class MultiSolverInterface(SolverInterface):
class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
"""
Multiple Solver base class. This class inherits is a wrapper of
SolverInterface class