Correct codacy warnings

This commit is contained in:
FilippoOlivo
2024-10-22 14:54:22 +02:00
committed by Nicola Demo
parent 1bc1b3a580
commit 3e30450e9a
10 changed files with 60 additions and 37 deletions

View File

@@ -48,7 +48,8 @@ class Collector:
for condition_name, condition in self.problem.conditions.items(): for condition_name, condition in self.problem.conditions.items():
# if the condition is not ready and domain is not attribute # if the condition is not ready and domain is not attribute
# of condition, we get and store the data # of condition, we get and store the data
if (not self._is_conditions_ready[condition_name]) and (not hasattr(condition, "domain")): if (not self._is_conditions_ready[condition_name]) and (
not hasattr(condition, "domain")):
# get data # get data
keys = condition.__slots__ keys = condition.__slots__
values = [getattr(condition, name) for name in keys] values = [getattr(condition, name) for name in keys]
@@ -69,7 +70,8 @@ class Collector:
already_sampled = [] already_sampled = []
# if we have sampled the condition but not all variables # if we have sampled the condition but not all variables
else: else:
already_sampled = [self.data_collections[loc]['input_points']] already_sampled = [
self.data_collections[loc]['input_points']]
# if the condition is ready but we want to sample again # if the condition is ready but we want to sample again
else: else:
self._is_conditions_ready[loc] = False self._is_conditions_ready[loc] = False
@@ -77,11 +79,13 @@ class Collector:
# get the samples # get the samples
samples = [ samples = [
condition.domain.sample(n=n, mode=mode, variables=variables) condition.domain.sample(n=n, mode=mode,
variables=variables)
] + already_sampled ] + already_sampled
pts = merge_tensors(samples) pts = merge_tensors(samples)
if ( if (
set(pts.labels).issubset(sorted(self.problem.input_variables)) set(pts.labels).issubset(
sorted(self.problem.input_variables))
): ):
pts = pts.sort_labels() pts = pts.sort_labels()
if sorted(pts.labels) == sorted(self.problem.input_variables): if sorted(pts.labels) == sorted(self.problem.input_variables):
@@ -89,7 +93,8 @@ class Collector:
values = [pts, condition.equation] values = [pts, condition.equation]
self.data_collections[loc] = dict(zip(keys, values)) self.data_collections[loc] = dict(zip(keys, values))
else: else:
raise RuntimeError('Try to sample variables which are not in problem defined in the problem') raise RuntimeError(
'Try to sample variables which are not in problem defined in the problem')
def add_points(self, new_points_dict): def add_points(self, new_points_dict):
""" """
@@ -100,5 +105,7 @@ class Collector:
""" """
for k, v in new_points_dict.items(): for k, v in new_points_dict.items():
if not self._is_conditions_ready[k]: if not self._is_conditions_ready[k]:
raise RuntimeError('Cannot add points on a non sampled condition') raise RuntimeError(
self.data_collections[k]['input_points'] = self.data_collections[k]['input_points'].vstack(v) 'Cannot add points on a non sampled condition')
self.data_collections[k]['input_points'] = self.data_collections[k][
'input_points'].vstack(v)

View File

@@ -5,6 +5,7 @@ from .input_equation_condition import InputPointsEquationCondition
from .input_output_condition import InputOutputPointsCondition from .input_output_condition import InputOutputPointsCondition
from .data_condition import DataConditionInterface from .data_condition import DataConditionInterface
class Condition: class Condition:
""" """
The class ``Condition`` is used to represent the constraints (physical The class ``Condition`` is used to represent the constraints (physical

View File

@@ -27,7 +27,7 @@ class BaseDataset(Dataset):
if not hasattr(cls, '__slots__'): if not hasattr(cls, '__slots__'):
raise TypeError( raise TypeError(
'Something is wrong, __slots__ must be defined in subclasses.') 'Something is wrong, __slots__ must be defined in subclasses.')
return super(BaseDataset, cls).__new__(cls) return object.__new__(cls)
def __init__(self, problem, device): def __init__(self, problem, device):
"""" """"

View File

@@ -38,9 +38,11 @@ class PinaDataModule(LightningDataModule):
:param datasets: list of datasets objects :param datasets: list of datasets objects
""" """
super().__init__() super().__init__()
dataset_classes = [SupervisedDataset, UnsupervisedDataset, SamplePointDataset] dataset_classes = [SupervisedDataset, UnsupervisedDataset,
SamplePointDataset]
if datasets is None: if datasets is None:
self.datasets = [DatasetClass(problem, device) for DatasetClass in dataset_classes] self.datasets = [DatasetClass(problem, device) for DatasetClass in
dataset_classes]
else: else:
self.datasets = datasets self.datasets = datasets
@@ -100,8 +102,6 @@ class PinaDataModule(LightningDataModule):
for key, value in dataset.condition_names.items() for key, value in dataset.condition_names.items()
} }
def train_dataloader(self): def train_dataloader(self):
""" """
Return the training dataloader for the dataset Return the training dataloader for the dataset
@@ -158,11 +158,13 @@ class PinaDataModule(LightningDataModule):
if seed is not None: if seed is not None:
generator = torch.Generator() generator = torch.Generator()
generator.manual_seed(seed) generator.manual_seed(seed)
indices = torch.randperm(sum(lengths), generator=generator).tolist() indices = torch.randperm(sum(lengths),
generator=generator).tolist()
else: else:
indices = torch.arange(sum(lengths)).tolist() indices = torch.arange(sum(lengths)).tolist()
else: else:
indices = torch.arange(0, sum(lengths), 1, dtype=torch.uint8).tolist() indices = torch.arange(0, sum(lengths), 1,
dtype=torch.uint8).tolist()
offsets = [ offsets = [
sum(lengths[:i]) if i > 0 else 0 for i in range(len(lengths)) sum(lengths[:i]) if i > 0 else 0 for i in range(len(lengths))
] ]

