diff --git a/pina/collector.py b/pina/collector.py index f9ef194..4ebf236 100644 --- a/pina/collector.py +++ b/pina/collector.py @@ -48,7 +48,8 @@ class Collector: for condition_name, condition in self.problem.conditions.items(): # if the condition is not ready and domain is not attribute # of condition, we get and store the data - if (not self._is_conditions_ready[condition_name]) and (not hasattr(condition, "domain")): + if (not self._is_conditions_ready[condition_name]) and ( + not hasattr(condition, "domain")): # get data keys = condition.__slots__ values = [getattr(condition, name) for name in keys] @@ -69,7 +70,8 @@ class Collector: already_sampled = [] # if we have sampled the condition but not all variables else: - already_sampled = [self.data_collections[loc]['input_points']] + already_sampled = [ + self.data_collections[loc]['input_points']] # if the condition is ready but we want to sample again else: self._is_conditions_ready[loc] = False @@ -77,11 +79,13 @@ class Collector: # get the samples samples = [ - condition.domain.sample(n=n, mode=mode, variables=variables) + condition.domain.sample(n=n, mode=mode, + variables=variables) ] + already_sampled pts = merge_tensors(samples) if ( - set(pts.labels).issubset(sorted(self.problem.input_variables)) + set(pts.labels).issubset( + sorted(self.problem.input_variables)) ): pts = pts.sort_labels() if sorted(pts.labels) == sorted(self.problem.input_variables): @@ -89,7 +93,8 @@ class Collector: values = [pts, condition.equation] self.data_collections[loc] = dict(zip(keys, values)) else: - raise RuntimeError('Try to sample variables which are not in problem defined in the problem') + raise RuntimeError( + 'Try to sample variables which are not in problem defined in the problem') def add_points(self, new_points_dict): """ @@ -100,5 +105,7 @@ class Collector: """ for k, v in new_points_dict.items(): if not self._is_conditions_ready[k]: - raise RuntimeError('Cannot add points on a non sampled condition') - self.data_collections[k]['input_points'] = self.data_collections[k]['input_points'].vstack(v) + raise RuntimeError( + 'Cannot add points on a non sampled condition') + self.data_collections[k]['input_points'] = self.data_collections[k][ + 'input_points'].vstack(v) diff --git a/pina/condition/condition.py b/pina/condition/condition.py index 09180cc..01965fe 100644 --- a/pina/condition/condition.py +++ b/pina/condition/condition.py @@ -5,6 +5,7 @@ from .input_equation_condition import InputPointsEquationCondition from .input_output_condition import InputOutputPointsCondition from .data_condition import DataConditionInterface + class Condition: """ The class ``Condition`` is used to represent the constraints (physical @@ -38,23 +39,23 @@ class Condition: """ __slots__ = list( - set( - InputOutputPointsCondition.__slots__ + - InputPointsEquationCondition.__slots__ + - DomainEquationCondition.__slots__ + - DataConditionInterface.__slots__ - ) - ) + set( + InputOutputPointsCondition.__slots__ + + InputPointsEquationCondition.__slots__ + + DomainEquationCondition.__slots__ + + DataConditionInterface.__slots__ + ) + ) def __new__(cls, *args, **kwargs): - + if len(args) != 0: raise ValueError( "Condition takes only the following keyword " f"arguments: {Condition.__slots__}." ) - - sorted_keys = sorted(kwargs.keys()) + + sorted_keys = sorted(kwargs.keys()) if sorted_keys == sorted(InputOutputPointsCondition.__slots__): return InputOutputPointsCondition(**kwargs) elif sorted_keys == sorted(InputPointsEquationCondition.__slots__): @@ -66,4 +67,4 @@ class Condition: elif sorted_keys == DataConditionInterface.__slots__[0]: return DataConditionInterface(**kwargs) else: - raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.") \ No newline at end of file + raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.") diff --git a/pina/data/base_dataset.py b/pina/data/base_dataset.py index f1d17ae..b15a0be 100644 --- a/pina/data/base_dataset.py +++ b/pina/data/base_dataset.py @@ -27,7 +27,7 @@ class BaseDataset(Dataset): if not hasattr(cls, '__slots__'): raise TypeError( 'Something is wrong, __slots__ must be defined in subclasses.') - return super(BaseDataset, cls).__new__(cls) + return object.__new__(cls) def __init__(self, problem, device): """" diff --git a/pina/data/data_module.py b/pina/data/data_module.py index e4e8a45..25c7e54 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -26,7 +26,7 @@ class PinaDataModule(LightningDataModule): eval_size=.1, batch_size=None, shuffle=True, - datasets = None): + datasets=None): """ Initialize the object, creating dataset based on input problem :param AbstractProblem problem: PINA problem @@ -38,9 +38,11 @@ class PinaDataModule(LightningDataModule): :param datasets: list of datasets objects """ super().__init__() - dataset_classes = [SupervisedDataset, UnsupervisedDataset, SamplePointDataset] + dataset_classes = [SupervisedDataset, UnsupervisedDataset, + SamplePointDataset] if datasets is None: - self.datasets = [DatasetClass(problem, device) for DatasetClass in dataset_classes] + self.datasets = [DatasetClass(problem, device) for DatasetClass in + dataset_classes] else: self.datasets = datasets @@ -100,8 +102,6 @@ class PinaDataModule(LightningDataModule): for key, value in dataset.condition_names.items() } - - def train_dataloader(self): """ Return the training dataloader for the dataset @@ -158,11 +158,13 @@ class PinaDataModule(LightningDataModule): if seed is not None: generator = torch.Generator() generator.manual_seed(seed) - indices = torch.randperm(sum(lengths), generator=generator).tolist() + indices = torch.randperm(sum(lengths), + generator=generator).tolist() else: indices = torch.arange(sum(lengths)).tolist() else: - indices = torch.arange(0, sum(lengths), 1, dtype=torch.uint8).tolist() + indices = torch.arange(0, sum(lengths), 1, + dtype=torch.uint8).tolist() offsets = [ sum(lengths[:i]) if i > 0 else 0 for i in range(len(lengths)) ] diff --git a/pina/data/pina_batch.py b/pina/data/pina_batch.py index f61e002..ed34a91 100644 --- a/pina/data/pina_batch.py +++ b/pina/data/pina_batch.py @@ -5,6 +5,9 @@ from .pina_subset import PinaSubset class Batch: + """ + Implementation of the Batch class used during training to perform SGD optimization. + """ def __init__(self, dataset_dict, idx_dict): diff --git a/pina/data/pina_dataloader.py b/pina/data/pina_dataloader.py index cbd8fe8..e2d3fb7 100644 --- a/pina/data/pina_dataloader.py +++ b/pina/data/pina_dataloader.py @@ -33,7 +33,7 @@ class PinaDataLoader: Create batches according to the batch_size provided in input. """ self.batches = [] - n_elements = sum([len(v) for v in self.dataset_dict.values()]) + n_elements = sum(len(v) for v in self.dataset_dict.values()) if batch_size is None: batch_size = n_elements indexes_dict = {} diff --git a/pina/data/pina_subset.py b/pina/data/pina_subset.py index 41571f9..844321b 100644 --- a/pina/data/pina_subset.py +++ b/pina/data/pina_subset.py @@ -1,3 +1,8 @@ +""" +Module for PinaSubset class +""" + + class PinaSubset: """ TODO diff --git a/pina/data/unsupervised_dataset.py b/pina/data/unsupervised_dataset.py index f4e8fb3..18cf296 100644 --- a/pina/data/unsupervised_dataset.py +++ b/pina/data/unsupervised_dataset.py @@ -6,8 +6,9 @@ from .base_dataset import BaseDataset class UnsupervisedDataset(BaseDataset): """ - This class extend BaseDataset class to handle unsupervised dataset, - composed of input points and, optionally, conditional variables + This class extend BaseDataset class to handle + unsupervised dataset,composed of input points + and, optionally, conditional variables """ data_type = 'unsupervised' __slots__ = ['input_points', 'conditional_variables'] diff --git a/pina/optim/torch_optimizer.py b/pina/optim/torch_optimizer.py index ed90846..54818d5 100644 --- a/pina/optim/torch_optimizer.py +++ b/pina/optim/torch_optimizer.py @@ -13,6 +13,7 @@ class TorchOptimizer(Optimizer): self.optimizer_class = optimizer_class self.kwargs = kwargs + self.optimizer_instance = None def hook(self, parameters): self.optimizer_instance = self.optimizer_class(parameters, diff --git a/pina/trainer.py b/pina/trainer.py index 49c6a40..884eef7 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -9,7 +9,8 @@ from .solvers.solver import SolverInterface class Trainer(pytorch_lightning.Trainer): - def __init__(self, solver, batch_size=None, train_size=.7, test_size=.2, eval_size=.1, **kwargs): + def __init__(self, solver, batch_size=None, train_size=.7, test_size=.2, + eval_size=.1, **kwargs): """ PINA Trainer class for costumizing every aspect of training via flags. @@ -39,10 +40,9 @@ class Trainer(pytorch_lightning.Trainer): self._create_loader() self._move_to_device() - def _move_to_device(self): device = self._accelerator_connector._parallel_devices[0] - + # move parameters to device pb = self.solver.problem if hasattr(pb, "unknown_parameters"): @@ -59,11 +59,13 @@ class Trainer(pytorch_lightning.Trainer): """ if not self.solver.problem.collector.full: error_message = '\n'.join( - [f'{" " * 13} ---> Condition {key} {"sampled" if value else "not sampled"}' - for key, value in self.solver.problem.collector._is_conditions_ready.items()]) + [ + f'{" " * 13} ---> Condition {key} {"sampled" if value else "not sampled"}' + for key, value in + self.solver.problem.collector._is_conditions_ready.items()]) raise RuntimeError('Cannot create Trainer if not all conditions ' - 'are sampled. The Trainer got the following:\n' - f'{error_message}') + 'are sampled. The Trainer got the following:\n' + f'{error_message}') devices = self._accelerator_connector._parallel_devices if len(devices) > 1: @@ -72,7 +74,8 @@ class Trainer(pytorch_lightning.Trainer): device = devices[0] data_module = PinaDataModule(problem=self.solver.problem, device=device, - train_size=self.train_size, test_size=self.test_size, + train_size=self.train_size, + test_size=self.test_size, eval_size=self.eval_size) data_module.setup() self._loader = data_module.train_dataloader()