Correct codacy warnings

This commit is contained in:
FilippoOlivo
2024-10-22 14:54:22 +02:00
committed by Nicola Demo
parent 1bc1b3a580
commit 3e30450e9a
10 changed files with 60 additions and 37 deletions

View File

@@ -48,7 +48,8 @@ class Collector:
for condition_name, condition in self.problem.conditions.items():
# if the condition is not ready and domain is not attribute
# of condition, we get and store the data
if (not self._is_conditions_ready[condition_name]) and (not hasattr(condition, "domain")):
if (not self._is_conditions_ready[condition_name]) and (
not hasattr(condition, "domain")):
# get data
keys = condition.__slots__
values = [getattr(condition, name) for name in keys]
@@ -69,7 +70,8 @@ class Collector:
already_sampled = []
# if we have sampled the condition but not all variables
else:
already_sampled = [self.data_collections[loc]['input_points']]
already_sampled = [
self.data_collections[loc]['input_points']]
# if the condition is ready but we want to sample again
else:
self._is_conditions_ready[loc] = False
@@ -77,11 +79,13 @@ class Collector:
# get the samples
samples = [
condition.domain.sample(n=n, mode=mode, variables=variables)
condition.domain.sample(n=n, mode=mode,
variables=variables)
] + already_sampled
pts = merge_tensors(samples)
if (
set(pts.labels).issubset(sorted(self.problem.input_variables))
set(pts.labels).issubset(
sorted(self.problem.input_variables))
):
pts = pts.sort_labels()
if sorted(pts.labels) == sorted(self.problem.input_variables):
@@ -89,7 +93,8 @@ class Collector:
values = [pts, condition.equation]
self.data_collections[loc] = dict(zip(keys, values))
else:
raise RuntimeError('Try to sample variables which are not in problem defined in the problem')
raise RuntimeError(
'Try to sample variables which are not in problem defined in the problem')
def add_points(self, new_points_dict):
"""
@@ -100,5 +105,7 @@ class Collector:
"""
for k, v in new_points_dict.items():
if not self._is_conditions_ready[k]:
raise RuntimeError('Cannot add points on a non sampled condition')
self.data_collections[k]['input_points'] = self.data_collections[k]['input_points'].vstack(v)
raise RuntimeError(
'Cannot add points on a non sampled condition')
self.data_collections[k]['input_points'] = self.data_collections[k][
'input_points'].vstack(v)

View File

@@ -5,6 +5,7 @@ from .input_equation_condition import InputPointsEquationCondition
from .input_output_condition import InputOutputPointsCondition
from .data_condition import DataConditionInterface
class Condition:
"""
The class ``Condition`` is used to represent the constraints (physical
@@ -38,23 +39,23 @@ class Condition:
"""
__slots__ = list(
set(
InputOutputPointsCondition.__slots__ +
InputPointsEquationCondition.__slots__ +
DomainEquationCondition.__slots__ +
DataConditionInterface.__slots__
)
)
set(
InputOutputPointsCondition.__slots__ +
InputPointsEquationCondition.__slots__ +
DomainEquationCondition.__slots__ +
DataConditionInterface.__slots__
)
)
def __new__(cls, *args, **kwargs):
if len(args) != 0:
raise ValueError(
"Condition takes only the following keyword "
f"arguments: {Condition.__slots__}."
)
sorted_keys = sorted(kwargs.keys())
sorted_keys = sorted(kwargs.keys())
if sorted_keys == sorted(InputOutputPointsCondition.__slots__):
return InputOutputPointsCondition(**kwargs)
elif sorted_keys == sorted(InputPointsEquationCondition.__slots__):
@@ -66,4 +67,4 @@ class Condition:
elif sorted_keys == DataConditionInterface.__slots__[0]:
return DataConditionInterface(**kwargs)
else:
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")

View File

