Correct codacy warnings
This commit is contained in:
committed by
Nicola Demo
parent
c9304fb9bb
commit
1bc1b3a580
@@ -1,12 +1,6 @@
|
||||
__all__ = [
|
||||
"PINN",
|
||||
"Trainer",
|
||||
"LabelTensor",
|
||||
"Plotter",
|
||||
"Condition",
|
||||
"SamplePointDataset",
|
||||
"PinaDataModule",
|
||||
"PinaDataLoader"
|
||||
"PINN", "Trainer", "LabelTensor", "Plotter", "Condition",
|
||||
"SamplePointDataset", "PinaDataModule", "PinaDataLoader"
|
||||
]
|
||||
|
||||
from .meta import *
|
||||
@@ -17,4 +11,4 @@ from .plotter import Plotter
|
||||
from .condition.condition import Condition
|
||||
from .data import SamplePointDataset
|
||||
from .data import PinaDataModule
|
||||
from .data import PinaDataLoader
|
||||
from .data import PinaDataLoader
|
||||
|
||||
@@ -2,13 +2,8 @@
|
||||
Import data classes
|
||||
"""
|
||||
__all__ = [
|
||||
'PinaDataLoader',
|
||||
'SupervisedDataset',
|
||||
'SamplePointDataset',
|
||||
'UnsupervisedDataset',
|
||||
'Batch',
|
||||
'PinaDataModule',
|
||||
'BaseDataset'
|
||||
'PinaDataLoader', 'SupervisedDataset', 'SamplePointDataset',
|
||||
'UnsupervisedDataset', 'Batch', 'PinaDataModule', 'BaseDataset'
|
||||
]
|
||||
|
||||
from .pina_dataloader import PinaDataLoader
|
||||
|
||||
@@ -22,10 +22,12 @@ class BaseDataset(Dataset):
|
||||
dataset will be loaded.
|
||||
"""
|
||||
if cls is BaseDataset:
|
||||
raise TypeError('BaseDataset cannot be instantiated directly. Use a subclass.')
|
||||
raise TypeError(
|
||||
'BaseDataset cannot be instantiated directly. Use a subclass.')
|
||||
if not hasattr(cls, '__slots__'):
|
||||
raise TypeError('Something is wrong, __slots__ must be defined in subclasses.')
|
||||
return super().__new__(cls)
|
||||
raise TypeError(
|
||||
'Something is wrong, __slots__ must be defined in subclasses.')
|
||||
return super(BaseDataset, cls).__new__(cls)
|
||||
|
||||
def __init__(self, problem, device):
|
||||
""""
|
||||
@@ -79,7 +81,8 @@ class BaseDataset(Dataset):
|
||||
|
||||
def __getattribute__(self, item):
|
||||
attribute = super().__getattribute__(item)
|
||||
if isinstance(attribute, LabelTensor) and attribute.dtype == torch.float32:
|
||||
if isinstance(attribute,
|
||||
LabelTensor) and attribute.dtype == torch.float32:
|
||||
attribute = attribute.to(device=self.device).requires_grad_()
|
||||
return attribute
|
||||
|
||||
@@ -101,7 +104,8 @@ class BaseDataset(Dataset):
|
||||
if all(isinstance(x, int) for x in idx):
|
||||
to_return_list = []
|
||||
for i in self.__slots__:
|
||||
to_return_list.append(getattr(self, i)[[idx]].to(self.device))
|
||||
to_return_list.append(
|
||||
getattr(self, i)[[idx]].to(self.device))
|
||||
return to_return_list
|
||||
|
||||
raise ValueError(f'Invalid index {idx}')
|
||||
|
||||
@@ -5,6 +5,7 @@ from .pina_subset import PinaSubset
|
||||
|
||||
|
||||
class Batch:
|
||||
|
||||
def __init__(self, dataset_dict, idx_dict):
|
||||
|
||||
for k, v in dataset_dict.items():
|
||||
@@ -29,5 +30,6 @@ class Batch:
|
||||
def __getattr__(self, item):
|
||||
if not item in dir(self):
|
||||
raise AttributeError(f'Batch instance has no attribute {item}')
|
||||
return PinaSubset(getattr(self, item).dataset,
|
||||
getattr(self, item).indices[self.coordinates_dict[item]])
|
||||
return PinaSubset(
|
||||
getattr(self, item).dataset,
|
||||
getattr(self, item).indices[self.coordinates_dict[item]])
|
||||
|
||||
@@ -50,7 +50,8 @@ class PinaDataLoader:
|
||||
temp_dict[k] = slice(i * v, (i + 1) * v)
|
||||
else:
|
||||
temp_dict[k] = slice(i * v, len(self.dataset_dict[k]))
|
||||
self.batches.append(Batch(idx_dict=temp_dict, dataset_dict=self.dataset_dict))
|
||||
self.batches.append(
|
||||
Batch(idx_dict=temp_dict, dataset_dict=self.dataset_dict))
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
|
||||
@@ -5,6 +5,7 @@ import torch
|
||||
from ..utils import check_consistency
|
||||
from .optimizer_interface import Optimizer
|
||||
|
||||
|
||||
class TorchOptimizer(Optimizer):
|
||||
|
||||
def __init__(self, optimizer_class, **kwargs):
|
||||
@@ -14,6 +15,5 @@ class TorchOptimizer(Optimizer):
|
||||
self.kwargs = kwargs
|
||||
|
||||
def hook(self, parameters):
|
||||
self.optimizer_instance = self.optimizer_class(
|
||||
parameters, **self.kwargs
|
||||
)
|
||||
self.optimizer_instance = self.optimizer_class(parameters,
|
||||
**self.kwargs)
|
||||
|
||||
@@ -5,13 +5,13 @@ try:
|
||||
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
|
||||
except ImportError:
|
||||
from torch.optim.lr_scheduler import (
|
||||
_LRScheduler as LRScheduler,
|
||||
) # torch < 2.0
|
||||
_LRScheduler as LRScheduler, ) # torch < 2.0
|
||||
|
||||
from ..utils import check_consistency
|
||||
from .optimizer_interface import Optimizer
|
||||
from .scheduler_interface import Scheduler
|
||||
|
||||
|
||||
class TorchScheduler(Scheduler):
|
||||
|
||||
def __init__(self, scheduler_class, **kwargs):
|
||||
@@ -23,5 +23,4 @@ class TorchScheduler(Scheduler):
|
||||
def hook(self, optimizer):
|
||||
check_consistency(optimizer, Optimizer)
|
||||
self.scheduler_instance = self.scheduler_class(
|
||||
optimizer.optimizer_instance, **self.kwargs
|
||||
)
|
||||
optimizer.optimizer_instance, **self.kwargs)
|
||||
|
||||
@@ -17,15 +17,13 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
||||
LightningModule methods.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
models,
|
||||
problem,
|
||||
optimizers,
|
||||
schedulers,
|
||||
extra_features,
|
||||
use_lt=True
|
||||
):
|
||||
def __init__(self,
|
||||
models,
|
||||
problem,
|
||||
optimizers,
|
||||
schedulers,
|
||||
extra_features,
|
||||
use_lt=True):
|
||||
"""
|
||||
:param model: A torch neural network model instance.
|
||||
:type model: torch.nn.Module
|
||||
@@ -55,10 +53,11 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
||||
if use_lt is True:
|
||||
for idx in range(len(models)):
|
||||
models[idx] = Network(
|
||||
model = models[idx],
|
||||
model=models[idx],
|
||||
input_variables=problem.input_variables,
|
||||
output_variables=problem.output_variables,
|
||||
extra_features=extra_features, )
|
||||
extra_features=extra_features,
|
||||
)
|
||||
|
||||
#Check scheduler consistency + encapsulation
|
||||
if not isinstance(schedulers, list):
|
||||
@@ -79,11 +78,9 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
||||
|
||||
# check length consistency optimizers
|
||||
if len_model != len_optimizer:
|
||||
raise ValueError(
|
||||
"You must define one optimizer for each model."
|
||||
f"Got {len_model} models, and {len_optimizer}"
|
||||
" optimizers."
|
||||
)
|
||||
raise ValueError("You must define one optimizer for each model."
|
||||
f"Got {len_model} models, and {len_optimizer}"
|
||||
" optimizers.")
|
||||
|
||||
# extra features handling
|
||||
|
||||
@@ -92,7 +89,6 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
||||
self._pina_schedulers = schedulers
|
||||
self._pina_problem = problem
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, *args, **kwargs):
|
||||
pass
|
||||
@@ -142,5 +138,8 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
||||
TODO
|
||||
"""
|
||||
for _, condition in problem.conditions.items():
|
||||
if not set(self.accepted_condition_types).issubset(condition.condition_type):
|
||||
raise ValueError(f'{self.__name__} support only dose not support condition {condition.condition_type}')
|
||||
if not set(self.accepted_condition_types).issubset(
|
||||
condition.condition_type):
|
||||
raise ValueError(
|
||||
f'{self.__name__} support only dose not support condition {condition.condition_type}'
|
||||
)
|
||||
|
||||
@@ -40,15 +40,13 @@ class SupervisedSolver(SolverInterface):
|
||||
accepted_condition_types = ['supervised']
|
||||
__name__ = 'SupervisedSolver'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
problem,
|
||||
model,
|
||||
loss=None,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
extra_features=None
|
||||
):
|
||||
def __init__(self,
|
||||
problem,
|
||||
model,
|
||||
loss=None,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
extra_features=None):
|
||||
"""
|
||||
:param AbstractProblem problem: The formualation of the problem.
|
||||
:param torch.nn.Module model: The neural network model to use.
|
||||
@@ -68,16 +66,13 @@ class SupervisedSolver(SolverInterface):
|
||||
optimizer = TorchOptimizer(torch.optim.Adam, lr=0.001)
|
||||
|
||||
if scheduler is None:
|
||||
scheduler = TorchScheduler(
|
||||
torch.optim.lr_scheduler.ConstantLR)
|
||||
scheduler = TorchScheduler(torch.optim.lr_scheduler.ConstantLR)
|
||||
|
||||
super().__init__(
|
||||
models=model,
|
||||
problem=problem,
|
||||
optimizers=optimizer,
|
||||
schedulers=scheduler,
|
||||
extra_features=extra_features
|
||||
)
|
||||
super().__init__(models=model,
|
||||
problem=problem,
|
||||
optimizers=optimizer,
|
||||
schedulers=scheduler,
|
||||
extra_features=extra_features)
|
||||
|
||||
# check consistency
|
||||
check_consistency(loss, (LossInterface, _Loss), subclass=False)
|
||||
@@ -107,10 +102,8 @@ class SupervisedSolver(SolverInterface):
|
||||
"""
|
||||
self._optimizer.hook(self._model.parameters())
|
||||
self._scheduler.hook(self._optimizer)
|
||||
return (
|
||||
[self._optimizer.optimizer_instance],
|
||||
[self._scheduler.scheduler_instance]
|
||||
)
|
||||
return ([self._optimizer.optimizer_instance],
|
||||
[self._scheduler.scheduler_instance])
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
"""Solver training step.
|
||||
@@ -136,8 +129,7 @@ class SupervisedSolver(SolverInterface):
|
||||
# for data driven mode
|
||||
if not hasattr(condition, "output_points"):
|
||||
raise NotImplementedError(
|
||||
f"{type(self).__name__} works only in data-driven mode."
|
||||
)
|
||||
f"{type(self).__name__} works only in data-driven mode.")
|
||||
|
||||
output_pts = out[condition_idx == condition_id]
|
||||
input_pts = pts[condition_idx == condition_id]
|
||||
@@ -145,9 +137,7 @@ class SupervisedSolver(SolverInterface):
|
||||
input_pts.labels = pts.labels
|
||||
output_pts.labels = out.labels
|
||||
|
||||
loss = (
|
||||
self.loss_data(input_pts=input_pts, output_pts=output_pts)
|
||||
)
|
||||
loss = (self.loss_data(input_pts=input_pts, output_pts=output_pts))
|
||||
loss = loss.as_subclass(torch.Tensor)
|
||||
|
||||
self.log("mean_loss", float(loss), prog_bar=True, logger=True)
|
||||
|
||||
Reference in New Issue
Block a user