Correct codacy warnings
This commit is contained in:
committed by
Nicola Demo
parent
1bc1b3a580
commit
3e30450e9a
@@ -48,7 +48,8 @@ class Collector:
|
|||||||
for condition_name, condition in self.problem.conditions.items():
|
for condition_name, condition in self.problem.conditions.items():
|
||||||
# if the condition is not ready and domain is not attribute
|
# if the condition is not ready and domain is not attribute
|
||||||
# of condition, we get and store the data
|
# of condition, we get and store the data
|
||||||
if (not self._is_conditions_ready[condition_name]) and (not hasattr(condition, "domain")):
|
if (not self._is_conditions_ready[condition_name]) and (
|
||||||
|
not hasattr(condition, "domain")):
|
||||||
# get data
|
# get data
|
||||||
keys = condition.__slots__
|
keys = condition.__slots__
|
||||||
values = [getattr(condition, name) for name in keys]
|
values = [getattr(condition, name) for name in keys]
|
||||||
@@ -69,7 +70,8 @@ class Collector:
|
|||||||
already_sampled = []
|
already_sampled = []
|
||||||
# if we have sampled the condition but not all variables
|
# if we have sampled the condition but not all variables
|
||||||
else:
|
else:
|
||||||
already_sampled = [self.data_collections[loc]['input_points']]
|
already_sampled = [
|
||||||
|
self.data_collections[loc]['input_points']]
|
||||||
# if the condition is ready but we want to sample again
|
# if the condition is ready but we want to sample again
|
||||||
else:
|
else:
|
||||||
self._is_conditions_ready[loc] = False
|
self._is_conditions_ready[loc] = False
|
||||||
@@ -77,11 +79,13 @@ class Collector:
|
|||||||
|
|
||||||
# get the samples
|
# get the samples
|
||||||
samples = [
|
samples = [
|
||||||
condition.domain.sample(n=n, mode=mode, variables=variables)
|
condition.domain.sample(n=n, mode=mode,
|
||||||
|
variables=variables)
|
||||||
] + already_sampled
|
] + already_sampled
|
||||||
pts = merge_tensors(samples)
|
pts = merge_tensors(samples)
|
||||||
if (
|
if (
|
||||||
set(pts.labels).issubset(sorted(self.problem.input_variables))
|
set(pts.labels).issubset(
|
||||||
|
sorted(self.problem.input_variables))
|
||||||
):
|
):
|
||||||
pts = pts.sort_labels()
|
pts = pts.sort_labels()
|
||||||
if sorted(pts.labels) == sorted(self.problem.input_variables):
|
if sorted(pts.labels) == sorted(self.problem.input_variables):
|
||||||
@@ -89,7 +93,8 @@ class Collector:
|
|||||||
values = [pts, condition.equation]
|
values = [pts, condition.equation]
|
||||||
self.data_collections[loc] = dict(zip(keys, values))
|
self.data_collections[loc] = dict(zip(keys, values))
|
||||||
else:
|
else:
|
||||||
raise RuntimeError('Try to sample variables which are not in problem defined in the problem')
|
raise RuntimeError(
|
||||||
|
'Try to sample variables which are not in problem defined in the problem')
|
||||||
|
|
||||||
def add_points(self, new_points_dict):
|
def add_points(self, new_points_dict):
|
||||||
"""
|
"""
|
||||||
@@ -100,5 +105,7 @@ class Collector:
|
|||||||
"""
|
"""
|
||||||
for k, v in new_points_dict.items():
|
for k, v in new_points_dict.items():
|
||||||
if not self._is_conditions_ready[k]:
|
if not self._is_conditions_ready[k]:
|
||||||
raise RuntimeError('Cannot add points on a non sampled condition')
|
raise RuntimeError(
|
||||||
self.data_collections[k]['input_points'] = self.data_collections[k]['input_points'].vstack(v)
|
'Cannot add points on a non sampled condition')
|
||||||
|
self.data_collections[k]['input_points'] = self.data_collections[k][
|
||||||
|
'input_points'].vstack(v)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from .input_equation_condition import InputPointsEquationCondition
|
|||||||
from .input_output_condition import InputOutputPointsCondition
|
from .input_output_condition import InputOutputPointsCondition
|
||||||
from .data_condition import DataConditionInterface
|
from .data_condition import DataConditionInterface
|
||||||
|
|
||||||
|
|
||||||
class Condition:
|
class Condition:
|
||||||
"""
|
"""
|
||||||
The class ``Condition`` is used to represent the constraints (physical
|
The class ``Condition`` is used to represent the constraints (physical
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ class BaseDataset(Dataset):
|
|||||||
if not hasattr(cls, '__slots__'):
|
if not hasattr(cls, '__slots__'):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
'Something is wrong, __slots__ must be defined in subclasses.')
|
'Something is wrong, __slots__ must be defined in subclasses.')
|
||||||
return super(BaseDataset, cls).__new__(cls)
|
return object.__new__(cls)
|
||||||
|
|
||||||
def __init__(self, problem, device):
|
def __init__(self, problem, device):
|
||||||
""""
|
""""
|
||||||
|
|||||||
@@ -38,9 +38,11 @@ class PinaDataModule(LightningDataModule):
|
|||||||
:param datasets: list of datasets objects
|
:param datasets: list of datasets objects
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
dataset_classes = [SupervisedDataset, UnsupervisedDataset, SamplePointDataset]
|
dataset_classes = [SupervisedDataset, UnsupervisedDataset,
|
||||||
|
SamplePointDataset]
|
||||||
if datasets is None:
|
if datasets is None:
|
||||||
self.datasets = [DatasetClass(problem, device) for DatasetClass in dataset_classes]
|
self.datasets = [DatasetClass(problem, device) for DatasetClass in
|
||||||
|
dataset_classes]
|
||||||
else:
|
else:
|
||||||
self.datasets = datasets
|
self.datasets = datasets
|
||||||
|
|
||||||
@@ -100,8 +102,6 @@ class PinaDataModule(LightningDataModule):
|
|||||||
for key, value in dataset.condition_names.items()
|
for key, value in dataset.condition_names.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
"""
|
"""
|
||||||
Return the training dataloader for the dataset
|
Return the training dataloader for the dataset
|
||||||
@@ -158,11 +158,13 @@ class PinaDataModule(LightningDataModule):
|
|||||||
if seed is not None:
|
if seed is not None:
|
||||||
generator = torch.Generator()
|
generator = torch.Generator()
|
||||||
generator.manual_seed(seed)
|
generator.manual_seed(seed)
|
||||||
indices = torch.randperm(sum(lengths), generator=generator).tolist()
|
indices = torch.randperm(sum(lengths),
|
||||||
|
generator=generator).tolist()
|
||||||
else:
|
else:
|
||||||
indices = torch.arange(sum(lengths)).tolist()
|
indices = torch.arange(sum(lengths)).tolist()
|
||||||
else:
|
else:
|
||||||
indices = torch.arange(0, sum(lengths), 1, dtype=torch.uint8).tolist()
|
indices = torch.arange(0, sum(lengths), 1,
|
||||||
|
dtype=torch.uint8).tolist()
|
||||||
offsets = [
|
offsets = [
|
||||||
sum(lengths[:i]) if i > 0 else 0 for i in range(len(lengths))
|
sum(lengths[:i]) if i > 0 else 0 for i in range(len(lengths))
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -5,6 +5,9 @@ from .pina_subset import PinaSubset
|
|||||||
|
|
||||||
|
|
||||||
class Batch:
|
class Batch:
|
||||||
|
"""
|
||||||
|
Implementation of the Batch class used during training to perform SGD optimization.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, dataset_dict, idx_dict):
|
def __init__(self, dataset_dict, idx_dict):
|
||||||
|
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ class PinaDataLoader:
|
|||||||
Create batches according to the batch_size provided in input.
|
Create batches according to the batch_size provided in input.
|
||||||
"""
|
"""
|
||||||
self.batches = []
|
self.batches = []
|
||||||
n_elements = sum([len(v) for v in self.dataset_dict.values()])
|
n_elements = sum(len(v) for v in self.dataset_dict.values())
|
||||||
if batch_size is None:
|
if batch_size is None:
|
||||||
batch_size = n_elements
|
batch_size = n_elements
|
||||||
indexes_dict = {}
|
indexes_dict = {}
|
||||||
|
|||||||
@@ -1,3 +1,8 @@
|
|||||||
|
"""
|
||||||
|
Module for PinaSubset class
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class PinaSubset:
|
class PinaSubset:
|
||||||
"""
|
"""
|
||||||
TODO
|
TODO
|
||||||
|
|||||||
@@ -6,8 +6,9 @@ from .base_dataset import BaseDataset
|
|||||||
|
|
||||||
class UnsupervisedDataset(BaseDataset):
|
class UnsupervisedDataset(BaseDataset):
|
||||||
"""
|
"""
|
||||||
This class extend BaseDataset class to handle unsupervised dataset,
|
This class extend BaseDataset class to handle
|
||||||
composed of input points and, optionally, conditional variables
|
unsupervised dataset,composed of input points
|
||||||
|
and, optionally, conditional variables
|
||||||
"""
|
"""
|
||||||
data_type = 'unsupervised'
|
data_type = 'unsupervised'
|
||||||
__slots__ = ['input_points', 'conditional_variables']
|
__slots__ = ['input_points', 'conditional_variables']
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ class TorchOptimizer(Optimizer):
|
|||||||
|
|
||||||
self.optimizer_class = optimizer_class
|
self.optimizer_class = optimizer_class
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
self.optimizer_instance = None
|
||||||
|
|
||||||
def hook(self, parameters):
|
def hook(self, parameters):
|
||||||
self.optimizer_instance = self.optimizer_class(parameters,
|
self.optimizer_instance = self.optimizer_class(parameters,
|
||||||
|
|||||||
@@ -9,7 +9,8 @@ from .solvers.solver import SolverInterface
|
|||||||
|
|
||||||
class Trainer(pytorch_lightning.Trainer):
|
class Trainer(pytorch_lightning.Trainer):
|
||||||
|
|
||||||
def __init__(self, solver, batch_size=None, train_size=.7, test_size=.2, eval_size=.1, **kwargs):
|
def __init__(self, solver, batch_size=None, train_size=.7, test_size=.2,
|
||||||
|
eval_size=.1, **kwargs):
|
||||||
"""
|
"""
|
||||||
PINA Trainer class for costumizing every aspect of training via flags.
|
PINA Trainer class for costumizing every aspect of training via flags.
|
||||||
|
|
||||||
@@ -39,7 +40,6 @@ class Trainer(pytorch_lightning.Trainer):
|
|||||||
self._create_loader()
|
self._create_loader()
|
||||||
self._move_to_device()
|
self._move_to_device()
|
||||||
|
|
||||||
|
|
||||||
def _move_to_device(self):
|
def _move_to_device(self):
|
||||||
device = self._accelerator_connector._parallel_devices[0]
|
device = self._accelerator_connector._parallel_devices[0]
|
||||||
|
|
||||||
@@ -59,8 +59,10 @@ class Trainer(pytorch_lightning.Trainer):
|
|||||||
"""
|
"""
|
||||||
if not self.solver.problem.collector.full:
|
if not self.solver.problem.collector.full:
|
||||||
error_message = '\n'.join(
|
error_message = '\n'.join(
|
||||||
[f'{" " * 13} ---> Condition {key} {"sampled" if value else "not sampled"}'
|
[
|
||||||
for key, value in self.solver.problem.collector._is_conditions_ready.items()])
|
f'{" " * 13} ---> Condition {key} {"sampled" if value else "not sampled"}'
|
||||||
|
for key, value in
|
||||||
|
self.solver.problem.collector._is_conditions_ready.items()])
|
||||||
raise RuntimeError('Cannot create Trainer if not all conditions '
|
raise RuntimeError('Cannot create Trainer if not all conditions '
|
||||||
'are sampled. The Trainer got the following:\n'
|
'are sampled. The Trainer got the following:\n'
|
||||||
f'{error_message}')
|
f'{error_message}')
|
||||||
@@ -72,7 +74,8 @@ class Trainer(pytorch_lightning.Trainer):
|
|||||||
device = devices[0]
|
device = devices[0]
|
||||||
|
|
||||||
data_module = PinaDataModule(problem=self.solver.problem, device=device,
|
data_module = PinaDataModule(problem=self.solver.problem, device=device,
|
||||||
train_size=self.train_size, test_size=self.test_size,
|
train_size=self.train_size,
|
||||||
|
test_size=self.test_size,
|
||||||
eval_size=self.eval_size)
|
eval_size=self.eval_size)
|
||||||
data_module.setup()
|
data_module.setup()
|
||||||
self._loader = data_module.train_dataloader()
|
self._loader = data_module.train_dataloader()
|
||||||
|
|||||||
Reference in New Issue
Block a user