@@ -27,7 +27,7 @@ class BaseDataset(Dataset):
if not hasattr(cls, '__slots__'):
raise TypeError(
'Something is wrong, __slots__ must be defined in subclasses.')
return super(BaseDataset, cls).__new__(cls)
return object.__new__(cls)
def __init__(self, problem, device):
""""

View File

@@ -26,7 +26,7 @@ class PinaDataModule(LightningDataModule):
eval_size=.1,
batch_size=None,
shuffle=True,
datasets = None):
datasets=None):
"""
Initialize the object, creating dataset based on input problem
:param AbstractProblem problem: PINA problem
@@ -38,9 +38,11 @@ class PinaDataModule(LightningDataModule):
:param datasets: list of datasets objects
"""
super().__init__()
dataset_classes = [SupervisedDataset, UnsupervisedDataset, SamplePointDataset]
dataset_classes = [SupervisedDataset, UnsupervisedDataset,
SamplePointDataset]
if datasets is None:
self.datasets = [DatasetClass(problem, device) for DatasetClass in dataset_classes]
self.datasets = [DatasetClass(problem, device) for DatasetClass in
dataset_classes]
else:
self.datasets = datasets
@@ -100,8 +102,6 @@ class PinaDataModule(LightningDataModule):
for key, value in dataset.condition_names.items()
}
def train_dataloader(self):
"""
Return the training dataloader for the dataset
@@ -158,11 +158,13 @@ class PinaDataModule(LightningDataModule):
if seed is not None:
generator = torch.Generator()
generator.manual_seed(seed)
indices = torch.randperm(sum(lengths), generator=generator).tolist()
indices = torch.randperm(sum(lengths),
generator=generator).tolist()
else:
indices = torch.arange(sum(lengths)).tolist()
else:
indices = torch.arange(0, sum(lengths), 1, dtype=torch.uint8).tolist()
indices = torch.arange(0, sum(lengths), 1,
dtype=torch.uint8).tolist()
offsets = [
sum(lengths[:i]) if i > 0 else 0 for i in range(len(lengths))
]

View File

@@ -5,6 +5,9 @@ from .pina_subset import PinaSubset
class Batch:
"""
Implementation of the Batch class used during training to perform SGD optimization.
"""
def __init__(self, dataset_dict, idx_dict):

View File

@@ -33,7 +33,7 @@ class PinaDataLoader:
Create batches according to the batch_size provided in input.
"""
self.batches = []
n_elements = sum([len(v) for v in self.dataset_dict.values()])
n_elements = sum(len(v) for v in self.dataset_dict.values())
if batch_size is None:
batch_size = n_elements
indexes_dict = {}

View File

@@ -1,3 +1,8 @@
"""
Module for PinaSubset class
"""
class PinaSubset:
"""
TODO

View File

@@ -6,8 +6,9 @@ from .base_dataset import BaseDataset
class UnsupervisedDataset(BaseDataset):
"""
This class extend BaseDataset class to handle unsupervised dataset,
composed of input points and, optionally, conditional variables
This class extend BaseDataset class to handle
unsupervised dataset,composed of input points
and, optionally, conditional variables
"""
data_type = 'unsupervised'
__slots__ = ['input_points', 'conditional_variables']

View File

@@ -13,6 +13,7 @@ class TorchOptimizer(Optimizer):
self.optimizer_class = optimizer_class
self.kwargs = kwargs
self.optimizer_instance = None
def hook(self, parameters):
self.optimizer_instance = self.optimizer_class(parameters,

View File

@@ -9,7 +9,8 @@ from .solvers.solver import SolverInterface
class Trainer(pytorch_lightning.Trainer):
def __init__(self, solver, batch_size=None, train_size=.7, test_size=.2, eval_size=.1, **kwargs):
def __init__(self, solver, batch_size=None, train_size=.7, test_size=.2,
eval_size=.1, **kwargs):
"""
PINA Trainer class for costumizing every aspect of training via flags.
@@ -39,10 +40,9 @@ class Trainer(pytorch_lightning.Trainer):
self._create_loader()
self._move_to_device()
def _move_to_device(self):
device = self._accelerator_connector._parallel_devices[0]
# move parameters to device
pb = self.solver.problem
if hasattr(pb, "unknown_parameters"):
@@ -59,11 +59,13 @@ class Trainer(pytorch_lightning.Trainer):
"""
if not self.solver.problem.collector.full:
error_message = '\n'.join(
[f'{" " * 13} ---> Condition {key} {"sampled" if value else "not sampled"}'
for key, value in self.solver.problem.collector._is_conditions_ready.items()])
[
f'{" " * 13} ---> Condition {key} {"sampled" if value else "not sampled"}'
for key, value in
self.solver.problem.collector._is_conditions_ready.items()])
raise RuntimeError('Cannot create Trainer if not all conditions '
'are sampled. The Trainer got the following:\n'
f'{error_message}')
'are sampled. The Trainer got the following:\n'
f'{error_message}')
devices = self._accelerator_connector._parallel_devices
if len(devices) > 1:
@@ -72,7 +74,8 @@ class Trainer(pytorch_lightning.Trainer):
device = devices[0]
data_module = PinaDataModule(problem=self.solver.problem, device=device,
train_size=self.train_size, test_size=self.test_size,
train_size=self.train_size,
test_size=self.test_size,
eval_size=self.eval_size)
data_module.setup()
self._loader = data_module.train_dataloader()