View File

@@ -5,6 +5,9 @@ from .pina_subset import PinaSubset
class Batch: class Batch:
"""
Implementation of the Batch class used during training to perform SGD optimization.
"""
def __init__(self, dataset_dict, idx_dict): def __init__(self, dataset_dict, idx_dict):

View File

@@ -33,7 +33,7 @@ class PinaDataLoader:
Create batches according to the batch_size provided in input. Create batches according to the batch_size provided in input.
""" """
self.batches = [] self.batches = []
n_elements = sum([len(v) for v in self.dataset_dict.values()]) n_elements = sum(len(v) for v in self.dataset_dict.values())
if batch_size is None: if batch_size is None:
batch_size = n_elements batch_size = n_elements
indexes_dict = {} indexes_dict = {}

View File

@@ -1,3 +1,8 @@
"""
Module for PinaSubset class
"""
class PinaSubset: class PinaSubset:
""" """
TODO TODO

View File

@@ -6,8 +6,9 @@ from .base_dataset import BaseDataset
class UnsupervisedDataset(BaseDataset): class UnsupervisedDataset(BaseDataset):
""" """
This class extend BaseDataset class to handle unsupervised dataset, This class extend BaseDataset class to handle
composed of input points and, optionally, conditional variables unsupervised dataset,composed of input points
and, optionally, conditional variables
""" """
data_type = 'unsupervised' data_type = 'unsupervised'
__slots__ = ['input_points', 'conditional_variables'] __slots__ = ['input_points', 'conditional_variables']

View File

@@ -13,6 +13,7 @@ class TorchOptimizer(Optimizer):
self.optimizer_class = optimizer_class self.optimizer_class = optimizer_class
self.kwargs = kwargs self.kwargs = kwargs
self.optimizer_instance = None
def hook(self, parameters): def hook(self, parameters):
self.optimizer_instance = self.optimizer_class(parameters, self.optimizer_instance = self.optimizer_class(parameters,

View File

@@ -9,7 +9,8 @@ from .solvers.solver import SolverInterface
class Trainer(pytorch_lightning.Trainer): class Trainer(pytorch_lightning.Trainer):
def __init__(self, solver, batch_size=None, train_size=.7, test_size=.2, eval_size=.1, **kwargs): def __init__(self, solver, batch_size=None, train_size=.7, test_size=.2,
eval_size=.1, **kwargs):
""" """
PINA Trainer class for costumizing every aspect of training via flags. PINA Trainer class for costumizing every aspect of training via flags.
@@ -39,7 +40,6 @@ class Trainer(pytorch_lightning.Trainer):
self._create_loader() self._create_loader()
self._move_to_device() self._move_to_device()
def _move_to_device(self): def _move_to_device(self):
device = self._accelerator_connector._parallel_devices[0] device = self._accelerator_connector._parallel_devices[0]
@@ -59,8 +59,10 @@ class Trainer(pytorch_lightning.Trainer):
""" """
if not self.solver.problem.collector.full: if not self.solver.problem.collector.full:
error_message = '\n'.join( error_message = '\n'.join(
[f'{" " * 13} ---> Condition {key} {"sampled" if value else "not sampled"}' [
for key, value in self.solver.problem.collector._is_conditions_ready.items()]) f'{" " * 13} ---> Condition {key} {"sampled" if value else "not sampled"}'
for key, value in
self.solver.problem.collector._is_conditions_ready.items()])
raise RuntimeError('Cannot create Trainer if not all conditions ' raise RuntimeError('Cannot create Trainer if not all conditions '
'are sampled. The Trainer got the following:\n' 'are sampled. The Trainer got the following:\n'
f'{error_message}') f'{error_message}')
@@ -72,7 +74,8 @@ class Trainer(pytorch_lightning.Trainer):
device = devices[0] device = devices[0]
data_module = PinaDataModule(problem=self.solver.problem, device=device, data_module = PinaDataModule(problem=self.solver.problem, device=device,
train_size=self.train_size, test_size=self.test_size, train_size=self.train_size,
test_size=self.test_size,
eval_size=self.eval_size) eval_size=self.eval_size)
data_module.setup() data_module.setup()
self._loader = data_module.train_dataloader() self._loader = data_module.train_dataloader()