Correct codacy warnings
This commit is contained in:
committed by
Nicola Demo
parent
1bc1b3a580
commit
3e30450e9a
@@ -48,7 +48,8 @@ class Collector:
|
||||
for condition_name, condition in self.problem.conditions.items():
|
||||
# if the condition is not ready and domain is not attribute
|
||||
# of condition, we get and store the data
|
||||
if (not self._is_conditions_ready[condition_name]) and (not hasattr(condition, "domain")):
|
||||
if (not self._is_conditions_ready[condition_name]) and (
|
||||
not hasattr(condition, "domain")):
|
||||
# get data
|
||||
keys = condition.__slots__
|
||||
values = [getattr(condition, name) for name in keys]
|
||||
@@ -69,7 +70,8 @@ class Collector:
|
||||
already_sampled = []
|
||||
# if we have sampled the condition but not all variables
|
||||
else:
|
||||
already_sampled = [self.data_collections[loc]['input_points']]
|
||||
already_sampled = [
|
||||
self.data_collections[loc]['input_points']]
|
||||
# if the condition is ready but we want to sample again
|
||||
else:
|
||||
self._is_conditions_ready[loc] = False
|
||||
@@ -77,11 +79,13 @@ class Collector:
|
||||
|
||||
# get the samples
|
||||
samples = [
|
||||
condition.domain.sample(n=n, mode=mode, variables=variables)
|
||||
condition.domain.sample(n=n, mode=mode,
|
||||
variables=variables)
|
||||
] + already_sampled
|
||||
pts = merge_tensors(samples)
|
||||
if (
|
||||
set(pts.labels).issubset(sorted(self.problem.input_variables))
|
||||
set(pts.labels).issubset(
|
||||
sorted(self.problem.input_variables))
|
||||
):
|
||||
pts = pts.sort_labels()
|
||||
if sorted(pts.labels) == sorted(self.problem.input_variables):
|
||||
@@ -89,7 +93,8 @@ class Collector:
|
||||
values = [pts, condition.equation]
|
||||
self.data_collections[loc] = dict(zip(keys, values))
|
||||
else:
|
||||
raise RuntimeError('Try to sample variables which are not in problem defined in the problem')
|
||||
raise RuntimeError(
|
||||
'Try to sample variables which are not in problem defined in the problem')
|
||||
|
||||
def add_points(self, new_points_dict):
|
||||
"""
|
||||
@@ -100,5 +105,7 @@ class Collector:
|
||||
"""
|
||||
for k, v in new_points_dict.items():
|
||||
if not self._is_conditions_ready[k]:
|
||||
raise RuntimeError('Cannot add points on a non sampled condition')
|
||||
self.data_collections[k]['input_points'] = self.data_collections[k]['input_points'].vstack(v)
|
||||
raise RuntimeError(
|
||||
'Cannot add points on a non sampled condition')
|
||||
self.data_collections[k]['input_points'] = self.data_collections[k][
|
||||
'input_points'].vstack(v)
|
||||
|
||||
@@ -5,6 +5,7 @@ from .input_equation_condition import InputPointsEquationCondition
|
||||
from .input_output_condition import InputOutputPointsCondition
|
||||
from .data_condition import DataConditionInterface
|
||||
|
||||
|
||||
class Condition:
|
||||
"""
|
||||
The class ``Condition`` is used to represent the constraints (physical
|
||||
@@ -38,13 +39,13 @@ class Condition:
|
||||
"""
|
||||
|
||||
__slots__ = list(
|
||||
set(
|
||||
InputOutputPointsCondition.__slots__ +
|
||||
InputPointsEquationCondition.__slots__ +
|
||||
DomainEquationCondition.__slots__ +
|
||||
DataConditionInterface.__slots__
|
||||
)
|
||||
)
|
||||
set(
|
||||
InputOutputPointsCondition.__slots__ +
|
||||
InputPointsEquationCondition.__slots__ +
|
||||
DomainEquationCondition.__slots__ +
|
||||
DataConditionInterface.__slots__
|
||||
)
|
||||
)
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ class BaseDataset(Dataset):
|
||||
if not hasattr(cls, '__slots__'):
|
||||
raise TypeError(
|
||||
'Something is wrong, __slots__ must be defined in subclasses.')
|
||||
return super(BaseDataset, cls).__new__(cls)
|
||||
return object.__new__(cls)
|
||||
|
||||
def __init__(self, problem, device):
|
||||
""""
|
||||
|
||||
@@ -26,7 +26,7 @@ class PinaDataModule(LightningDataModule):
|
||||
eval_size=.1,
|
||||
batch_size=None,
|
||||
shuffle=True,
|
||||
datasets = None):
|
||||
datasets=None):
|
||||
"""
|
||||
Initialize the object, creating dataset based on input problem
|
||||
:param AbstractProblem problem: PINA problem
|
||||
@@ -38,9 +38,11 @@ class PinaDataModule(LightningDataModule):
|
||||
:param datasets: list of datasets objects
|
||||
"""
|
||||
super().__init__()
|
||||
dataset_classes = [SupervisedDataset, UnsupervisedDataset, SamplePointDataset]
|
||||
dataset_classes = [SupervisedDataset, UnsupervisedDataset,
|
||||
SamplePointDataset]
|
||||
if datasets is None:
|
||||
self.datasets = [DatasetClass(problem, device) for DatasetClass in dataset_classes]
|
||||
self.datasets = [DatasetClass(problem, device) for DatasetClass in
|
||||
dataset_classes]
|
||||
else:
|
||||
self.datasets = datasets
|
||||
|
||||
@@ -100,8 +102,6 @@ class PinaDataModule(LightningDataModule):
|
||||
for key, value in dataset.condition_names.items()
|
||||
}
|
||||
|
||||
|
||||
|
||||
def train_dataloader(self):
|
||||
"""
|
||||
Return the training dataloader for the dataset
|
||||
@@ -158,11 +158,13 @@ class PinaDataModule(LightningDataModule):
|
||||
if seed is not None:
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(seed)
|
||||
indices = torch.randperm(sum(lengths), generator=generator).tolist()
|
||||
indices = torch.randperm(sum(lengths),
|
||||
generator=generator).tolist()
|
||||
else:
|
||||
indices = torch.arange(sum(lengths)).tolist()
|
||||
else:
|
||||
indices = torch.arange(0, sum(lengths), 1, dtype=torch.uint8).tolist()
|
||||
indices = torch.arange(0, sum(lengths), 1,
|
||||
dtype=torch.uint8).tolist()
|
||||
offsets = [
|
||||
sum(lengths[:i]) if i > 0 else 0 for i in range(len(lengths))
|
||||
]
|
||||
|
||||
@@ -5,6 +5,9 @@ from .pina_subset import PinaSubset
|
||||
|
||||
|
||||
class Batch:
|
||||
"""
|
||||
Implementation of the Batch class used during training to perform SGD optimization.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset_dict, idx_dict):
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ class PinaDataLoader:
|
||||
Create batches according to the batch_size provided in input.
|
||||
"""
|
||||
self.batches = []
|
||||
n_elements = sum([len(v) for v in self.dataset_dict.values()])
|
||||
n_elements = sum(len(v) for v in self.dataset_dict.values())
|
||||
if batch_size is None:
|
||||
batch_size = n_elements
|
||||
indexes_dict = {}
|
||||
|
||||
@@ -1,3 +1,8 @@
|
||||
"""
|
||||
Module for PinaSubset class
|
||||
"""
|
||||
|
||||
|
||||
class PinaSubset:
|
||||
"""
|
||||
TODO
|
||||
|
||||
@@ -6,8 +6,9 @@ from .base_dataset import BaseDataset
|
||||
|
||||
class UnsupervisedDataset(BaseDataset):
|
||||
"""
|
||||
This class extend BaseDataset class to handle unsupervised dataset,
|
||||
composed of input points and, optionally, conditional variables
|
||||
This class extend BaseDataset class to handle
|
||||
unsupervised dataset,composed of input points
|
||||
and, optionally, conditional variables
|
||||
"""
|
||||
data_type = 'unsupervised'
|
||||
__slots__ = ['input_points', 'conditional_variables']
|
||||
|
||||
@@ -13,6 +13,7 @@ class TorchOptimizer(Optimizer):
|
||||
|
||||
self.optimizer_class = optimizer_class
|
||||
self.kwargs = kwargs
|
||||
self.optimizer_instance = None
|
||||
|
||||
def hook(self, parameters):
|
||||
self.optimizer_instance = self.optimizer_class(parameters,
|
||||
|
||||
@@ -9,7 +9,8 @@ from .solvers.solver import SolverInterface
|
||||
|
||||
class Trainer(pytorch_lightning.Trainer):
|
||||
|
||||
def __init__(self, solver, batch_size=None, train_size=.7, test_size=.2, eval_size=.1, **kwargs):
|
||||
def __init__(self, solver, batch_size=None, train_size=.7, test_size=.2,
|
||||
eval_size=.1, **kwargs):
|
||||
"""
|
||||
PINA Trainer class for costumizing every aspect of training via flags.
|
||||
|
||||
@@ -39,7 +40,6 @@ class Trainer(pytorch_lightning.Trainer):
|
||||
self._create_loader()
|
||||
self._move_to_device()
|
||||
|
||||
|
||||
def _move_to_device(self):
|
||||
device = self._accelerator_connector._parallel_devices[0]
|
||||
|
||||
@@ -59,11 +59,13 @@ class Trainer(pytorch_lightning.Trainer):
|
||||
"""
|
||||
if not self.solver.problem.collector.full:
|
||||
error_message = '\n'.join(
|
||||
[f'{" " * 13} ---> Condition {key} {"sampled" if value else "not sampled"}'
|
||||
for key, value in self.solver.problem.collector._is_conditions_ready.items()])
|
||||
[
|
||||
f'{" " * 13} ---> Condition {key} {"sampled" if value else "not sampled"}'
|
||||
for key, value in
|
||||
self.solver.problem.collector._is_conditions_ready.items()])
|
||||
raise RuntimeError('Cannot create Trainer if not all conditions '
|
||||
'are sampled. The Trainer got the following:\n'
|
||||
f'{error_message}')
|
||||
'are sampled. The Trainer got the following:\n'
|
||||
f'{error_message}')
|
||||
devices = self._accelerator_connector._parallel_devices
|
||||
|
||||
if len(devices) > 1:
|
||||
@@ -72,7 +74,8 @@ class Trainer(pytorch_lightning.Trainer):
|
||||
device = devices[0]
|
||||
|
||||
data_module = PinaDataModule(problem=self.solver.problem, device=device,
|
||||
train_size=self.train_size, test_size=self.test_size,
|
||||
train_size=self.train_size,
|
||||
test_size=self.test_size,
|
||||
eval_size=self.eval_size)
|
||||
data_module.setup()
|
||||
self._loader = data_module.train_dataloader()
|
||||
|
||||
Reference in New Issue
Block a user