From 42ab1a666b5b6665a775c2d64edd31df473165a9 Mon Sep 17 00:00:00 2001 From: Dario Coscia Date: Mon, 24 Feb 2025 11:26:49 +0100 Subject: [PATCH] Formatting * Adding black as dev dependency * Formatting pina code * Formatting tests --- pina/__init__.py | 8 +- pina/adaptive_function/adaptive_function.py | 2 +- .../adaptive_function_interface.py | 2 +- pina/adaptive_functions/__init__.py | 7 +- pina/callback/adaptive_refinement_callback.py | 2 +- pina/callback/processing_callback.py | 18 +- pina/callbacks/__init__.py | 7 +- pina/collector.py | 22 +- pina/condition/__init__.py | 12 +- pina/condition/condition.py | 31 ++- pina/condition/condition_interface.py | 7 +- pina/condition/data_condition.py | 4 +- pina/condition/domain_equation_condition.py | 6 +- pina/condition/input_equation_condition.py | 6 +- pina/condition/input_output_condition.py | 9 +- pina/data/__init__.py | 7 +- pina/data/data_module.py | 215 +++++++++++------- pina/data/dataset.py | 105 +++++---- pina/domain/cartesian.py | 10 +- pina/domain/domain_interface.py | 8 +- pina/domain/exclusion_domain.py | 2 +- pina/domain/intersection_domain.py | 2 +- pina/domain/operation_interface.py | 2 +- pina/domain/simplex.py | 2 +- pina/domain/union_domain.py | 2 +- pina/equation/equation.py | 2 +- pina/equation/equation_factory.py | 2 +- pina/equation/equation_interface.py | 2 +- pina/equation/system_equation.py | 2 +- pina/geometry/__init__.py | 9 +- pina/graph.py | 181 ++++++++------- pina/label_tensor.py | 172 ++++++++------ pina/loss/__init__.py | 10 +- pina/loss/loss_interface.py | 4 +- pina/loss/lp_loss.py | 5 +- pina/loss/power_loss.py | 4 +- pina/loss/scalar_weighting.py | 8 +- pina/loss/weighting_interface.py | 2 +- pina/model/__init__.py | 4 +- pina/model/block/__init__.py | 2 +- .../block/average_neural_operator_block.py | 2 +- pina/model/block/embedding.py | 2 +- pina/model/block/gno_block.py | 39 ++-- pina/model/block/low_rank_block.py | 2 +- pina/model/graph_neural_operator.py | 71 +++--- pina/model/layers/__init__.py | 7 +- pina/operator.py | 19 +- pina/operators.py | 7 +- pina/optim/__init__.py | 2 +- pina/optim/optimizer_interface.py | 4 +- pina/optim/scheduler_interface.py | 6 +- pina/optim/torch_optimizer.py | 8 +- pina/optim/torch_scheduler.py | 9 +- pina/problem/abstract_problem.py | 27 ++- pina/problem/inverse_problem.py | 8 +- pina/problem/zoo/__init__.py | 10 +- pina/problem/zoo/diffusion_reaction.py | 46 ++-- .../problem/zoo/inverse_diffusion_reaction.py | 66 +++--- pina/problem/zoo/inverse_poisson_2d_square.py | 52 +++-- pina/problem/zoo/poisson_2d_square.py | 44 ++-- pina/problem/zoo/supervised_problem.py | 8 +- pina/solver/garom.py | 71 +++--- .../physic_informed_solver/causal_pinn.py | 34 +-- .../competitive_pinn.py | 63 ++--- .../physic_informed_solver/gradient_pinn.py | 34 +-- pina/solver/physic_informed_solver/pinn.py | 37 +-- .../physic_informed_solver/pinn_interface.py | 48 ++-- .../solver/physic_informed_solver/rba_pinn.py | 51 +++-- .../self_adaptive_pinn.py | 94 ++++---- pina/solver/reduced_order_model.py | 6 +- pina/solver/solver.py | 94 ++++---- pina/solver/supervised.py | 52 +++-- pina/solvers/__init__.py | 7 +- pina/solvers/pinns/__init__.py | 9 +- pina/trainer.py | 131 ++++++----- pina/utils.py | 14 +- pyproject.toml | 3 + 77 files changed, 1170 insertions(+), 924 deletions(-) diff --git a/pina/__init__.py b/pina/__init__.py index 06b9482..f6e7359 100644 --- a/pina/__init__.py +++ b/pina/__init__.py @@ -1,11 +1,11 @@ __all__ = [ "Trainer", - "LabelTensor", + "LabelTensor", "Condition", "PinaDataModule", - 'Graph', + "Graph", "SolverInterface", - "MultiSolverInterface" + "MultiSolverInterface", ] from .label_tensor import LabelTensor @@ -13,4 +13,4 @@ from .graph import Graph from .solver import SolverInterface, MultiSolverInterface from .trainer import Trainer from .condition.condition import Condition -from .data import PinaDataModule \ No newline at end of file +from .data import PinaDataModule diff --git a/pina/adaptive_function/adaptive_function.py b/pina/adaptive_function/adaptive_function.py index 36e2a95..9bf5fba 100644 --- a/pina/adaptive_function/adaptive_function.py +++ b/pina/adaptive_function/adaptive_function.py @@ -1,4 +1,4 @@ -""" Module for adaptive functions. """ +"""Module for adaptive functions.""" import torch from ..utils import check_consistency diff --git a/pina/adaptive_function/adaptive_function_interface.py b/pina/adaptive_function/adaptive_function_interface.py index 7058faf..20fae51 100644 --- a/pina/adaptive_function/adaptive_function_interface.py +++ b/pina/adaptive_function/adaptive_function_interface.py @@ -1,4 +1,4 @@ -""" Module for adaptive functions. """ +"""Module for adaptive functions.""" import torch diff --git a/pina/adaptive_functions/__init__.py b/pina/adaptive_functions/__init__.py index 6381b47..9af99d2 100644 --- a/pina/adaptive_functions/__init__.py +++ b/pina/adaptive_functions/__init__.py @@ -8,6 +8,7 @@ from ..utils import custom_warning_format warnings.formatwarning = custom_warning_format warnings.filterwarnings("always", category=DeprecationWarning) warnings.warn( - f"'pina.adaptive_functions' is deprecated and will be removed " - f"in future versions. Please use 'pina.adaptive_function' instead.", - DeprecationWarning) \ No newline at end of file + f"'pina.adaptive_functions' is deprecated and will be removed " + f"in future versions. Please use 'pina.adaptive_function' instead.", + DeprecationWarning, +) diff --git a/pina/callback/adaptive_refinement_callback.py b/pina/callback/adaptive_refinement_callback.py index 3462fd9..951ee75 100644 --- a/pina/callback/adaptive_refinement_callback.py +++ b/pina/callback/adaptive_refinement_callback.py @@ -67,7 +67,7 @@ class R3Refinement(Callback): # compute residual res_loss = {} tot_loss = [] - for location in self._sampling_locations: #TODO fix for new collector + for location in self._sampling_locations: # TODO fix for new collector condition = solver.problem.conditions[location] pts = solver.problem.input_pts[location] # send points to correct device diff --git a/pina/callback/processing_callback.py b/pina/callback/processing_callback.py index e8d0f68..f3a13c1 100644 --- a/pina/callback/processing_callback.py +++ b/pina/callback/processing_callback.py @@ -26,7 +26,7 @@ class MetricTracker(Callback): super().__init__() self._collection = [] # Default to tracking 'train_loss' and 'val_loss' if not specified - self.metrics_to_track = metrics_to_track or ['train_loss', 'val_loss'] + self.metrics_to_track = metrics_to_track or ["train_loss", "val_loss"] def on_train_epoch_end(self, trainer, pl_module): """ @@ -40,7 +40,8 @@ class MetricTracker(Callback): if trainer.current_epoch > 0: # Append only the tracked metrics to avoid unnecessary data tracked_metrics = { - k: v for k, v in trainer.logged_metrics.items() + k: v + for k, v in trainer.logged_metrics.items() if k in self.metrics_to_track } self._collection.append(copy.deepcopy(tracked_metrics)) @@ -57,16 +58,18 @@ class MetricTracker(Callback): return {} # Get intersection of keys across all collected dictionaries - common_keys = set(self._collection[0]).intersection(*self._collection[1:]) - + common_keys = set(self._collection[0]).intersection( + *self._collection[1:] + ) + # Stack the metric values for common keys and return return { k: torch.stack([dic[k] for dic in self._collection]) - for k in common_keys if k in self.metrics_to_track + for k in common_keys + if k in self.metrics_to_track } - class PINAProgressBar(TQDMProgressBar): BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]" @@ -142,7 +145,8 @@ class PINAProgressBar(TQDMProgressBar): for key in self._sorted_metrics: if ( key not in trainer.solver.problem.conditions.keys() - and key != "train" and key != "val" + and key != "train" + and key != "val" ): raise KeyError(f"Key '{key}' is not present in the dictionary") # add the loss pedix diff --git a/pina/callbacks/__init__.py b/pina/callbacks/__init__.py index e2dc42b..8c2c71d 100644 --- a/pina/callbacks/__init__.py +++ b/pina/callbacks/__init__.py @@ -8,6 +8,7 @@ from ..utils import custom_warning_format warnings.formatwarning = custom_warning_format warnings.filterwarnings("always", category=DeprecationWarning) warnings.warn( - f"'pina.callbacks' is deprecated and will be removed " - f"in future versions. Please use 'pina.callback' instead.", - DeprecationWarning) \ No newline at end of file + f"'pina.callbacks' is deprecated and will be removed " + f"in future versions. Please use 'pina.callback' instead.", + DeprecationWarning, +) diff --git a/pina/collector.py b/pina/collector.py index 93ea182..c8e8160 100644 --- a/pina/collector.py +++ b/pina/collector.py @@ -1,6 +1,7 @@ """ # TODO """ + from .graph import Graph from .utils import check_consistency @@ -14,14 +15,12 @@ class Collector: # those variables are used for the dataloading self._data_collections = {name: {} for name in self.problem.conditions} self.conditions_name = { - i: name - for i, name in enumerate(self.problem.conditions) + i: name for i, name in enumerate(self.problem.conditions) } # variables used to check that all conditions are sampled self._is_conditions_ready = { - name: False - for name in self.problem.conditions + name: False for name in self.problem.conditions } self.full = False @@ -51,13 +50,16 @@ class Collector: for condition_name, condition in self.problem.conditions.items(): # if the condition is not ready and domain is not attribute # of condition, we get and store the data - if (not self._is_conditions_ready[condition_name]) and (not hasattr( - condition, "domain")): + if (not self._is_conditions_ready[condition_name]) and ( + not hasattr(condition, "domain") + ): # get data keys = condition.__slots__ values = [getattr(condition, name) for name in keys] - values = [value.data if isinstance( - value, Graph) else value for value in values] + values = [ + value.data if isinstance(value, Graph) else value + for value in values + ] self.data_collections[condition_name] = dict(zip(keys, values)) # condition now is ready self._is_conditions_ready[condition_name] = True @@ -74,6 +76,6 @@ class Collector: samples = self.problem.discretised_domains[condition.domain] self.data_collections[condition_name] = { - 'input_points': samples, - 'equation': condition.equation + "input_points": samples, + "equation": condition.equation, } diff --git a/pina/condition/__init__.py b/pina/condition/__init__.py index 4c89b75..3893e34 100644 --- a/pina/condition/__init__.py +++ b/pina/condition/__init__.py @@ -1,12 +1,12 @@ __all__ = [ - 'Condition', - 'ConditionInterface', - 'DomainEquationCondition', - 'InputPointsEquationCondition', - 'InputOutputPointsCondition', + "Condition", + "ConditionInterface", + "DomainEquationCondition", + "InputPointsEquationCondition", + "InputOutputPointsCondition", ] from .condition_interface import ConditionInterface from .domain_equation_condition import DomainEquationCondition from .input_equation_condition import InputPointsEquationCondition -from .input_output_condition import InputOutputPointsCondition \ No newline at end of file +from .input_output_condition import InputOutputPointsCondition diff --git a/pina/condition/condition.py b/pina/condition/condition.py index 9ff27c1..e01db1f 100644 --- a/pina/condition/condition.py +++ b/pina/condition/condition.py @@ -1,4 +1,4 @@ -""" Condition module. """ +"""Condition module.""" from .domain_equation_condition import DomainEquationCondition from .input_equation_condition import InputPointsEquationCondition @@ -11,6 +11,7 @@ from ..utils import custom_warning_format warnings.formatwarning = custom_warning_format warnings.filterwarnings("always", category=DeprecationWarning) + class Condition: """ The class ``Condition`` is used to represent the constraints (physical @@ -44,24 +45,30 @@ class Condition: """ __slots__ = list( - set(InputOutputPointsCondition.__slots__ + - InputPointsEquationCondition.__slots__ + - DomainEquationCondition.__slots__ + - DataConditionInterface.__slots__)) + set( + InputOutputPointsCondition.__slots__ + + InputPointsEquationCondition.__slots__ + + DomainEquationCondition.__slots__ + + DataConditionInterface.__slots__ + ) + ) def __new__(cls, *args, **kwargs): if len(args) != 0: - raise ValueError("Condition takes only the following keyword " - f"arguments: {Condition.__slots__}.") + raise ValueError( + "Condition takes only the following keyword " + f"arguments: {Condition.__slots__}." + ) # back-compatibility 0.1 - if 'location' in kwargs.keys(): - kwargs['domain'] = kwargs.pop('location') + if "location" in kwargs.keys(): + kwargs["domain"] = kwargs.pop("location") warnings.warn( - f"'location' is deprecated and will be removed " - f"in future versions. Please use 'domain' instead.", - DeprecationWarning) + f"'location' is deprecated and will be removed " + f"in future versions. Please use 'domain' instead.", + DeprecationWarning, + ) sorted_keys = sorted(kwargs.keys()) diff --git a/pina/condition/condition_interface.py b/pina/condition/condition_interface.py index f2fe5db..a9d62fd 100644 --- a/pina/condition/condition_interface.py +++ b/pina/condition/condition_interface.py @@ -3,7 +3,7 @@ from abc import ABCMeta class ConditionInterface(metaclass=ABCMeta): - condition_types = ['physics', 'supervised', 'unsupervised'] + condition_types = ["physics", "supervised", "unsupervised"] def __init__(self, *args, **kwargs): self._condition_type = None @@ -28,6 +28,7 @@ class ConditionInterface(metaclass=ABCMeta): for value in values: if value not in ConditionInterface.condition_types: raise ValueError( - 'Unavailable type of condition, expected one of' - f' {ConditionInterface.condition_types}.') + "Unavailable type of condition, expected one of" + f" {ConditionInterface.condition_types}." + ) self._condition_type = values diff --git a/pina/condition/data_condition.py b/pina/condition/data_condition.py index 255c329..ffd10f3 100644 --- a/pina/condition/data_condition.py +++ b/pina/condition/data_condition.py @@ -25,8 +25,8 @@ class DataConditionInterface(ConditionInterface): self.conditional_variables = conditional_variables def __setattr__(self, key, value): - if (key == 'input_points') or (key == 'conditional_variables'): + if (key == "input_points") or (key == "conditional_variables"): check_consistency(value, (LabelTensor, Graph, torch.Tensor)) DataConditionInterface.__dict__[key].__set__(self, value) - elif key in ('_problem', '_condition_type'): + elif key in ("_problem", "_condition_type"): super().__setattr__(key, value) diff --git a/pina/condition/domain_equation_condition.py b/pina/condition/domain_equation_condition.py index 9fb0dcb..002a7c4 100644 --- a/pina/condition/domain_equation_condition.py +++ b/pina/condition/domain_equation_condition.py @@ -23,11 +23,11 @@ class DomainEquationCondition(ConditionInterface): self.equation = equation def __setattr__(self, key, value): - if key == 'domain': + if key == "domain": check_consistency(value, (DomainInterface, str)) DomainEquationCondition.__dict__[key].__set__(self, value) - elif key == 'equation': + elif key == "equation": check_consistency(value, (EquationInterface)) DomainEquationCondition.__dict__[key].__set__(self, value) - elif key in ('_problem', '_condition_type'): + elif key in ("_problem", "_condition_type"): super().__setattr__(key, value) diff --git a/pina/condition/input_equation_condition.py b/pina/condition/input_equation_condition.py index 7416cff..061261f 100644 --- a/pina/condition/input_equation_condition.py +++ b/pina/condition/input_equation_condition.py @@ -24,13 +24,13 @@ class InputPointsEquationCondition(ConditionInterface): self.equation = equation def __setattr__(self, key, value): - if key == 'input_points': + if key == "input_points": check_consistency( value, (LabelTensor) ) # for now only labeltensors, we need labels for the operator! InputPointsEquationCondition.__dict__[key].__set__(self, value) - elif key == 'equation': + elif key == "equation": check_consistency(value, (EquationInterface)) InputPointsEquationCondition.__dict__[key].__set__(self, value) - elif key in ('_problem', '_condition_type'): + elif key in ("_problem", "_condition_type"): super().__setattr__(key, value) diff --git a/pina/condition/input_output_condition.py b/pina/condition/input_output_condition.py index 5cf5518..47f182a 100644 --- a/pina/condition/input_output_condition.py +++ b/pina/condition/input_output_condition.py @@ -24,8 +24,11 @@ class InputOutputPointsCondition(ConditionInterface): self.output_points = output_points def __setattr__(self, key, value): - if (key == 'input_points') or (key == 'output_points'): - check_consistency(value, (LabelTensor, Graph, torch.Tensor, torch_geometric.data.Data)) + if (key == "input_points") or (key == "output_points"): + check_consistency( + value, + (LabelTensor, Graph, torch.Tensor, torch_geometric.data.Data), + ) InputOutputPointsCondition.__dict__[key].__set__(self, value) - elif key in ('_problem', '_condition_type'): + elif key in ("_problem", "_condition_type"): super().__setattr__(key, value) diff --git a/pina/data/__init__.py b/pina/data/__init__.py index 292c9ed..4c14188 100644 --- a/pina/data/__init__.py +++ b/pina/data/__init__.py @@ -1,12 +1,9 @@ """ Import data classes """ -__all__ = [ - 'PinaDataModule', - 'PinaDataset' -] +__all__ = ["PinaDataModule", "PinaDataset"] from .data_module import PinaDataModule -from .dataset import PinaDataset +from .dataset import PinaDataset diff --git a/pina/data/data_module.py b/pina/data/data_module.py index 9ecfaa5..ef6a6de 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -11,7 +11,7 @@ from ..collector import Collector class DummyDataloader: - """" + """ " Dummy dataloader used when batch size is None. It callects all the data in self.dataset and returns it when it is called a single batch. """ @@ -28,14 +28,17 @@ class DummyDataloader: - **Non-Distributed Environment**: - Fetches the entire dataset. """ - if (torch.distributed.is_available() and - torch.distributed.is_initialized()): + if ( + torch.distributed.is_available() + and torch.distributed.is_initialized() + ): rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() if len(dataset) < world_size: raise RuntimeError( "Dimension of the dataset smaller than world size." - " Increase the size of the partition or use a single GPU") + " Increase the size of the partition or use a single GPU" + ) idx, i = [], rank while i < len(dataset): idx.append(i) @@ -57,9 +60,11 @@ class DummyDataloader: class Collator: def __init__(self, max_conditions_lengths, dataset=None): self.max_conditions_lengths = max_conditions_lengths - self.callable_function = self._collate_custom_dataloader if \ - max_conditions_lengths is None else ( - self._collate_standard_dataloader) + self.callable_function = ( + self._collate_custom_dataloader + if max_conditions_lengths is None + else (self._collate_standard_dataloader) + ) self.dataset = dataset if isinstance(self.dataset, PinaTensorDataset): self._collate = self._collate_tensor_dataset @@ -82,9 +87,15 @@ class Collator: single_cond_dict = {} condition_args = batch[0][condition_name].keys() for arg in condition_args: - data_list = [batch[idx][condition_name][arg] for idx in range( - min(len(batch), - self.max_conditions_lengths[condition_name]))] + data_list = [ + batch[idx][condition_name][arg] + for idx in range( + min( + len(batch), + self.max_conditions_lengths[condition_name], + ) + ) + ] single_cond_dict[arg] = self._collate(data_list) batch_dict[condition_name] = single_cond_dict @@ -114,8 +125,10 @@ class Collator: class PinaSampler: def __new__(cls, dataset, shuffle): - if (torch.distributed.is_available() and - torch.distributed.is_initialized()): + if ( + torch.distributed.is_available() + and torch.distributed.is_initialized() + ): sampler = DistributedSampler(dataset, shuffle=shuffle) else: if shuffle: @@ -131,19 +144,20 @@ class PinaDataModule(LightningDataModule): management of different types of Datasets defined in PINA """ - def __init__(self, - problem, - train_size=.7, - test_size=.2, - val_size=.1, - predict_size=0., - batch_size=None, - shuffle=True, - repeat=False, - automatic_batching=None, - num_workers=0, - pin_memory=False, - ): + def __init__( + self, + problem, + train_size=0.7, + test_size=0.2, + val_size=0.1, + predict_size=0.0, + batch_size=None, + shuffle=True, + repeat=False, + automatic_batching=None, + num_workers=0, + pin_memory=False, + ): """ Initialize the object, creating datasets based on the input problem. @@ -170,8 +184,8 @@ class PinaDataModule(LightningDataModule): :param pin_memory: Whether to use pinned memory for faster data transfer to GPU. (Default False) :type pin_memory: bool """ - logging.debug('Start initialization of Pina DataModule') - logging.info('Start initialization of Pina DataModule') + logging.debug("Start initialization of Pina DataModule") + logging.info("Start initialization of Pina DataModule") super().__init__() # Store fixed attributes @@ -182,13 +196,16 @@ class PinaDataModule(LightningDataModule): if batch_size is None and num_workers != 0: warnings.warn( "Setting num_workers when batch_size is None has no effect on " - "the DataLoading process.") + "the DataLoading process." + ) self.num_workers = 0 else: self.num_workers = num_workers if batch_size is None and pin_memory: - warnings.warn("Setting pin_memory to True has no effect when " - "batch_size is None.") + warnings.warn( + "Setting pin_memory to True has no effect when " + "batch_size is None." + ) self.pin_memory = False else: self.pin_memory = pin_memory @@ -204,22 +221,22 @@ class PinaDataModule(LightningDataModule): # Split input data into subsets splits_dict = {} if train_size > 0: - splits_dict['train'] = train_size + splits_dict["train"] = train_size self.train_dataset = None else: self.train_dataloader = super().train_dataloader if test_size > 0: - splits_dict['test'] = test_size + splits_dict["test"] = test_size self.test_dataset = None else: self.test_dataloader = super().test_dataloader if val_size > 0: - splits_dict['val'] = val_size + splits_dict["val"] = val_size self.val_dataset = None else: self.val_dataloader = super().val_dataloader if predict_size > 0: - splits_dict['predict'] = predict_size + splits_dict["predict"] = predict_size self.predict_dataset = None else: self.predict_dataloader = super().predict_dataloader @@ -230,29 +247,36 @@ class PinaDataModule(LightningDataModule): """ Perform the splitting of the dataset """ - logging.debug('Start setup of Pina DataModule obj') - if stage == 'fit' or stage is None: + logging.debug("Start setup of Pina DataModule obj") + if stage == "fit" or stage is None: self.train_dataset = PinaDatasetFactory( - self.collector_splits['train'], + self.collector_splits["train"], max_conditions_lengths=self.find_max_conditions_lengths( - 'train'), automatic_batching=self.automatic_batching) - if 'val' in self.collector_splits.keys(): - self.val_dataset = PinaDatasetFactory( - self.collector_splits['val'], - max_conditions_lengths=self.find_max_conditions_lengths( - 'val'), automatic_batching=self.automatic_batching - ) - elif stage == 'test': - self.test_dataset = PinaDatasetFactory( - self.collector_splits['test'], - max_conditions_lengths=self.find_max_conditions_lengths( - 'test'), automatic_batching=self.automatic_batching + "train" + ), + automatic_batching=self.automatic_batching, ) - elif stage == 'predict': + if "val" in self.collector_splits.keys(): + self.val_dataset = PinaDatasetFactory( + self.collector_splits["val"], + max_conditions_lengths=self.find_max_conditions_lengths( + "val" + ), + automatic_batching=self.automatic_batching, + ) + elif stage == "test": + self.test_dataset = PinaDatasetFactory( + self.collector_splits["test"], + max_conditions_lengths=self.find_max_conditions_lengths("test"), + automatic_batching=self.automatic_batching, + ) + elif stage == "predict": self.predict_dataset = PinaDatasetFactory( - self.collector_splits['predict'], + self.collector_splits["predict"], max_conditions_lengths=self.find_max_conditions_lengths( - 'predict'), automatic_batching=self.automatic_batching + "predict" + ), + automatic_batching=self.automatic_batching, ) else: raise ValueError( @@ -261,28 +285,29 @@ class PinaDataModule(LightningDataModule): @staticmethod def _split_condition(condition_dict, splits_dict): - len_condition = len(condition_dict['input_points']) + len_condition = len(condition_dict["input_points"]) lengths = [ - int(len_condition * length) for length in - splits_dict.values() + int(len_condition * length) for length in splits_dict.values() ] remainder = len_condition - sum(lengths) for i in range(remainder): lengths[i % len(lengths)] += 1 - splits_dict = {k: max(1, v) for k, v in zip(splits_dict.keys(), lengths) - } + splits_dict = { + k: max(1, v) for k, v in zip(splits_dict.keys(), lengths) + } to_return_dict = {} offset = 0 for stage, stage_len in splits_dict.items(): - to_return_dict[stage] = {k: v[offset:offset + stage_len] - for k, v in condition_dict.items() if - k != 'equation' - # Equations are NEVER dataloaded - } + to_return_dict[stage] = { + k: v[offset : offset + stage_len] + for k, v in condition_dict.items() + if k != "equation" + # Equations are NEVER dataloaded + } if offset + stage_len > len_condition: offset = len_condition - 1 continue @@ -298,13 +323,12 @@ class PinaDataModule(LightningDataModule): def _apply_shuffle(condition_dict, len_data): idx = torch.randperm(len_data) for k, v in condition_dict.items(): - if k == 'equation': + if k == "equation": continue if isinstance(v, list): condition_dict[k] = [v[i] for i in idx] elif isinstance(v, LabelTensor): - condition_dict[k] = LabelTensor(v.tensor[idx], - v.labels) + condition_dict[k] = LabelTensor(v.tensor[idx], v.labels) elif isinstance(v, torch.Tensor): condition_dict[k] = v[idx] else: @@ -312,42 +336,53 @@ class PinaDataModule(LightningDataModule): # ----------- End auxiliary function ------------ - logging.debug('Dataset creation in PinaDataModule obj') + logging.debug("Dataset creation in PinaDataModule obj") split_names = list(splits_dict.keys()) dataset_dict = {name: {} for name in split_names} - for condition_name, condition_dict in collector.data_collections.items(): - len_data = len(condition_dict['input_points']) + for ( + condition_name, + condition_dict, + ) in collector.data_collections.items(): + len_data = len(condition_dict["input_points"]) if self.shuffle: _apply_shuffle(condition_dict, len_data) - for key, data in self._split_condition(condition_dict, - splits_dict).items(): + for key, data in self._split_condition( + condition_dict, splits_dict + ).items(): dataset_dict[key].update({condition_name: data}) return dataset_dict def _create_dataloader(self, split, dataset): - shuffle = self.shuffle if split == 'train' else False + shuffle = self.shuffle if split == "train" else False # Suppress the warning about num_workers. # In many cases, especially for PINNs, serial data loading can outperform parallel data loading. warnings.filterwarnings( "ignore", message=( - r"The '(train|val|test)_dataloader' does not have many workers which may be a bottleneck."), - module="lightning.pytorch.trainer.connectors.data_connector" + r"The '(train|val|test)_dataloader' does not have many workers which may be a bottleneck." + ), + module="lightning.pytorch.trainer.connectors.data_connector", ) # Use custom batching (good if batch size is large) if self.batch_size is not None: sampler = PinaSampler(dataset, shuffle) if self.automatic_batching: - collate = Collator(self.find_max_conditions_lengths(split), - dataset=dataset) + collate = Collator( + self.find_max_conditions_lengths(split), dataset=dataset + ) else: collate = Collator(None, dataset=dataset) - return DataLoader(dataset, self.batch_size, - collate_fn=collate, sampler=sampler, - num_workers=self.num_workers) + return DataLoader( + dataset, + self.batch_size, + collate_fn=collate, + sampler=sampler, + num_workers=self.num_workers, + ) dataloader = DummyDataloader(dataset) dataloader.dataset = self._transfer_batch_to_device( - dataloader.dataset, self.trainer.strategy.root_device, 0) + dataloader.dataset, self.trainer.strategy.root_device, 0 + ) self.transfer_batch_to_device = self._transfer_batch_to_device_dummy return dataloader @@ -355,31 +390,32 @@ class PinaDataModule(LightningDataModule): max_conditions_lengths = {} for k, v in self.collector_splits[split].items(): if self.batch_size is None: - max_conditions_lengths[k] = len(v['input_points']) + max_conditions_lengths[k] = len(v["input_points"]) elif self.repeat: max_conditions_lengths[k] = self.batch_size else: - max_conditions_lengths[k] = min(len(v['input_points']), - self.batch_size) + max_conditions_lengths[k] = min( + len(v["input_points"]), self.batch_size + ) return max_conditions_lengths def val_dataloader(self): """ Create the validation dataloader """ - return self._create_dataloader('val', self.val_dataset) + return self._create_dataloader("val", self.val_dataset) def train_dataloader(self): """ Create the training dataloader """ - return self._create_dataloader('train', self.train_dataset) + return self._create_dataloader("train", self.train_dataset) def test_dataloader(self): """ Create the testing dataloader """ - return self._create_dataloader('test', self.test_dataset) + return self._create_dataloader("test", self.test_dataset) def predict_dataloader(self): """ @@ -397,9 +433,12 @@ class PinaDataModule(LightningDataModule): training loop and is used to transfer the batch to the device. """ batch = [ - (k, - super(LightningDataModule, self).transfer_batch_to_device( - v, device, dataloader_idx)) + ( + k, + super(LightningDataModule, self).transfer_batch_to_device( + v, device, dataloader_idx + ), + ) for k, v in batch.items() ] diff --git a/pina/data/dataset.py b/pina/data/dataset.py index 2fecb93..3944ef4 100644 --- a/pina/data/dataset.py +++ b/pina/data/dataset.py @@ -1,6 +1,7 @@ """ This module provide basic data management functionalities """ + import functools import torch from torch.utils.data import Dataset @@ -19,15 +20,24 @@ class PinaDatasetFactory: def __new__(cls, conditions_dict, **kwargs): if len(conditions_dict) == 0: - raise ValueError('No conditions provided') - if all([isinstance(v['input_points'], torch.Tensor) for v - in conditions_dict.values()]): + raise ValueError("No conditions provided") + if all( + [ + isinstance(v["input_points"], torch.Tensor) + for v in conditions_dict.values() + ] + ): return PinaTensorDataset(conditions_dict, **kwargs) - elif all([isinstance(v['input_points'], list) for v - in conditions_dict.values()]): + elif all( + [ + isinstance(v["input_points"], list) + for v in conditions_dict.values() + ] + ): return PinaGraphDataset(conditions_dict, **kwargs) - raise ValueError('Conditions must be either torch.Tensor or list of Data ' - 'objects.') + raise ValueError( + "Conditions must be either torch.Tensor or list of Data " "objects." + ) class PinaDataset(Dataset): @@ -38,14 +48,15 @@ class PinaDataset(Dataset): def __init__(self, conditions_dict, max_conditions_lengths): self.conditions_dict = conditions_dict self.max_conditions_lengths = max_conditions_lengths - self.conditions_length = {k: len(v['input_points']) for k, v in - self.conditions_dict.items()} + self.conditions_length = { + k: len(v["input_points"]) for k, v in self.conditions_dict.items() + } self.length = max(self.conditions_length.values()) def _get_max_len(self): max_len = 0 for condition in self.conditions_dict.values(): - max_len = max(max_len, len(condition['input_points'])) + max_len = max(max_len, len(condition["input_points"])) return max_len def __len__(self): @@ -57,8 +68,9 @@ class PinaDataset(Dataset): class PinaTensorDataset(PinaDataset): - def __init__(self, conditions_dict, max_conditions_lengths, - automatic_batching): + def __init__( + self, conditions_dict, max_conditions_lengths, automatic_batching + ): super().__init__(conditions_dict, max_conditions_lengths) if automatic_batching: @@ -68,19 +80,23 @@ class PinaTensorDataset(PinaDataset): def _getitem_int(self, idx): return { - k: {k_data: v[k_data][idx % len(v['input_points'])] for k_data - in v.keys()} for k, v in self.conditions_dict.items() + k: { + k_data: v[k_data][idx % len(v["input_points"])] + for k_data in v.keys() + } + for k, v in self.conditions_dict.items() } def fetch_from_idx_list(self, idx): to_return_dict = {} for condition, data in self.conditions_dict.items(): - cond_idx = idx[:self.max_conditions_lengths[condition]] + cond_idx = idx[: self.max_conditions_lengths[condition]] condition_len = self.conditions_length[condition] if self.length > condition_len: cond_idx = [idx % condition_len for idx in cond_idx] - to_return_dict[condition] = {k: v[cond_idx] - for k, v in data.items()} + to_return_dict[condition] = { + k: v[cond_idx] for k, v in data.items() + } return to_return_dict @staticmethod @@ -99,15 +115,14 @@ class PinaTensorDataset(PinaDataset): """ Method to return input points for training. """ - return { - k: v['input_points'] for k, v in self.conditions_dict.items() - } + return {k: v["input_points"] for k, v in self.conditions_dict.items()} class PinaBatch(Batch): """ Add extract function to torch_geometric Batch object """ + def __init__(self): super().__init__(self) @@ -116,8 +131,8 @@ class PinaBatch(Batch): """ Perform extraction of labels on node features (x) - :param labels: Labels to extract - :type labels: list[str] | tuple[str] | str + :param labels: Labels to extract + :type labels: list[str] | tuple[str] | str :return: Batch object with extraction performed on x :rtype: PinaBatch """ @@ -127,8 +142,9 @@ class PinaBatch(Batch): class PinaGraphDataset(PinaDataset): - def __init__(self, conditions_dict, max_conditions_lengths, - automatic_batching): + def __init__( + self, conditions_dict, max_conditions_lengths, automatic_batching + ): super().__init__(conditions_dict, max_conditions_lengths) self.in_labels = {} self.out_labels = None @@ -137,35 +153,43 @@ class PinaGraphDataset(PinaDataset): else: self._getitem_func = self._getitem_dummy - ex_data = conditions_dict[list(conditions_dict.keys())[ - 0]]['input_points'][0] + ex_data = conditions_dict[list(conditions_dict.keys())[0]][ + "input_points" + ][0] for name, attr in ex_data.items(): if isinstance(attr, LabelTensor): self.in_labels[name] = attr.stored_labels - ex_data = conditions_dict[list(conditions_dict.keys())[ - 0]]['output_points'][0] + ex_data = conditions_dict[list(conditions_dict.keys())[0]][ + "output_points" + ][0] if isinstance(ex_data, LabelTensor): self.out_labels = ex_data.labels - self._create_graph_batch_from_list = self._labelise_batch( - self._base_create_graph_batch_from_list) if self.in_labels \ + self._create_graph_batch_from_list = ( + self._labelise_batch(self._base_create_graph_batch_from_list) + if self.in_labels else self._base_create_graph_batch_from_list + ) - self._create_output_batch = self._labelise_tensor( - self._base_create_output_batch) if self.out_labels is not None \ + self._create_output_batch = ( + self._labelise_tensor(self._base_create_output_batch) + if self.out_labels is not None else self._base_create_output_batch + ) def fetch_from_idx_list(self, idx): to_return_dict = {} for condition, data in self.conditions_dict.items(): - cond_idx = idx[:self.max_conditions_lengths[condition]] + cond_idx = idx[: self.max_conditions_lengths[condition]] condition_len = self.conditions_length[condition] if self.length > condition_len: cond_idx = [idx % condition_len for idx in cond_idx] to_return_dict[condition] = { - k: self._create_graph_batch_from_list([v[i] for i in idx]) - if isinstance(v, list) - else self._create_output_batch(v[idx]) + k: ( + self._create_graph_batch_from_list([v[i] for i in idx]) + if isinstance(v, list) + else self._create_output_batch(v[idx]) + ) for k, v in data.items() } @@ -184,8 +208,11 @@ class PinaGraphDataset(PinaDataset): def _getitem_int(self, idx): return { - k: {k_data: v[k_data][idx % len(v['input_points'])] for k_data - in v.keys()} for k, v in self.conditions_dict.items() + k: { + k_data: v[k_data][idx % len(v["input_points"])] + for k_data in v.keys() + } + for k, v in self.conditions_dict.items() } def get_all_data(self): @@ -204,6 +231,7 @@ class PinaGraphDataset(PinaDataset): tmp.labels = v batch[k] = tmp return batch + return wrapper def _labelise_tensor(self, func): @@ -213,6 +241,7 @@ class PinaGraphDataset(PinaDataset): if isinstance(out, LabelTensor): out.labels = self.out_labels return out + return wrapper def create_graph_batch(self, data): diff --git a/pina/domain/cartesian.py b/pina/domain/cartesian.py index 48e5e4d..9c312ea 100644 --- a/pina/domain/cartesian.py +++ b/pina/domain/cartesian.py @@ -168,8 +168,9 @@ class CartesianDomain(DomainInterface): for variable in variables: if variable in self.fixed_.keys(): value = self.fixed_[variable] - pts_variable = torch.tensor([[value] - ]).repeat(result.shape[0], 1) + pts_variable = torch.tensor([[value]]).repeat( + result.shape[0], 1 + ) pts_variable = pts_variable.as_subclass(LabelTensor) pts_variable.labels = [variable] @@ -202,8 +203,9 @@ class CartesianDomain(DomainInterface): for variable in variables: if variable in self.fixed_.keys(): value = self.fixed_[variable] - pts_variable = torch.tensor([[value] - ]).repeat(result.shape[0], 1) + pts_variable = torch.tensor([[value]]).repeat( + result.shape[0], 1 + ) pts_variable = pts_variable.as_subclass(LabelTensor) pts_variable.labels = [variable] diff --git a/pina/domain/domain_interface.py b/pina/domain/domain_interface.py index 916bf3e..265b64f 100644 --- a/pina/domain/domain_interface.py +++ b/pina/domain/domain_interface.py @@ -36,9 +36,11 @@ class DomainInterface(metaclass=ABCMeta): values = [values] for value in values: if value not in DomainInterface.available_sampling_modes: - raise TypeError(f"mode {value} not valid. Expected at least " - "one in " - f"{DomainInterface.available_sampling_modes}.") + raise TypeError( + f"mode {value} not valid. Expected at least " + "one in " + f"{DomainInterface.available_sampling_modes}." + ) @abstractmethod def sample(self): diff --git a/pina/domain/exclusion_domain.py b/pina/domain/exclusion_domain.py index a05b154..6b04b0c 100644 --- a/pina/domain/exclusion_domain.py +++ b/pina/domain/exclusion_domain.py @@ -1,4 +1,4 @@ -"""Module for Exclusion class. """ +"""Module for Exclusion class.""" import torch from ..label_tensor import LabelTensor diff --git a/pina/domain/intersection_domain.py b/pina/domain/intersection_domain.py index bb0499b..906595f 100644 --- a/pina/domain/intersection_domain.py +++ b/pina/domain/intersection_domain.py @@ -1,4 +1,4 @@ -"""Module for Intersection class. """ +"""Module for Intersection class.""" import torch from ..label_tensor import LabelTensor diff --git a/pina/domain/operation_interface.py b/pina/domain/operation_interface.py index e42d37e..7023eb9 100644 --- a/pina/domain/operation_interface.py +++ b/pina/domain/operation_interface.py @@ -1,4 +1,4 @@ -""" Module for OperationInterface class. """ +"""Module for OperationInterface class.""" from .domain_interface import DomainInterface from ..utils import check_consistency diff --git a/pina/domain/simplex.py b/pina/domain/simplex.py index 6915c12..1e706c6 100644 --- a/pina/domain/simplex.py +++ b/pina/domain/simplex.py @@ -144,7 +144,7 @@ class SimplexDomain(DomainInterface): return all(torch.gt(lambdas, 0.0)) and all(torch.lt(lambdas, 1.0)) return all(torch.ge(lambdas, 0)) and ( - any(torch.eq(lambdas, 0)) or any(torch.eq(lambdas, 1)) + any(torch.eq(lambdas, 0)) or any(torch.eq(lambdas, 1)) ) def _sample_interior_randomly(self, n, variables): diff --git a/pina/domain/union_domain.py b/pina/domain/union_domain.py index 91aa5fb..813cc74 100644 --- a/pina/domain/union_domain.py +++ b/pina/domain/union_domain.py @@ -1,4 +1,4 @@ -"""Module for Union class. """ +"""Module for Union class.""" import torch from .operation_interface import OperationInterface diff --git a/pina/equation/equation.py b/pina/equation/equation.py index 3a8f4b1..6ab28cb 100644 --- a/pina/equation/equation.py +++ b/pina/equation/equation.py @@ -1,4 +1,4 @@ -""" Module for Equation. """ +"""Module for Equation.""" from .equation_interface import EquationInterface diff --git a/pina/equation/equation_factory.py b/pina/equation/equation_factory.py index cdc4c3f..6894659 100644 --- a/pina/equation/equation_factory.py +++ b/pina/equation/equation_factory.py @@ -1,4 +1,4 @@ -""" Module """ +"""Module""" from .equation import Equation from ..operator import grad, div, laplacian diff --git a/pina/equation/equation_interface.py b/pina/equation/equation_interface.py index c64c180..982b431 100644 --- a/pina/equation/equation_interface.py +++ b/pina/equation/equation_interface.py @@ -1,4 +1,4 @@ -""" Module for EquationInterface class """ +"""Module for EquationInterface class""" from abc import ABCMeta, abstractmethod diff --git a/pina/equation/system_equation.py b/pina/equation/system_equation.py index bf54abd..2ed54ae 100644 --- a/pina/equation/system_equation.py +++ b/pina/equation/system_equation.py @@ -1,4 +1,4 @@ -""" Module for SystemEquation. """ +"""Module for SystemEquation.""" import torch from .equation import Equation diff --git a/pina/geometry/__init__.py b/pina/geometry/__init__.py index 47cc4a4..e627b29 100644 --- a/pina/geometry/__init__.py +++ b/pina/geometry/__init__.py @@ -11,7 +11,8 @@ Location = DomainInterface warnings.formatwarning = custom_warning_format warnings.filterwarnings("always", category=DeprecationWarning) warnings.warn( - "'pina.geometry' is deprecated and will be removed " - "in future versions. Please use 'pina.domain' instead. " - "Location moved to DomainInferface object.", - DeprecationWarning) \ No newline at end of file + "'pina.geometry' is deprecated and will be removed " + "in future versions. Please use 'pina.domain' instead. " + "Location moved to DomainInferface object.", + DeprecationWarning, +) diff --git a/pina/graph.py b/pina/graph.py index ca92ab4..3bfb370 100644 --- a/pina/graph.py +++ b/pina/graph.py @@ -14,15 +14,15 @@ class Graph: """ def __init__( - self, - x, - pos, - edge_index, - edge_attr=None, - build_edge_attr=False, - undirected=False, - custom_build_edge_attr=None, - additional_params=None + self, + x, + pos, + edge_index, + edge_attr=None, + build_edge_attr=False, + undirected=False, + custom_build_edge_attr=None, + additional_params=None, ): """ Constructor for the Graph class. This object creates a list of PyTorch Geometric Data objects. @@ -72,8 +72,9 @@ class Graph: self._build_edge_attr = custom_build_edge_attr # Check consistency and initialize additional_parameters (if present) - additional_params = self._check_additional_params(additional_params, - data_len) + additional_params = self._check_additional_params( + additional_params, data_len + ) # Make the graphs undirected if undirected: @@ -84,49 +85,63 @@ class Graph: # Prepare internal lists to create a graph list (same positions but # different node features) - if isinstance(x, list) and isinstance(pos, - (torch.Tensor, LabelTensor)): + if isinstance(x, list) and isinstance(pos, (torch.Tensor, LabelTensor)): # Replicate the positions, edge_index and edge_attr pos, edge_index = [pos] * data_len, [edge_index] * data_len # Prepare internal lists to create a list containing a single graph - elif isinstance(x, (torch.Tensor, LabelTensor)) and isinstance(pos, ( - torch.Tensor, LabelTensor)): + elif isinstance(x, (torch.Tensor, LabelTensor)) and isinstance( + pos, (torch.Tensor, LabelTensor) + ): # Encapsulate the input tensors into lists x, pos, edge_index = [x], [pos], [edge_index] # Prepare internal lists to create a list of graphs (same node features # but different positions) - elif (isinstance(x, (torch.Tensor, LabelTensor)) - and isinstance(pos, list)): + elif isinstance(x, (torch.Tensor, LabelTensor)) and isinstance( + pos, list + ): # Replicate the node features x = [x] * data_len elif not isinstance(x, list) and not isinstance(pos, list): raise TypeError("x and pos must be lists or tensors.") # Build the edge attributes - edge_attr = self._check_and_build_edge_attr(edge_attr, build_edge_attr, - data_len, edge_index, pos, - x) + edge_attr = self._check_and_build_edge_attr( + edge_attr, build_edge_attr, data_len, edge_index, pos, x + ) # Perform the graph construction - self._build_graph_list( - x, pos, edge_index, edge_attr, additional_params) + self._build_graph_list(x, pos, edge_index, edge_attr, additional_params) - def _build_graph_list(self, x, pos, edge_index, edge_attr, - additional_params): + def _build_graph_list( + self, x, pos, edge_index, edge_attr, additional_params + ): for i, (x_, pos_, edge_index_) in enumerate(zip(x, pos, edge_index)): add_params_local = {k: v[i] for k, v in additional_params.items()} if edge_attr is not None: - self.data.append(Data(x=x_, pos=pos_, edge_index=edge_index_, - edge_attr=edge_attr[i], - **add_params_local)) + self.data.append( + Data( + x=x_, + pos=pos_, + edge_index=edge_index_, + edge_attr=edge_attr[i], + **add_params_local, + ) + ) else: - self.data.append(Data(x=x_, pos=pos_, edge_index=edge_index_, - **add_params_local)) + self.data.append( + Data( + x=x_, + pos=pos_, + edge_index=edge_index_, + **add_params_local, + ) + ) @staticmethod def _build_edge_attr(x, pos, edge_index): - distance = torch.abs(pos[edge_index[0]] - - pos[edge_index[1]]).as_subclass(torch.Tensor) + distance = torch.abs( + pos[edge_index[0]] - pos[edge_index[1]] + ).as_subclass(torch.Tensor) return distance @staticmethod @@ -147,32 +162,39 @@ class Graph: # If x is a 3D tensor, we split it into a list of 2D tensors if isinstance(x, torch.Tensor) and x.ndim == 3: x = [x[i] for i in range(x.shape[0])] - elif (not (isinstance(x, list) and all(t.ndim == 2 for t in x)) and - not (isinstance(x, torch.Tensor) and x.ndim == 2)): - raise TypeError("x must be either a list of 2D tensors or a 2D " - "tensor or a 3D tensor") + elif not (isinstance(x, list) and all(t.ndim == 2 for t in x)) and not ( + isinstance(x, torch.Tensor) and x.ndim == 2 + ): + raise TypeError( + "x must be either a list of 2D tensors or a 2D " + "tensor or a 3D tensor" + ) # If pos is a 3D tensor, we split it into a list of 2D tensors if isinstance(pos, torch.Tensor) and pos.ndim == 3: pos = [pos[i] for i in range(pos.shape[0])] - elif not (isinstance(pos, list) and all( - t.ndim == 2 for t in pos)) and not ( - isinstance(pos, torch.Tensor) and pos.ndim == 2): - raise TypeError("pos must be either a list of 2D tensors or a 2D " - "tensor or a 3D tensor") + elif not ( + isinstance(pos, list) and all(t.ndim == 2 for t in pos) + ) and not (isinstance(pos, torch.Tensor) and pos.ndim == 2): + raise TypeError( + "pos must be either a list of 2D tensors or a 2D " + "tensor or a 3D tensor" + ) # If edge_index is a 3D tensor, we split it into a list of 2D tensors if edge_index is not None: if isinstance(edge_index, torch.Tensor) and edge_index.ndim == 3: - edge_index = [edge_index[i] - for i in range(edge_index.shape[0])] - elif not (isinstance(edge_index, list) and all( - t.ndim == 2 for t in edge_index)) and not ( - isinstance(edge_index, - torch.Tensor) and edge_index.ndim == 2): + edge_index = [edge_index[i] for i in range(edge_index.shape[0])] + elif not ( + isinstance(edge_index, list) + and all(t.ndim == 2 for t in edge_index) + ) and not ( + isinstance(edge_index, torch.Tensor) and edge_index.ndim == 2 + ): raise TypeError( "edge_index must be either a list of 2D tensors or a 2D " - "tensor or a 3D tensor") + "tensor or a 3D tensor" + ) return x, pos, edge_index @@ -188,8 +210,9 @@ class Graph: # In this case there must be a additional parameter for each # node if val.ndim == 3: - additional_params[param] = [val[i] for i in - range(val.shape[0])] + additional_params[param] = [ + val[i] for i in range(val.shape[0]) + ] # If the tensor is 2D, we replicate it for each node elif val.ndim == 2: additional_params[param] = [val] * data_len @@ -197,44 +220,48 @@ class Graph: # additional parameter if val.ndim == 1: if len(val) == data_len: - additional_params[param] = [val[i] for i in - range(len(val))] + additional_params[param] = [ + val[i] for i in range(len(val)) + ] else: - additional_params[param] = [val for _ in - range(data_len)] + additional_params[param] = [ + val for _ in range(data_len) + ] elif not isinstance(val, list): - raise TypeError("additional_params values must be tensors " - "or lists of tensors.") + raise TypeError( + "additional_params values must be tensors " + "or lists of tensors." + ) else: additional_params = {} return additional_params - def _check_and_build_edge_attr(self, edge_attr, build_edge_attr, data_len, - edge_index, pos, x): + def _check_and_build_edge_attr( + self, edge_attr, build_edge_attr, data_len, edge_index, pos, x + ): # Check if edge_attr is consistent with x and pos if edge_attr is not None: if build_edge_attr is True: - warning("edge_attr is not None. build_edge_attr will not be " - "considered.") + warning( + "edge_attr is not None. build_edge_attr will not be " + "considered." + ) if isinstance(edge_attr, list): if len(edge_attr) != data_len: - raise TypeError("edge_attr must have the same length as x " - "and pos.") + raise TypeError( + "edge_attr must have the same length as x " "and pos." + ) return [edge_attr] * data_len if build_edge_attr: - return [self._build_edge_attr(x_, pos_, edge_index_) for - x_, pos_, edge_index_ in zip(x, pos, edge_index)] + return [ + self._build_edge_attr(x_, pos_, edge_index_) + for x_, pos_, edge_index_ in zip(x, pos, edge_index) + ] class RadiusGraph(Graph): - def __init__( - self, - x, - pos, - r, - **kwargs - ): + def __init__(self, x, pos, r, **kwargs): x, pos, edge_index = Graph._check_input_consistency(x, pos) if isinstance(pos, (torch.Tensor, LabelTensor)): @@ -242,8 +269,7 @@ class RadiusGraph(Graph): else: edge_index = [RadiusGraph._radius_graph(p, r) for p in pos] - super().__init__(x=x, pos=pos, edge_index=edge_index, - **kwargs) + super().__init__(x=x, pos=pos, edge_index=edge_index, **kwargs) @staticmethod def _radius_graph(points, r): @@ -264,20 +290,13 @@ class RadiusGraph(Graph): class KNNGraph(Graph): - def __init__( - self, - x, - pos, - k, - **kwargs - ): + def __init__(self, x, pos, k, **kwargs): x, pos, edge_index = Graph._check_input_consistency(x, pos) if isinstance(pos, (torch.Tensor, LabelTensor)): edge_index = KNNGraph._knn_graph(pos, k) else: edge_index = [KNNGraph._knn_graph(p, k) for p in pos] - super().__init__(x=x, pos=pos, edge_index=edge_index, - **kwargs) + super().__init__(x=x, pos=pos, edge_index=edge_index, **kwargs) @staticmethod def _knn_graph(points, k): diff --git a/pina/label_tensor.py b/pina/label_tensor.py index d8f66f9..4c30f4b 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -1,4 +1,5 @@ -""" Module for LabelTensor """ +"""Module for LabelTensor""" + from copy import copy, deepcopy import torch from torch import Tensor @@ -43,7 +44,7 @@ class LabelTensor(torch.Tensor): :rtype: list """ if self.ndim - 1 in self._labels.keys(): - return self._labels[self.ndim - 1]['dof'] + return self._labels[self.ndim - 1]["dof"] @property def full_labels(self): @@ -58,7 +59,7 @@ class LabelTensor(torch.Tensor): if i in self._labels.keys(): to_return_dict[i] = self._labels[i] else: - to_return_dict[i] = {'dof': range(shape_tensor[i]), 'name': i} + to_return_dict[i] = {"dof": range(shape_tensor[i]), "name": i} return to_return_dict @property @@ -72,13 +73,13 @@ class LabelTensor(torch.Tensor): @labels.setter def labels(self, labels): - """" + """ " Set properly the parameter _labels :param labels: Labels to assign to the class variable _labels. :type: labels: str | list(str) | dict """ - if not hasattr(self, '_labels'): + if not hasattr(self, "_labels"): self._labels = {} if isinstance(labels, dict): self._init_labels_from_dict(labels) @@ -109,27 +110,34 @@ class LabelTensor(torch.Tensor): if len(dof_list) != dim_size: raise ValueError( f"Number of dof ({len(dof_list)}) does not match " - f"tensor shape ({dim_size})") + f"tensor shape ({dim_size})" + ) for dim, label in labels.items(): if isinstance(label, dict): - if 'name' not in label: - label['name'] = dim - if 'dof' not in label: - label['dof'] = range(tensor_shape[dim]) - if 'dof' in label and 'name' in label: - dof = label['dof'] + if "name" not in label: + label["name"] = dim + if "dof" not in label: + label["dof"] = range(tensor_shape[dim]) + if "dof" in label and "name" in label: + dof = label["dof"] dof_list = dof if isinstance(dof, (list, range)) else [dof] if not isinstance(dof_list, (list, range)): - raise ValueError(f"'dof' should be a list or range, not" - f" {type(dof_list)}") + raise ValueError( + f"'dof' should be a list or range, not" + f" {type(dof_list)}" + ) validate_dof(dof_list, tensor_shape[dim]) else: - raise ValueError("Labels dictionary must contain either " - " both 'name' and 'dof' keys") + raise ValueError( + "Labels dictionary must contain either " + " both 'name' and 'dof' keys" + ) else: - raise ValueError(f"Invalid label format for {dim}: Expected " - f"list or dictionary, got {type(label)}") + raise ValueError( + f"Invalid label format for {dim}: Expected " + f"list or dictionary, got {type(label)}" + ) # Assign validated label data to internal labels self._labels[dim] = label @@ -144,10 +152,7 @@ class LabelTensor(torch.Tensor): """ # Create a dict with labels last_dim_labels = { - self.ndim - 1: { - 'dof': labels, - 'name': self.ndim - 1 - } + self.ndim - 1: {"dof": labels, "name": self.ndim - 1} } self._init_labels_from_dict(last_dim_labels) @@ -165,9 +170,14 @@ class LabelTensor(torch.Tensor): def get_label_indices(dim_labels, labels_te): if isinstance(labels_te, (int, str)): labels_te = [labels_te] - return [dim_labels.index(label) for label in labels_te] if len( - labels_te) > 1 else slice(dim_labels.index(labels_te[0]), - dim_labels.index(labels_te[0]) + 1) + return ( + [dim_labels.index(label) for label in labels_te] + if len(labels_te) > 1 + else slice( + dim_labels.index(labels_te[0]), + dim_labels.index(labels_te[0]) + 1, + ) + ) # Ensure labels_to_extract is a list or dict if isinstance(labels_to_extract, (str, int)): @@ -176,37 +186,39 @@ class LabelTensor(torch.Tensor): labels = copy(self._labels) # Get the dimension names and the respective dimension index - dim_names = {labels[dim]['name']: dim for dim in labels.keys()} + dim_names = {labels[dim]["name"]: dim for dim in labels.keys()} ndim = super().ndim tensor = self.tensor.as_subclass(torch.Tensor) # Convert list/tuple to a dict for the last dimension if applicable if isinstance(labels_to_extract, (list, tuple)): last_dim = ndim - 1 - dim_name = labels[last_dim]['name'] + dim_name = labels[last_dim]["name"] labels_to_extract = {dim_name: list(labels_to_extract)} # Validate the labels_to_extract type if not isinstance(labels_to_extract, dict): raise ValueError( - "labels_to_extract must be a string, list, or dictionary.") + "labels_to_extract must be a string, list, or dictionary." + ) # Perform the extraction for each specified dimension for dim_name, labels_te in labels_to_extract.items(): if dim_name not in dim_names: raise ValueError( f"Cannot extract labels for dimension '{dim_name}' as it is" - f" not present in the original labels.") + f" not present in the original labels." + ) idx_dim = dim_names[dim_name] - dim_labels = labels[idx_dim]['dof'] + dim_labels = labels[idx_dim]["dof"] indices = get_label_indices(dim_labels, labels_te) extractor = [slice(None)] * ndim extractor[idx_dim] = indices tensor = tensor[tuple(extractor)] - labels[idx_dim] = {'dof': labels_te, 'name': dim_name} + labels[idx_dim] = {"dof": labels_te, "name": dim_name} return LabelTensor(tensor, labels) @@ -214,10 +226,10 @@ class LabelTensor(torch.Tensor): """ returns a string with the representation of the class """ - s = '' + s = "" for key, value in self._labels.items(): s += f"{key}: {value}\n" - s += '\n' + s += "\n" s += self.tensor.__str__() return s @@ -249,11 +261,14 @@ class LabelTensor(torch.Tensor): # concatenation dimension for key in tensors_labels[0].keys(): if key != dim: - if any(tensors_labels[i][key] != tensors_labels[0][key] - for i in range(len(tensors_labels))): + if any( + tensors_labels[i][key] != tensors_labels[0][key] + for i in range(len(tensors_labels)) + ): raise RuntimeError( f"Tensors must have the same labels along all " - f"dimensions except {dim}.") + f"dimensions except {dim}." + ) # Copy and update the 'dof' for the concatenation dimension cat_labels = {k: copy(v) for k, v in tensors_labels[0].items()} @@ -261,9 +276,8 @@ class LabelTensor(torch.Tensor): # Update labels if the concatenation dimension has labels if dim in tensors[0].stored_labels: if dim in cat_labels: - cat_dofs = [label[dim]['dof'] for label in - tensors_labels] - cat_labels[dim]['dof'] = sum(cat_dofs, []) + cat_dofs = [label[dim]["dof"] for label in tensors_labels] + cat_labels[dim]["dof"] = sum(cat_dofs, []) else: cat_labels = tensors[0].stored_labels @@ -330,26 +344,30 @@ class LabelTensor(torch.Tensor): :return: A copy of the tensor. :rtype: LabelTensor """ - out = LabelTensor(super().clone(*args, **kwargs), - deepcopy(self._labels)) + out = LabelTensor( + super().clone(*args, **kwargs), deepcopy(self._labels) + ) return out - def append(self, tensor, mode='std'): - if mode == 'std': + def append(self, tensor, mode="std"): + if mode == "std": # Call cat on last dimension - new_label_tensor = LabelTensor.cat([self, tensor], - dim=self.ndim - 1) - elif mode == 'cross': + new_label_tensor = LabelTensor.cat( + [self, tensor], dim=self.ndim - 1 + ) + elif mode == "cross": # Crete tensor and call cat on last dimension tensor1 = self tensor2 = tensor n1 = tensor1.shape[0] n2 = tensor2.shape[0] tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels) - tensor2 = LabelTensor(tensor2.repeat_interleave(n1, dim=0), - labels=tensor2.labels) - new_label_tensor = LabelTensor.cat([tensor1, tensor2], - dim=self.ndim - 1) + tensor2 = LabelTensor( + tensor2.repeat_interleave(n1, dim=0), labels=tensor2.labels + ) + new_label_tensor = LabelTensor.cat( + [tensor1, tensor2], dim=self.ndim - 1 + ) else: raise ValueError('mode must be either "std" or "cross"') return new_label_tensor @@ -368,8 +386,9 @@ class LabelTensor(torch.Tensor): return LabelTensor.cat(label_tensors, dim=0) # This method is used to update labels - def _update_single_label(self, old_labels, to_update_labels, index, dim, - to_update_dim): + def _update_single_label( + self, old_labels, to_update_labels, index, dim, to_update_dim + ): """ Update the labels of the tensor by selecting only the labels :param old_labels: labels from which retrieve data @@ -378,24 +397,29 @@ class LabelTensor(torch.Tensor): :param dim: label index :return: """ - old_dof = old_labels[to_update_dim]['dof'] - label_name = old_labels[dim]['name'] + old_dof = old_labels[to_update_dim]["dof"] + label_name = old_labels[dim]["name"] # Handle slicing if isinstance(index, slice): - to_update_labels[dim] = {'dof': old_dof[index], 'name': label_name} + to_update_labels[dim] = {"dof": old_dof[index], "name": label_name} # Handle single integer index elif isinstance(index, int): - to_update_labels[dim] = {'dof': [old_dof[index]], - 'name': label_name} + to_update_labels[dim] = { + "dof": [old_dof[index]], + "name": label_name, + } # Handle lists or tensors elif isinstance(index, (list, torch.Tensor)): # Handle list of bools if isinstance(index, torch.Tensor) and index.dtype == torch.bool: index = index.nonzero().squeeze() to_update_labels[dim] = { - 'dof': [old_dof[i] for i in index] if isinstance(old_dof, - list) else index, - 'name': label_name + "dof": ( + [old_dof[i] for i in index] + if isinstance(old_dof, list) + else index + ), + "name": label_name, } else: raise NotImplementedError( @@ -404,7 +428,7 @@ class LabelTensor(torch.Tensor): ) def __getitem__(self, index): - """" + """ " Override the __getitem__ method to handle the labels of the tensor. Perform the __getitem__ operation on the tensor and update the labels. @@ -416,8 +440,10 @@ class LabelTensor(torch.Tensor): :raises IndexError: If an invalid index is accessed in the tensor. """ # Handle string index - if isinstance(index, str) or (isinstance(index, (tuple, list)) and all( - isinstance(i, str) for i in index)): + if isinstance(index, str) or ( + isinstance(index, (tuple, list)) + and all(isinstance(i, str) for i in index) + ): return self.extract(index) # Retrieve selected tensor and labels @@ -436,8 +462,9 @@ class LabelTensor(torch.Tensor): if isinstance(idx, int): selected_tensor = selected_tensor.unsqueeze(dim) if idx != slice(None): - self._update_single_label(original_labels, updated_labels, - idx, dim, offset) + self._update_single_label( + original_labels, updated_labels, idx, dim, offset + ) else: # Adjust label keys if dimension is reduced (case of integer # index on a non-labeled dimension) @@ -472,7 +499,7 @@ class LabelTensor(torch.Tensor): dim = self.ndim - 1 if self.shape[dim] == 1: return self - labels = self.stored_labels[dim]['dof'] + labels = self.stored_labels[dim]["dof"] sorted_index = arg_sort(labels) # Define an indexer to sort the tensor along the specified dimension indexer = [slice(None)] * self.ndim @@ -509,10 +536,7 @@ class LabelTensor(torch.Tensor): # Update lables labels = self._labels keys_list = list(*dims) - labels = { - keys_list.index(k): labels[k] - for k in labels.keys() - } + labels = {keys_list.index(k): labels[k] for k in labels.keys()} # Assign labels to the new tensor tensor._labels = labels @@ -550,7 +574,7 @@ class LabelTensor(torch.Tensor): """ if not tensors: - raise ValueError('The tensors list must not be empty.') + raise ValueError("The tensors list must not be empty.") if len(tensors) == 1: return tensors[0] @@ -565,13 +589,13 @@ class LabelTensor(torch.Tensor): last_dim_labels.append(tensor.labels) # Construct last dimension labels - last_dim_labels = ['+'.join(items) for items in zip(*last_dim_labels)] + last_dim_labels = ["+".join(items) for items in zip(*last_dim_labels)] # Update the labels for the resulting tensor labels = {k: copy(v) for k, v in tensors[0].stored_labels.items()} labels[tensors[0].ndim - 1] = { - 'dof': last_dim_labels, - 'name': tensors[0].name + "dof": last_dim_labels, + "name": tensors[0].name, } return LabelTensor(data, labels) diff --git a/pina/loss/__init__.py b/pina/loss/__init__.py index a617e64..a4d7f69 100644 --- a/pina/loss/__init__.py +++ b/pina/loss/__init__.py @@ -1,9 +1,9 @@ __all__ = [ - 'LossInterface', - 'LpLoss', - 'PowerLoss', - 'WeightingInterface', - 'ScalarWeighting' + "LossInterface", + "LpLoss", + "PowerLoss", + "WeightingInterface", + "ScalarWeighting", ] from .loss_interface import LossInterface diff --git a/pina/loss/loss_interface.py b/pina/loss/loss_interface.py index 5093a65..b6b4dc1 100644 --- a/pina/loss/loss_interface.py +++ b/pina/loss/loss_interface.py @@ -1,4 +1,4 @@ -""" Module for Loss Interface """ +"""Module for Loss Interface""" from abc import ABCMeta, abstractmethod from torch.nn.modules.loss import _Loss @@ -58,4 +58,4 @@ class LossInterface(_Loss, metaclass=ABCMeta): ret = torch.sum(loss, keepdim=True, dim=-1) else: raise ValueError(self.reduction + " is not valid") - return ret \ No newline at end of file + return ret diff --git a/pina/loss/lp_loss.py b/pina/loss/lp_loss.py index 978efa8..b39b16e 100644 --- a/pina/loss/lp_loss.py +++ b/pina/loss/lp_loss.py @@ -1,10 +1,11 @@ -""" Module for LpLoss class """ +"""Module for LpLoss class""" import torch from ..utils import check_consistency from .loss_interface import LossInterface + class LpLoss(LossInterface): r""" The Lp loss implementation class. Creates a criterion that measures @@ -75,4 +76,4 @@ class LpLoss(LossInterface): loss = torch.linalg.norm((input - target), ord=self.p, dim=-1) if self.relative: loss = loss / torch.linalg.norm(input, ord=self.p, dim=-1) - return self._reduction(loss) \ No newline at end of file + return self._reduction(loss) diff --git a/pina/loss/power_loss.py b/pina/loss/power_loss.py index 4f3fc65..09bf94a 100644 --- a/pina/loss/power_loss.py +++ b/pina/loss/power_loss.py @@ -1,4 +1,4 @@ -""" Module for PowerLoss class """ +"""Module for PowerLoss class""" import torch @@ -76,4 +76,4 @@ class PowerLoss(LossInterface): loss = torch.abs((input - target)).pow(self.p).mean(-1) if self.relative: loss = loss / torch.abs(input).pow(self.p).mean(-1) - return self._reduction(loss) \ No newline at end of file + return self._reduction(loss) diff --git a/pina/loss/scalar_weighting.py b/pina/loss/scalar_weighting.py index 4248627..3273dea 100644 --- a/pina/loss/scalar_weighting.py +++ b/pina/loss/scalar_weighting.py @@ -1,4 +1,4 @@ -""" Module for Loss Interface """ +"""Module for Loss Interface""" from .weighting_interface import WeightingInterface from ..utils import check_consistency @@ -8,10 +8,12 @@ class _NoWeighting(WeightingInterface): def aggregate(self, losses): return sum(losses.values()) + class ScalarWeighting(WeightingInterface): """ TODO """ + def __init__(self, weights): super().__init__() check_consistency([weights], (float, dict, int)) @@ -31,6 +33,6 @@ class ScalarWeighting(WeightingInterface): :rtype: torch.Tensor """ return sum( - self.weights.get(condition, self.default_value_weights) * loss for - condition, loss in losses.items() + self.weights.get(condition, self.default_value_weights) * loss + for condition, loss in losses.items() ) diff --git a/pina/loss/weighting_interface.py b/pina/loss/weighting_interface.py index 982abf5..d8ce4b5 100644 --- a/pina/loss/weighting_interface.py +++ b/pina/loss/weighting_interface.py @@ -1,4 +1,4 @@ -""" Module for Loss Interface """ +"""Module for Loss Interface""" from abc import ABCMeta, abstractmethod diff --git a/pina/model/__init__.py b/pina/model/__init__.py index 0b1f2df..502e15d 100644 --- a/pina/model/__init__.py +++ b/pina/model/__init__.py @@ -10,7 +10,7 @@ __all__ = [ "AveragingNeuralOperator", "LowRankNeuralOperator", "Spline", - "GraphNeuralOperator" + "GraphNeuralOperator", ] from .feed_forward import FeedForward, ResidualFeedForward @@ -21,4 +21,4 @@ from .kernel_neural_operator import KernelNeuralOperator from .average_neural_operator import AveragingNeuralOperator from .low_rank_neural_operator import LowRankNeuralOperator from .spline import Spline -from .graph_neural_operator import GraphNeuralOperator \ No newline at end of file +from .graph_neural_operator import GraphNeuralOperator diff --git a/pina/model/block/__init__.py b/pina/model/block/__init__.py index fd5a266..64fb150 100644 --- a/pina/model/block/__init__.py +++ b/pina/model/block/__init__.py @@ -15,7 +15,7 @@ __all__ = [ "AVNOBlock", "LowRankBlock", "RBFBlock", - "GNOBlock" + "GNOBlock", ] from .convolution_2d import ContinuousConvBlock diff --git a/pina/model/block/average_neural_operator_block.py b/pina/model/block/average_neural_operator_block.py index 256dbbb..fd682a5 100644 --- a/pina/model/block/average_neural_operator_block.py +++ b/pina/model/block/average_neural_operator_block.py @@ -1,4 +1,4 @@ -""" Module for Averaging Neural Operator Layer class. """ +"""Module for Averaging Neural Operator Layer class.""" from torch import nn, mean from ...utils import check_consistency diff --git a/pina/model/block/embedding.py b/pina/model/block/embedding.py index 4248136..77e340d 100644 --- a/pina/model/block/embedding.py +++ b/pina/model/block/embedding.py @@ -1,4 +1,4 @@ -""" Embedding modulus. """ +"""Embedding modulus.""" import torch from pina.utils import check_consistency diff --git a/pina/model/block/gno_block.py b/pina/model/block/gno_block.py index 34929fe..f391324 100644 --- a/pina/model/block/gno_block.py +++ b/pina/model/block/gno_block.py @@ -8,14 +8,14 @@ class GNOBlock(MessagePassing): """ def __init__( - self, - width, - edges_features, - n_layers=2, - layers=None, - inner_size=None, - internal_func=None, - external_func=None + self, + width, + edges_features, + n_layers=2, + layers=None, + inner_size=None, + internal_func=None, + external_func=None, ): """ Initialize the Graph Integral Layer, inheriting from the MessagePassing class of PyTorch Geometric. @@ -28,16 +28,19 @@ class GNOBlock(MessagePassing): :type n_layers: int """ from pina.model import FeedForward - super(GNOBlock, self).__init__(aggr='mean') + + super(GNOBlock, self).__init__(aggr="mean") self.width = width if layers is None and inner_size is None: inner_size = width - self.dense = FeedForward(input_dimensions=edges_features, - output_dimensions=width ** 2, - n_layers=n_layers, - layers=layers, - inner_size=inner_size, - func=internal_func) + self.dense = FeedForward( + input_dimensions=edges_features, + output_dimensions=width**2, + n_layers=n_layers, + layers=layers, + inner_size=inner_size, + func=internal_func, + ) self.W = torch.nn.Linear(width, width) self.func = external_func() @@ -53,7 +56,7 @@ class GNOBlock(MessagePassing): :rtype: torch.Tensor """ x = self.dense(edge_attr).view(-1, self.width, self.width) - return torch.einsum('bij,bj->bi', x, x_j) + return torch.einsum("bij,bj->bi", x, x_j) def update(self, aggr_out, x): """ @@ -82,6 +85,4 @@ class GNOBlock(MessagePassing): :return: Output of a single iteration over the Graph Integral Layer. :rtype: torch.Tensor """ - return self.func( - self.propagate(edge_index, x=x, edge_attr=edge_attr) - ) + return self.func(self.propagate(edge_index, x=x, edge_attr=edge_attr)) diff --git a/pina/model/block/low_rank_block.py b/pina/model/block/low_rank_block.py index c36d279..dfb6864 100644 --- a/pina/model/block/low_rank_block.py +++ b/pina/model/block/low_rank_block.py @@ -1,4 +1,4 @@ -""" Module for Averaging Neural Operator Layer class. """ +"""Module for Averaging Neural Operator Layer class.""" import torch diff --git a/pina/model/graph_neural_operator.py b/pina/model/graph_neural_operator.py index 1932953..0e3a6d8 100644 --- a/pina/model/graph_neural_operator.py +++ b/pina/model/graph_neural_operator.py @@ -10,16 +10,16 @@ class GraphNeuralKernel(torch.nn.Module): """ def __init__( - self, - width, - edge_features, - n_layers=2, - internal_n_layers=0, - internal_layers=None, - inner_size=None, - internal_func=None, - external_func=None, - shared_weights=False + self, + width, + edge_features, + n_layers=2, + internal_n_layers=0, + internal_layers=None, + inner_size=None, + internal_func=None, + external_func=None, + shared_weights=False, ): """ The Graph Neural Kernel constructor. @@ -53,21 +53,24 @@ class GraphNeuralKernel(torch.nn.Module): layers=internal_layers, inner_size=inner_size, internal_func=internal_func, - external_func=external_func) + external_func=external_func, + ) self.n_layers = n_layers self.forward = self.forward_shared else: self.layers = torch.nn.ModuleList( - [GNOBlock( - width=width, - edges_features=edge_features, - n_layers=internal_n_layers, - layers=internal_layers, - inner_size=inner_size, - internal_func=internal_func, - external_func=external_func - ) - for _ in range(n_layers)] + [ + GNOBlock( + width=width, + edges_features=edge_features, + n_layers=internal_n_layers, + layers=internal_layers, + inner_size=inner_size, + internal_func=internal_func, + external_func=external_func, + ) + for _ in range(n_layers) + ] ) def forward(self, x, edge_index, edge_attr): @@ -107,17 +110,17 @@ class GraphNeuralOperator(KernelNeuralOperator): """ def __init__( - self, - lifting_operator, - projection_operator, - edge_features, - n_layers=10, - internal_n_layers=0, - inner_size=None, - internal_layers=None, - internal_func=None, - external_func=None, - shared_weights=True + self, + lifting_operator, + projection_operator, + edge_features, + n_layers=10, + internal_n_layers=0, + inner_size=None, + internal_layers=None, + internal_func=None, + external_func=None, + shared_weights=True, ): """ The Graph Neural Operator constructor. @@ -158,9 +161,9 @@ class GraphNeuralOperator(KernelNeuralOperator): external_func=external_func, internal_func=internal_func, n_layers=n_layers, - shared_weights=shared_weights + shared_weights=shared_weights, ), - projection_operator=projection_operator + projection_operator=projection_operator, ) def forward(self, x): diff --git a/pina/model/layers/__init__.py b/pina/model/layers/__init__.py index 15e0851..dcf63dc 100644 --- a/pina/model/layers/__init__.py +++ b/pina/model/layers/__init__.py @@ -8,6 +8,7 @@ from ...utils import custom_warning_format warnings.formatwarning = custom_warning_format warnings.filterwarnings("always", category=DeprecationWarning) warnings.warn( - f"'pina.model.layers' is deprecated and will be removed " - f"in future versions. Please use 'pina.model.block' instead.", - DeprecationWarning) \ No newline at end of file + f"'pina.model.layers' is deprecated and will be removed " + f"in future versions. Please use 'pina.model.block' instead.", + DeprecationWarning, +) diff --git a/pina/operator.py b/pina/operator.py index a4388f4..85ebf9d 100644 --- a/pina/operator.py +++ b/pina/operator.py @@ -5,6 +5,7 @@ All operator take as input a tensor onto which computing the operator, a tensor to which computing the operator, the name of the output variables to calculate the operator for (in case of multidimensional functions), and the variables name on which the operator is calculated. """ + import torch from pina.label_tensor import LabelTensor @@ -56,9 +57,9 @@ def grad(output_, input_, components=None, d=None): gradients = torch.autograd.grad( output_, input_, - grad_outputs=torch.ones(output_.size(), - dtype=output_.dtype, - device=output_.device), + grad_outputs=torch.ones( + output_.size(), dtype=output_.dtype, device=output_.device + ), create_graph=True, retain_graph=True, allow_unused=True, @@ -83,8 +84,9 @@ def grad(output_, input_, components=None, d=None): raise RuntimeError gradients = grad_scalar_output(output_, input_, d) - elif output_.shape[output_.ndim - - 1] >= 2: # vector output ############################## + elif ( + output_.shape[output_.ndim - 1] >= 2 + ): # vector output ############################## tensor_to_cat = [] for i, c in enumerate(components): c_output = output_.extract([c]) @@ -253,8 +255,11 @@ def advection(output_, input_, velocity_field, components=None, d=None): if components is None: components = output_.labels - tmp = (grad(output_, input_, components, d).reshape(-1, len(components), - len(d)).transpose(0, 1)) + tmp = ( + grad(output_, input_, components, d) + .reshape(-1, len(components), len(d)) + .transpose(0, 1) + ) tmp *= output_.extract(velocity_field) return tmp.sum(dim=2).T diff --git a/pina/operators.py b/pina/operators.py index dadebb1..5e3e838 100644 --- a/pina/operators.py +++ b/pina/operators.py @@ -8,6 +8,7 @@ from .utils import custom_warning_format warnings.formatwarning = custom_warning_format warnings.filterwarnings("always", category=DeprecationWarning) warnings.warn( - f"'pina.operators' is deprecated and will be removed " - f"in future versions. Please use 'pina.operator' instead.", - DeprecationWarning) \ No newline at end of file + f"'pina.operators' is deprecated and will be removed " + f"in future versions. Please use 'pina.operator' instead.", + DeprecationWarning, +) diff --git a/pina/optim/__init__.py b/pina/optim/__init__.py index 699706c..631134a 100644 --- a/pina/optim/__init__.py +++ b/pina/optim/__init__.py @@ -8,4 +8,4 @@ __all__ = [ from .optimizer_interface import Optimizer from .torch_optimizer import TorchOptimizer from .scheduler_interface import Scheduler -from .torch_scheduler import TorchScheduler \ No newline at end of file +from .torch_scheduler import TorchScheduler diff --git a/pina/optim/optimizer_interface.py b/pina/optim/optimizer_interface.py index e20e747..0d197ea 100644 --- a/pina/optim/optimizer_interface.py +++ b/pina/optim/optimizer_interface.py @@ -1,4 +1,4 @@ -""" Module for PINA Optimizer """ +"""Module for PINA Optimizer""" from abc import ABCMeta, abstractmethod @@ -12,4 +12,4 @@ class Optimizer(metaclass=ABCMeta): # TODO improve interface @abstractmethod def hook(self): - pass \ No newline at end of file + pass diff --git a/pina/optim/scheduler_interface.py b/pina/optim/scheduler_interface.py index 51fbbd4..1cae521 100644 --- a/pina/optim/scheduler_interface.py +++ b/pina/optim/scheduler_interface.py @@ -1,9 +1,9 @@ -""" Module for PINA Optimizer """ +"""Module for PINA Optimizer""" from abc import ABCMeta, abstractmethod -class Scheduler(metaclass=ABCMeta): # TODO improve interface +class Scheduler(metaclass=ABCMeta): # TODO improve interface @property @abstractmethod @@ -12,4 +12,4 @@ class Scheduler(metaclass=ABCMeta): # TODO improve interface @abstractmethod def hook(self): - pass \ No newline at end of file + pass diff --git a/pina/optim/torch_optimizer.py b/pina/optim/torch_optimizer.py index 4be7b47..02b8920 100644 --- a/pina/optim/torch_optimizer.py +++ b/pina/optim/torch_optimizer.py @@ -1,4 +1,4 @@ -""" Module for PINA Torch Optimizer """ +"""Module for PINA Torch Optimizer""" import torch @@ -16,8 +16,10 @@ class TorchOptimizer(Optimizer): self._optimizer_instance = None def hook(self, parameters): - self._optimizer_instance = self.optimizer_class(parameters, - **self.kwargs) + self._optimizer_instance = self.optimizer_class( + parameters, **self.kwargs + ) + @property def instance(self): """ diff --git a/pina/optim/torch_scheduler.py b/pina/optim/torch_scheduler.py index 2700a08..bf8daec 100644 --- a/pina/optim/torch_scheduler.py +++ b/pina/optim/torch_scheduler.py @@ -1,11 +1,13 @@ -""" Module for PINA Torch Optimizer """ +"""Module for PINA Torch Optimizer""" import torch + try: from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0 except ImportError: from torch.optim.lr_scheduler import ( - _LRScheduler as LRScheduler, ) # torch < 2.0 + _LRScheduler as LRScheduler, + ) # torch < 2.0 from ..utils import check_consistency from .optimizer_interface import Optimizer @@ -24,7 +26,8 @@ class TorchScheduler(Scheduler): def hook(self, optimizer): check_consistency(optimizer, Optimizer) self._scheduler_instance = self.scheduler_class( - optimizer.instance, **self.kwargs) + optimizer.instance, **self.kwargs + ) @property def instance(self): diff --git a/pina/problem/abstract_problem.py b/pina/problem/abstract_problem.py index 2c2c8de..9cdef60 100644 --- a/pina/problem/abstract_problem.py +++ b/pina/problem/abstract_problem.py @@ -1,4 +1,4 @@ -""" Module for AbstractProblem class """ +"""Module for AbstractProblem class""" from abc import ABCMeta, abstractmethod from ..utils import check_consistency @@ -60,7 +60,7 @@ class AbstractProblem(metaclass=ABCMeta): elif hasattr(cond, "domain"): to_return[cond_name] = self._discretised_domains[cond.domain] return to_return - + @property def discretised_domains(self): return self._discretised_domains @@ -138,11 +138,9 @@ class AbstractProblem(metaclass=ABCMeta): """ return self.conditions - def discretise_domain(self, - n=None, - mode="random", - domains="all", - sample_rules=None): + def discretise_domain( + self, n=None, mode="random", domains="all", sample_rules=None + ): """ Generate a set of points to span the `Location` of all the conditions of the problem. @@ -193,9 +191,7 @@ class AbstractProblem(metaclass=ABCMeta): "You can't specify both n and sample_rules at the same time." ) elif n is None and sample_rules is None: - raise RuntimeError( - "You have to specify either n or sample_rules." - ) + raise RuntimeError("You have to specify either n or sample_rules.") def _apply_default_discretization(self, n, mode, domains): for domain in domains: @@ -213,15 +209,17 @@ class AbstractProblem(metaclass=ABCMeta): if not isinstance(self.domains[domain], CartesianDomain): raise RuntimeError( "Custom discretisation can be applied only on Cartesian " - "domains") + "domains" + ) discretised_tensor = [] for var, rules in sample_rules.items(): - n, mode = rules['n'], rules['mode'] + n, mode = rules["n"], rules["mode"] points = self.domains[domain].sample(n, mode, var) discretised_tensor.append(points) self.discretised_domains[domain] = merge_tensors( - discretised_tensor).sort_labels() + discretised_tensor + ).sort_labels() def add_points(self, new_points_dict): """ @@ -232,4 +230,5 @@ class AbstractProblem(metaclass=ABCMeta): """ for k, v in new_points_dict.items(): self.discretised_domains[k] = LabelTensor.vstack( - [self.discretised_domains[k], v]) + [self.discretised_domains[k], v] + ) diff --git a/pina/problem/inverse_problem.py b/pina/problem/inverse_problem.py index 09c5981..7451e2b 100644 --- a/pina/problem/inverse_problem.py +++ b/pina/problem/inverse_problem.py @@ -1,4 +1,5 @@ """Module for the ParametricProblem class""" + import torch from abc import abstractmethod from .abstract_problem import AbstractProblem @@ -51,12 +52,9 @@ class InverseProblem(AbstractProblem): for i, var in enumerate(self.unknown_variables): range_var = self.unknown_parameter_domain.range_[var] tensor_var = ( - torch.rand(1, requires_grad=True) * range_var[1] - + range_var[0] - ) - self.unknown_parameters[var] = torch.nn.Parameter( - tensor_var + torch.rand(1, requires_grad=True) * range_var[1] + range_var[0] ) + self.unknown_parameters[var] = torch.nn.Parameter(tensor_var) @abstractmethod def unknown_parameter_domain(self): diff --git a/pina/problem/zoo/__init__.py b/pina/problem/zoo/__init__.py index 15e03e9..b10c0fb 100644 --- a/pina/problem/zoo/__init__.py +++ b/pina/problem/zoo/__init__.py @@ -1,9 +1,9 @@ __all__ = [ - 'Poisson2DSquareProblem', - 'SupervisedProblem', - 'InversePoisson2DSquareProblem', - 'DiffusionReactionProblem', - 'InverseDiffusionReactionProblem' + "Poisson2DSquareProblem", + "SupervisedProblem", + "InversePoisson2DSquareProblem", + "DiffusionReactionProblem", + "InverseDiffusionReactionProblem", ] from .poisson_2d_square import Poisson2DSquareProblem diff --git a/pina/problem/zoo/diffusion_reaction.py b/pina/problem/zoo/diffusion_reaction.py index 8bf6284..e7bc6c2 100644 --- a/pina/problem/zoo/diffusion_reaction.py +++ b/pina/problem/zoo/diffusion_reaction.py @@ -1,4 +1,4 @@ -""" Definition of the diffusion-reaction problem.""" +"""Definition of the diffusion-reaction problem.""" import torch from pina import Condition @@ -7,17 +7,22 @@ from pina.equation.equation import Equation from pina.domain import CartesianDomain from pina.operator import grad + def diffusion_reaction(input_, output_): """ Implementation of the diffusion-reaction equation. """ - x = input_.extract('x') - t = input_.extract('t') - u_t = grad(output_, input_, d='t') - u_x = grad(output_, input_, d='x') - u_xx = grad(u_x, input_, d='x') - r = torch.exp(-t) * (1.5 * torch.sin(2*x) + (8/3) * torch.sin(3*x) + - (15/4) * torch.sin(4*x) + (63/8) * torch.sin(8*x)) + x = input_.extract("x") + t = input_.extract("t") + u_t = grad(output_, input_, d="t") + u_x = grad(output_, input_, d="x") + u_xx = grad(u_x, input_, d="x") + r = torch.exp(-t) * ( + 1.5 * torch.sin(2 * x) + + (8 / 3) * torch.sin(3 * x) + + (15 / 4) * torch.sin(4 * x) + + (63 / 8) * torch.sin(8 * x) + ) return u_t - u_xx - r @@ -26,20 +31,25 @@ class DiffusionReactionProblem(TimeDependentProblem, SpatialProblem): Implementation of the diffusion-reaction problem on the spatial interval [-pi, pi] and temporal interval [0,1]. """ - output_variables = ['u'] - spatial_domain = CartesianDomain({'x': [-torch.pi, torch.pi]}) - temporal_domain = CartesianDomain({'t': [0, 1]}) + + output_variables = ["u"] + spatial_domain = CartesianDomain({"x": [-torch.pi, torch.pi]}) + temporal_domain = CartesianDomain({"t": [0, 1]}) conditions = { - 'D': Condition( - domain=CartesianDomain({'x': [-torch.pi, torch.pi], 't': [0, 1]}), - equation=Equation(diffusion_reaction)) + "D": Condition( + domain=CartesianDomain({"x": [-torch.pi, torch.pi], "t": [0, 1]}), + equation=Equation(diffusion_reaction), + ) } def _solution(self, pts): - t = pts.extract('t') - x = pts.extract('x') + t = pts.extract("t") + x = pts.extract("x") return torch.exp(-t) * ( - torch.sin(x) + (1/2)*torch.sin(2*x) + (1/3)*torch.sin(3*x) + - (1/4)*torch.sin(4*x) + (1/8)*torch.sin(8*x) + torch.sin(x) + + (1 / 2) * torch.sin(2 * x) + + (1 / 3) * torch.sin(3 * x) + + (1 / 4) * torch.sin(4 * x) + + (1 / 8) * torch.sin(8 * x) ) diff --git a/pina/problem/zoo/inverse_diffusion_reaction.py b/pina/problem/zoo/inverse_diffusion_reaction.py index 0758ce2..911f68e 100644 --- a/pina/problem/zoo/inverse_diffusion_reaction.py +++ b/pina/problem/zoo/inverse_diffusion_reaction.py @@ -1,4 +1,4 @@ -""" Definition of the diffusion-reaction problem.""" +"""Definition of the diffusion-reaction problem.""" import torch from pina import Condition, LabelTensor @@ -7,45 +7,57 @@ from pina.equation.equation import Equation from pina.domain import CartesianDomain from pina.operator import grad + def diffusion_reaction(input_, output_): """ Implementation of the diffusion-reaction equation. """ - x = input_.extract('x') - t = input_.extract('t') - u_t = grad(output_, input_, d='t') - u_x = grad(output_, input_, d='x') - u_xx = grad(u_x, input_, d='x') - r = torch.exp(-t) * (1.5 * torch.sin(2*x) + (8/3) * torch.sin(3*x) + - (15/4) * torch.sin(4*x) + (63/8) * torch.sin(8*x)) + x = input_.extract("x") + t = input_.extract("t") + u_t = grad(output_, input_, d="t") + u_x = grad(output_, input_, d="x") + u_xx = grad(u_x, input_, d="x") + r = torch.exp(-t) * ( + 1.5 * torch.sin(2 * x) + + (8 / 3) * torch.sin(3 * x) + + (15 / 4) * torch.sin(4 * x) + + (63 / 8) * torch.sin(8 * x) + ) return u_t - u_xx - r -class InverseDiffusionReactionProblem(TimeDependentProblem, - SpatialProblem, - InverseProblem): + +class InverseDiffusionReactionProblem( + TimeDependentProblem, SpatialProblem, InverseProblem +): """ - Implementation of the diffusion-reaction inverse problem on the spatial - interval [-pi, pi] and temporal interval [0,1], with unknown parameters + Implementation of the diffusion-reaction inverse problem on the spatial + interval [-pi, pi] and temporal interval [0,1], with unknown parameters in the interval [-1,1]. """ - output_variables = ['u'] - spatial_domain = CartesianDomain({'x': [-torch.pi, torch.pi]}) - temporal_domain = CartesianDomain({'t': [0, 1]}) - unknown_parameter_domain = CartesianDomain({'mu': [-1, 1]}) + + output_variables = ["u"] + spatial_domain = CartesianDomain({"x": [-torch.pi, torch.pi]}) + temporal_domain = CartesianDomain({"t": [0, 1]}) + unknown_parameter_domain = CartesianDomain({"mu": [-1, 1]}) conditions = { - 'D': Condition( - domain=CartesianDomain({'x': [-torch.pi, torch.pi], 't': [0, 1]}), - equation=Equation(diffusion_reaction)), - 'data' : Condition( - input_points=LabelTensor(torch.randn(10, 2), ['x', 't']), - output_points=LabelTensor(torch.randn(10, 1), ['u'])), + "D": Condition( + domain=CartesianDomain({"x": [-torch.pi, torch.pi], "t": [0, 1]}), + equation=Equation(diffusion_reaction), + ), + "data": Condition( + input_points=LabelTensor(torch.randn(10, 2), ["x", "t"]), + output_points=LabelTensor(torch.randn(10, 1), ["u"]), + ), } def _solution(self, pts): - t = pts.extract('t') - x = pts.extract('x') + t = pts.extract("t") + x = pts.extract("x") return torch.exp(-t) * ( - torch.sin(x) + (1/2)*torch.sin(2*x) + (1/3)*torch.sin(3*x) + - (1/4)*torch.sin(4*x) + (1/8)*torch.sin(8*x) + torch.sin(x) + + (1 / 2) * torch.sin(2 * x) + + (1 / 3) * torch.sin(3 * x) + + (1 / 4) * torch.sin(4 * x) + + (1 / 8) * torch.sin(8 * x) ) diff --git a/pina/problem/zoo/inverse_poisson_2d_square.py b/pina/problem/zoo/inverse_poisson_2d_square.py index 4e147ac..3a46334 100644 --- a/pina/problem/zoo/inverse_poisson_2d_square.py +++ b/pina/problem/zoo/inverse_poisson_2d_square.py @@ -1,4 +1,4 @@ -""" Definition of the inverse Poisson problem on a square domain.""" +"""Definition of the inverse Poisson problem on a square domain.""" import torch from pina import Condition, LabelTensor @@ -8,43 +8,49 @@ from pina.domain import CartesianDomain from pina.equation.equation import Equation from pina.equation.equation_factory import FixedValue + def laplace_equation(input_, output_, params_): """ Implementation of the laplace equation. """ - force_term = torch.exp(- 2*(input_.extract(['x']) - params_['mu1'])**2 - - 2*(input_.extract(['y']) - params_['mu2'])**2) - delta_u = laplacian(output_, input_, components=['u'], d=['x', 'y']) + force_term = torch.exp( + -2 * (input_.extract(["x"]) - params_["mu1"]) ** 2 + - 2 * (input_.extract(["y"]) - params_["mu2"]) ** 2 + ) + delta_u = laplacian(output_, input_, components=["u"], d=["x", "y"]) return delta_u - force_term + class InversePoisson2DSquareProblem(SpatialProblem, InverseProblem): """ - Implementation of the inverse 2-dimensional Poisson problem + Implementation of the inverse 2-dimensional Poisson problem on a square domain, with parameter domain [-1, 1] x [-1, 1]. """ - output_variables = ['u'] + + output_variables = ["u"] x_min, x_max = -2, 2 y_min, y_max = -2, 2 - data_input = LabelTensor(torch.rand(10, 2), ['x', 'y']) - data_output = LabelTensor(torch.rand(10, 1), ['u']) - spatial_domain = CartesianDomain({'x': [x_min, x_max], 'y': [y_min, y_max]}) - unknown_parameter_domain = CartesianDomain({'mu1': [-1, 1], 'mu2': [-1, 1]}) + data_input = LabelTensor(torch.rand(10, 2), ["x", "y"]) + data_output = LabelTensor(torch.rand(10, 1), ["u"]) + spatial_domain = CartesianDomain({"x": [x_min, x_max], "y": [y_min, y_max]}) + unknown_parameter_domain = CartesianDomain({"mu1": [-1, 1], "mu2": [-1, 1]}) domains = { - 'g1': CartesianDomain({'x': [x_min, x_max], 'y': y_max}), - 'g2': CartesianDomain({'x': [x_min, x_max], 'y': y_min}), - 'g3': CartesianDomain({'x': x_max, 'y': [y_min, y_max]}), - 'g4': CartesianDomain({'x': x_min, 'y': [y_min, y_max]}), - 'D': CartesianDomain({'x': [x_min, x_max], 'y': [y_min, y_max]}), + "g1": CartesianDomain({"x": [x_min, x_max], "y": y_max}), + "g2": CartesianDomain({"x": [x_min, x_max], "y": y_min}), + "g3": CartesianDomain({"x": x_max, "y": [y_min, y_max]}), + "g4": CartesianDomain({"x": x_min, "y": [y_min, y_max]}), + "D": CartesianDomain({"x": [x_min, x_max], "y": [y_min, y_max]}), } conditions = { - 'nil_g1': Condition(domain='g1', equation=FixedValue(0.0)), - 'nil_g2': Condition(domain='g2', equation=FixedValue(0.0)), - 'nil_g3': Condition(domain='g3', equation=FixedValue(0.0)), - 'nil_g4': Condition(domain='g4', equation=FixedValue(0.0)), - 'laplace_D': Condition(domain='D', equation=Equation(laplace_equation)), - 'data': Condition( - input_points=data_input.extract(['x', 'y']), - output_points=data_output) + "nil_g1": Condition(domain="g1", equation=FixedValue(0.0)), + "nil_g2": Condition(domain="g2", equation=FixedValue(0.0)), + "nil_g3": Condition(domain="g3", equation=FixedValue(0.0)), + "nil_g4": Condition(domain="g4", equation=FixedValue(0.0)), + "laplace_D": Condition(domain="D", equation=Equation(laplace_equation)), + "data": Condition( + input_points=data_input.extract(["x", "y"]), + output_points=data_output, + ), } diff --git a/pina/problem/zoo/poisson_2d_square.py b/pina/problem/zoo/poisson_2d_square.py index 1e161ba..89d9ee3 100644 --- a/pina/problem/zoo/poisson_2d_square.py +++ b/pina/problem/zoo/poisson_2d_square.py @@ -1,4 +1,4 @@ -""" Definition of the Poisson problem on a square domain.""" +"""Definition of the Poisson problem on a square domain.""" from pina.problem import SpatialProblem from pina.operator import laplacian @@ -8,41 +8,47 @@ from pina.equation.equation import Equation from pina.equation.equation_factory import FixedValue import torch + def laplace_equation(input_, output_): """ Implementation of the laplace equation. """ - force_term = (torch.sin(input_.extract(['x']) * torch.pi) * - torch.sin(input_.extract(['y']) * torch.pi)) - delta_u = laplacian(output_.extract(['u']), input_) + force_term = torch.sin(input_.extract(["x"]) * torch.pi) * torch.sin( + input_.extract(["y"]) * torch.pi + ) + delta_u = laplacian(output_.extract(["u"]), input_) return delta_u - force_term + my_laplace = Equation(laplace_equation) + class Poisson2DSquareProblem(SpatialProblem): """ Implementation of the 2-dimensional Poisson problem on a square domain. """ - output_variables = ['u'] - spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]}) + + output_variables = ["u"] + spatial_domain = CartesianDomain({"x": [0, 1], "y": [0, 1]}) domains = { - 'D': CartesianDomain({'x': [0, 1], 'y': [0, 1]}), - 'g1': CartesianDomain({'x': [0, 1], 'y': 1}), - 'g2': CartesianDomain({'x': [0, 1], 'y': 0}), - 'g3': CartesianDomain({'x': 1, 'y': [0, 1]}), - 'g4': CartesianDomain({'x': 0, 'y': [0, 1]}), + "D": CartesianDomain({"x": [0, 1], "y": [0, 1]}), + "g1": CartesianDomain({"x": [0, 1], "y": 1}), + "g2": CartesianDomain({"x": [0, 1], "y": 0}), + "g3": CartesianDomain({"x": 1, "y": [0, 1]}), + "g4": CartesianDomain({"x": 0, "y": [0, 1]}), } conditions = { - 'nil_g1': Condition(domain='g1', equation=FixedValue(0.0)), - 'nil_g2': Condition(domain='g2', equation=FixedValue(0.0)), - 'nil_g3': Condition(domain='g3', equation=FixedValue(0.0)), - 'nil_g4': Condition(domain='g4', equation=FixedValue(0.0)), - 'laplace_D': Condition(domain='D', equation=my_laplace), + "nil_g1": Condition(domain="g1", equation=FixedValue(0.0)), + "nil_g2": Condition(domain="g2", equation=FixedValue(0.0)), + "nil_g3": Condition(domain="g3", equation=FixedValue(0.0)), + "nil_g4": Condition(domain="g4", equation=FixedValue(0.0)), + "laplace_D": Condition(domain="D", equation=my_laplace), } def poisson_sol(self, pts): - return -(torch.sin(pts.extract(['x']) * torch.pi) * - torch.sin(pts.extract(['y']) * torch.pi)) - + return -( + torch.sin(pts.extract(["x"]) * torch.pi) + * torch.sin(pts.extract(["y"]) * torch.pi) + ) diff --git a/pina/problem/zoo/supervised_problem.py b/pina/problem/zoo/supervised_problem.py index 6acac7a..ef04062 100644 --- a/pina/problem/zoo/supervised_problem.py +++ b/pina/problem/zoo/supervised_problem.py @@ -2,6 +2,7 @@ from pina.problem import AbstractProblem from pina import Condition from pina import Graph + class SupervisedProblem(AbstractProblem): """ A problem definition for supervised learning in PINA. @@ -15,6 +16,7 @@ class SupervisedProblem(AbstractProblem): >>> output_data = torch.rand((100, 10)) >>> problem = SupervisedProblem(input_data, output_data) """ + conditions = dict() output_variables = None @@ -29,9 +31,7 @@ class SupervisedProblem(AbstractProblem): """ if isinstance(input_, Graph): input_ = input_.data - self.conditions['data'] = Condition( - input_points=input_, - output_points = output_ + self.conditions["data"] = Condition( + input_points=input_, output_points=output_ ) super().__init__() - \ No newline at end of file diff --git a/pina/solver/garom.py b/pina/solver/garom.py index 685da5e..930c144 100644 --- a/pina/solver/garom.py +++ b/pina/solver/garom.py @@ -1,4 +1,4 @@ -""" Module for GAROM """ +"""Module for GAROM""" import torch @@ -86,13 +86,14 @@ class GAROM(MultiSolverInterface): scheduler_generator, scheduler_discriminator, ], - use_lt=False + use_lt=False, ) # check consistency - check_consistency(loss, (LossInterface, _Loss, torch.nn.Module), - subclass=False) - self._loss = loss + check_consistency( + loss, (LossInterface, _Loss, torch.nn.Module), subclass=False + ) + self._loss = loss # set automatic optimization for GANs self.automatic_optimization = False @@ -152,9 +153,7 @@ class GAROM(MultiSolverInterface): # generator loss r_loss = self._loss(snapshots, generated_snapshots) - d_fake = self.discriminator( - [generated_snapshots, parameters] - ) + d_fake = self.discriminator([generated_snapshots, parameters]) g_loss = ( self._loss(d_fake, generated_snapshots) + self.regularizer * r_loss ) @@ -180,8 +179,7 @@ class GAROM(MultiSolverInterface): """ # increase by one the counter of optimization to save loggers ( - self.trainer.fit_loop.epoch_loop.manual_optimization - .optim_step_progress.total.completed + self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed ) += 1 return super().on_train_batch_end(outputs, batch, batch_idx) @@ -198,9 +196,7 @@ class GAROM(MultiSolverInterface): # Discriminator pass d_real = self.discriminator([snapshots, parameters]) - d_fake = self.discriminator( - [generated_snapshots, parameters] - ) + d_fake = self.discriminator([generated_snapshots, parameters]) # evaluate loss d_loss_real = self._loss(d_real, snapshots) @@ -236,7 +232,10 @@ class GAROM(MultiSolverInterface): """ condition_loss = {} for condition_name, points in batch: - parameters, snapshots = points['input_points'], points['output_points'] + parameters, snapshots = ( + points["input_points"], + points["output_points"], + ) d_loss_real, d_loss_fake, d_loss = self._train_discriminator( parameters, snapshots ) @@ -245,51 +244,53 @@ class GAROM(MultiSolverInterface): condition_loss[condition_name] = r_loss # some extra logging - self.store_log( - "d_loss", - float(d_loss), - self.get_batch_size(batch) - ) - self.store_log( - "g_loss", - float(g_loss), - self.get_batch_size(batch) - ) + self.store_log("d_loss", float(d_loss), self.get_batch_size(batch)) + self.store_log("g_loss", float(g_loss), self.get_batch_size(batch)) self.store_log( "stability_metric", float(d_loss_real + torch.abs(diff)), - self.get_batch_size(batch) + self.get_batch_size(batch), ) return condition_loss def validation_step(self, batch): condition_loss = {} for condition_name, points in batch: - parameters, snapshots = points['input_points'], points['output_points'] + parameters, snapshots = ( + points["input_points"], + points["output_points"], + ) snapshots_gen = self.generator(parameters) - condition_loss[condition_name] = self._loss(snapshots, snapshots_gen) + condition_loss[condition_name] = self._loss( + snapshots, snapshots_gen + ) loss = self.weighting.aggregate(condition_loss) - self.store_log('val_loss', loss, self.get_batch_size(batch)) + self.store_log("val_loss", loss, self.get_batch_size(batch)) return loss - + def test_step(self, batch): condition_loss = {} for condition_name, points in batch: - parameters, snapshots = points['input_points'], points['output_points'] + parameters, snapshots = ( + points["input_points"], + points["output_points"], + ) snapshots_gen = self.generator(parameters) - condition_loss[condition_name] = self._loss(snapshots, snapshots_gen) + condition_loss[condition_name] = self._loss( + snapshots, snapshots_gen + ) loss = self.weighting.aggregate(condition_loss) - self.store_log('test_loss', loss, self.get_batch_size(batch)) + self.store_log("test_loss", loss, self.get_batch_size(batch)) return loss - + @property def generator(self): return self.models[0] - + @property def discriminator(self): return self.models[1] - + @property def optimizer_generator(self): return self.optimizers[0].instance diff --git a/pina/solver/physic_informed_solver/causal_pinn.py b/pina/solver/physic_informed_solver/causal_pinn.py index eab03db..4984722 100644 --- a/pina/solver/physic_informed_solver/causal_pinn.py +++ b/pina/solver/physic_informed_solver/causal_pinn.py @@ -1,4 +1,4 @@ -""" Module for Causal PINN. """ +"""Module for Causal PINN.""" import torch @@ -67,14 +67,16 @@ class CausalPINN(PINN): :class:`~pina.problem.timedep_problem.TimeDependentProblem` class. """ - def __init__(self, - problem, - model, - optimizer=None, - scheduler=None, - weighting=None, - loss=None, - eps=100): + def __init__( + self, + problem, + model, + optimizer=None, + scheduler=None, + weighting=None, + loss=None, + eps=100, + ): """ :param torch.nn.Module model: The neural network model to use. :param AbstractProblem problem: The formulation of the problem. @@ -88,12 +90,14 @@ class CausalPINN(PINN): default `None`. :param float eps: The exponential decay parameter; default `100`. """ - super().__init__(model=model, - problem=problem, - optimizer=optimizer, - scheduler=scheduler, - weighting=weighting, - loss=loss) + super().__init__( + model=model, + problem=problem, + optimizer=optimizer, + scheduler=scheduler, + weighting=weighting, + loss=loss, + ) # checking consistency check_consistency(eps, (int, float)) diff --git a/pina/solver/physic_informed_solver/competitive_pinn.py b/pina/solver/physic_informed_solver/competitive_pinn.py index 2485ea8..0225ea6 100644 --- a/pina/solver/physic_informed_solver/competitive_pinn.py +++ b/pina/solver/physic_informed_solver/competitive_pinn.py @@ -1,4 +1,4 @@ -""" Module for Competitive PINN. """ +"""Module for Competitive PINN.""" import torch import copy @@ -55,16 +55,18 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface): ``extra_feature``. """ - def __init__(self, - problem, - model, - discriminator=None, - optimizer_model=None, - optimizer_discriminator=None, - scheduler_model=None, - scheduler_discriminator=None, - weighting=None, - loss=None): + def __init__( + self, + problem, + model, + discriminator=None, + optimizer_model=None, + optimizer_discriminator=None, + scheduler_model=None, + scheduler_discriminator=None, + weighting=None, + loss=None, + ): """ :param AbstractProblem problem: The formulation of the problem. :param torch.nn.Module model: The neural network model to use @@ -72,13 +74,13 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface): :param torch.nn.Module discriminator: The neural network model to use for the discriminator. If ``None``, the discriminator network will have the same architecture as the model network. - :param torch.optim.Optimizer optimizer_model: The neural network + :param torch.optim.Optimizer optimizer_model: The neural network optimizer to use for the model network; default `None`. :param torch.optim.Optimizer optimizer_discriminator: The neural network optimizer to use for the discriminator network; default `None`. - :param torch.optim.LRScheduler scheduler_model: Learning rate scheduler + :param torch.optim.LRScheduler scheduler_model: Learning rate scheduler for the model; default `None`. - :param torch.optim.LRScheduler scheduler_discriminator: Learning rate + :param torch.optim.LRScheduler scheduler_discriminator: Learning rate scheduler for the discriminator; default `None`. :param WeightingInterface weighting: The weighting schema to use; default `None`. @@ -88,12 +90,14 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface): if discriminator is None: discriminator = copy.deepcopy(model) - super().__init__(models=[model, discriminator], - problem=problem, - optimizers=[optimizer_model, optimizer_discriminator], - schedulers=[scheduler_model, scheduler_discriminator], - weighting=weighting, - loss=loss) + super().__init__( + models=[model, discriminator], + problem=problem, + optimizers=[optimizer_model, optimizer_discriminator], + schedulers=[scheduler_model, scheduler_discriminator], + weighting=weighting, + loss=loss, + ) # Set automatic optimization to False self.automatic_optimization = False @@ -158,7 +162,7 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface): def loss_data(self, input_pts, output_pts): """ - The data loss for the CompetitivePINN solver. It computes the loss + The data loss for the CompetitivePINN solver. It computes the loss between the network output against the true solution. :param LabelTensor input_tensor: The input to the neural networks. @@ -167,7 +171,7 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface): :return: The computed data loss. :rtype: torch.Tensor """ - loss_val = (super().loss_data(input_pts, output_pts)) + loss_val = super().loss_data(input_pts, output_pts) # prepare for optimizer step called in training step loss_val.backward() return loss_val @@ -195,10 +199,14 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface): self.scheduler_model.hook(self.optimizer_model) self.scheduler_discriminator.hook(self.optimizer_discriminator) return ( - [self.optimizer_model.instance, - self.optimizer_discriminator.instance], - [self.scheduler_model.instance, - self.scheduler_discriminator.instance] + [ + self.optimizer_model.instance, + self.optimizer_discriminator.instance, + ], + [ + self.scheduler_model.instance, + self.scheduler_discriminator.instance, + ], ) def on_train_batch_end(self, outputs, batch, batch_idx): @@ -216,8 +224,7 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface): """ # increase by one the counter of optimization to save loggers ( - self.trainer.fit_loop.epoch_loop.manual_optimization - .optim_step_progress.total.completed + self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed ) += 1 return super().on_train_batch_end(outputs, batch, batch_idx) diff --git a/pina/solver/physic_informed_solver/gradient_pinn.py b/pina/solver/physic_informed_solver/gradient_pinn.py index de439b5..cad5bce 100644 --- a/pina/solver/physic_informed_solver/gradient_pinn.py +++ b/pina/solver/physic_informed_solver/gradient_pinn.py @@ -1,4 +1,4 @@ -""" Module for Gradient PINN. """ +"""Module for Gradient PINN.""" import torch @@ -59,18 +59,20 @@ class GradientPINN(PINN): class. """ - def __init__(self, - problem, - model, - optimizer=None, - scheduler=None, - weighting=None, - loss=None): + def __init__( + self, + problem, + model, + optimizer=None, + scheduler=None, + weighting=None, + loss=None, + ): """ :param torch.nn.Module model: The neural network model to use. :param AbstractProblem problem: The formulation of the problem. It must inherit from at least - :class:`~pina.problem.spatial_problem.SpatialProblem` to compute + :class:`~pina.problem.spatial_problem.SpatialProblem` to compute the gradient of the loss. :param torch.optim.Optimizer optimizer: The neural network optimizer to use; default `None`. @@ -81,12 +83,14 @@ class GradientPINN(PINN): :param torch.nn.Module loss: The loss function to be minimized; default `None`. """ - super().__init__(model=model, - problem=problem, - optimizer=optimizer, - scheduler=scheduler, - weighting=weighting, - loss=loss) + super().__init__( + model=model, + problem=problem, + optimizer=optimizer, + scheduler=scheduler, + weighting=weighting, + loss=loss, + ) if not isinstance(self.problem, SpatialProblem): raise ValueError( diff --git a/pina/solver/physic_informed_solver/pinn.py b/pina/solver/physic_informed_solver/pinn.py index 83a4a1e..d3c2af6 100644 --- a/pina/solver/physic_informed_solver/pinn.py +++ b/pina/solver/physic_informed_solver/pinn.py @@ -1,4 +1,4 @@ -""" Module for Physics Informed Neural Network. """ +"""Module for Physics Informed Neural Network.""" import torch @@ -48,13 +48,15 @@ class PINN(PINNInterface, SingleSolverInterface): DOI: `10.1038 `_. """ - def __init__(self, - problem, - model, - optimizer=None, - scheduler=None, - weighting=None, - loss=None): + def __init__( + self, + problem, + model, + optimizer=None, + scheduler=None, + weighting=None, + loss=None, + ): """ :param torch.nn.Module model: The neural network model to use. :param AbstractProblem problem: The formulation of the problem. @@ -67,12 +69,14 @@ class PINN(PINNInterface, SingleSolverInterface): :param torch.nn.Module loss: The loss function to be minimized; default `None`. """ - super().__init__(model=model, - problem=problem, - optimizer=optimizer, - scheduler=scheduler, - weighting=weighting, - loss=loss) + super().__init__( + model=model, + problem=problem, + optimizer=optimizer, + scheduler=scheduler, + weighting=weighting, + loss=loss, + ) def loss_phys(self, samples, equation): """ @@ -112,7 +116,4 @@ class PINN(PINNInterface, SingleSolverInterface): } ) self.scheduler.hook(self.optimizer) - return ( - [self.optimizer.instance], - [self.scheduler.instance] - ) + return ([self.optimizer.instance], [self.scheduler.instance]) diff --git a/pina/solver/physic_informed_solver/pinn_interface.py b/pina/solver/physic_informed_solver/pinn_interface.py index c79e1ba..20ce4b2 100644 --- a/pina/solver/physic_informed_solver/pinn_interface.py +++ b/pina/solver/physic_informed_solver/pinn_interface.py @@ -1,4 +1,4 @@ -""" Module for Physics Informed Neural Network Interface.""" +"""Module for Physics Informed Neural Network Interface.""" from abc import ABCMeta, abstractmethod import torch @@ -11,7 +11,7 @@ from ...problem import InverseProblem from ...condition import ( InputOutputPointsCondition, InputPointsEquationCondition, - DomainEquationCondition + DomainEquationCondition, ) @@ -20,22 +20,20 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta): Base PINN solver class. This class implements the Solver Interface for Physics Informed Neural Network solver. - This class can be used to define PINNs with multiple ``optimizers``, + This class can be used to define PINNs with multiple ``optimizers``, and/or ``models``. By default it takes :class:`~pina.problem.abstract_problem.AbstractProblem`, so the user can choose what type of problem the implemented solver, inheriting from this class, is designed to solve. """ + accepted_conditions_types = ( InputOutputPointsCondition, InputPointsEquationCondition, - DomainEquationCondition + DomainEquationCondition, ) - def __init__(self, - problem, - loss=None, - **kwargs): + def __init__(self, problem, loss=None, **kwargs): """ :param AbstractProblem problem: A problem definition instance. :param torch.nn.Module loss: The loss function to be minimized, @@ -45,9 +43,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta): if loss is None: loss = torch.nn.MSELoss() - super().__init__(problem=problem, - use_lt=True, - **kwargs) + super().__init__(problem=problem, use_lt=True, **kwargs) # check consistency check_consistency(loss, (LossInterface, _Loss), subclass=False) @@ -72,14 +68,14 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta): def validation_step(self, batch): losses = self._run_optimization_cycle(batch, self._residual_loss) loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor) - self.store_log('val_loss', loss, self.get_batch_size(batch)) + self.store_log("val_loss", loss, self.get_batch_size(batch)) return loss @torch.set_grad_enabled(True) def test_step(self, batch): losses = self._run_optimization_cycle(batch, self._residual_loss) loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor) - self.store_log('test_loss', loss, self.get_batch_size(batch)) + self.store_log("test_loss", loss, self.get_batch_size(batch)) return loss def loss_data(self, input_pts, output_pts): @@ -129,42 +125,38 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta): except TypeError: # this occurs when the function has three inputs (inverse problem) residual = equation.residual( - samples, - self.forward(samples), - self._params + samples, self.forward(samples), self._params ) return residual def _residual_loss(self, samples, equation): residuals = self.compute_residual(samples, equation) return self.loss(residuals, torch.zeros_like(residuals)) - + def _run_optimization_cycle(self, batch, loss_residuals): condition_loss = {} for condition_name, points in batch: self.__metric = condition_name # if equations are passed - if 'output_points' not in points: - input_pts = points['input_points'] + if "output_points" not in points: + input_pts = points["input_points"] condition = self.problem.conditions[condition_name] loss = loss_residuals( - input_pts.requires_grad_(), - condition.equation + input_pts.requires_grad_(), condition.equation ) # if data are passed else: - input_pts = points['input_points'] - output_pts = points['output_points'] + input_pts = points["input_points"] + output_pts = points["output_points"] loss = self.loss_data( - input_pts=input_pts.requires_grad_(), - output_pts=output_pts + input_pts=input_pts.requires_grad_(), output_pts=output_pts ) # append loss condition_loss[condition_name] = loss # clamp unknown parameters in InverseProblem (if needed) self._clamp_params() return condition_loss - + def _clamp_inverse_problem_params(self): """ Clamps the parameters of the inverse problem @@ -175,14 +167,14 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta): self.problem.unknown_parameter_domain.range_[v][0], self.problem.unknown_parameter_domain.range_[v][1], ) - + @property def loss(self): """ Loss used for training. """ return self._loss - + @property def current_condition_name(self): """ diff --git a/pina/solver/physic_informed_solver/rba_pinn.py b/pina/solver/physic_informed_solver/rba_pinn.py index 38e5061..3f189e9 100644 --- a/pina/solver/physic_informed_solver/rba_pinn.py +++ b/pina/solver/physic_informed_solver/rba_pinn.py @@ -1,4 +1,4 @@ -""" Module for Residual-Based Attention PINN. """ +"""Module for Residual-Based Attention PINN.""" from copy import deepcopy import torch @@ -66,15 +66,17 @@ class RBAPINN(PINN): j.cma.2024.116805 `_. """ - def __init__(self, - problem, - model, - optimizer=None, - scheduler=None, - weighting=None, - loss=None, - eta=0.001, - gamma=0.999): + def __init__( + self, + problem, + model, + optimizer=None, + scheduler=None, + weighting=None, + loss=None, + eta=0.001, + gamma=0.999, + ): """ :param torch.nn.Module model: The neural network model to use. :param AbstractProblem problem: The formulation of the problem. @@ -86,17 +88,19 @@ class RBAPINN(PINN): default `None`. :param torch.nn.Module loss: The loss function to be minimized; default `None`. - :param float | int eta: The learning rate for the weights of the + :param float | int eta: The learning rate for the weights of the residual; default 0.001. :param float gamma: The decay parameter in the update of the weights of the residual. Must be between 0 and 1; default 0.999. """ - super().__init__(model=model, - problem=problem, - optimizer=optimizer, - scheduler=scheduler, - weighting=weighting, - loss=loss) + super().__init__( + model=model, + problem=problem, + optimizer=optimizer, + scheduler=scheduler, + weighting=weighting, + loss=loss, + ) # check consistency check_consistency(eta, (float, int)) @@ -119,9 +123,11 @@ class RBAPINN(PINN): # for now RBAPINN is implemented only for batch_size = None def on_train_start(self): if self.trainer.batch_size is not None: - raise NotImplementedError("RBAPINN only works with full batch " - "size, set batch_size=None inside the " - "Trainer to use the solver.") + raise NotImplementedError( + "RBAPINN only works with full batch " + "size, set batch_size=None inside the " + "Trainer to use the solver." + ) return super().on_train_start() def _vect_to_scalar(self, loss_value): @@ -160,10 +166,11 @@ class RBAPINN(PINN): cond = self.current_condition_name r_norm = ( - self.eta * torch.abs(residual) + self.eta + * torch.abs(residual) / (torch.max(torch.abs(residual)) + 1e-12) ) - self.weights[cond] = (self.gamma*self.weights[cond] + r_norm).detach() + self.weights[cond] = (self.gamma * self.weights[cond] + r_norm).detach() loss_value = self._vectorial_loss( torch.zeros_like(residual, requires_grad=True), residual diff --git a/pina/solver/physic_informed_solver/self_adaptive_pinn.py b/pina/solver/physic_informed_solver/self_adaptive_pinn.py index 9314ae9..185643d 100644 --- a/pina/solver/physic_informed_solver/self_adaptive_pinn.py +++ b/pina/solver/physic_informed_solver/self_adaptive_pinn.py @@ -1,4 +1,4 @@ -""" Module for Self-Adaptive PINN. """ +"""Module for Self-Adaptive PINN.""" import torch from copy import deepcopy @@ -99,25 +99,27 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface): j.jcp.2022.111722 `_. """ - def __init__(self, - problem, - model, - weight_function=torch.nn.Sigmoid(), - optimizer_model=None, - optimizer_weights=None, - scheduler_model=None, - scheduler_weights=None, - weighting=None, - loss=None): + def __init__( + self, + problem, + model, + weight_function=torch.nn.Sigmoid(), + optimizer_model=None, + optimizer_weights=None, + scheduler_model=None, + scheduler_weights=None, + weighting=None, + loss=None, + ): """ :param AbstractProblem problem: The formulation of the problem. - :param torch.nn.Module model: The neural network model to use for + :param torch.nn.Module model: The neural network model to use for the model. :param torch.nn.Module weight_function: The neural network model related to the Self-Adaptive PINN mask; default `torch.nn.Sigmoid()` - :param torch.optim.Optimizer optimizer_model: The neural network + :param torch.optim.Optimizer optimizer_model: The neural network optimizer to use for the model network; default `None`. - :param torch.optim.Optimizer optimizer_weights: The neural network + :param torch.optim.Optimizer optimizer_weights: The neural network optimizer to use for mask model; default `None`. :param torch.optim.LRScheduler scheduler_model: Learning rate scheduler for the model; default `None`. @@ -137,12 +139,14 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface): weights_dict[condition_name] = Weights(weight_function) weights_dict = torch.nn.ModuleDict(weights_dict) - super().__init__(models=[model, weights_dict], - problem=problem, - optimizers=[optimizer_model, optimizer_weights], - schedulers=[scheduler_model, scheduler_weights], - weighting=weighting, - loss=loss) + super().__init__( + models=[model, weights_dict], + problem=problem, + optimizers=[optimizer_model, optimizer_weights], + schedulers=[scheduler_model, scheduler_weights], + weighting=weighting, + loss=loss, + ) # Set automatic optimization to False self.automatic_optimization = False @@ -202,7 +206,7 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface): # create a new one by setting requires_grad to True. # In alternative set `retain_graph=True`. samples = samples.detach() - samples.requires_grad_()# = True + samples.requires_grad_() # = True # Train the model weighted_loss = self._loss_phys(samples, equation) @@ -244,20 +248,18 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface): self.optimizer_weights.hook(self.weights_dict.parameters()) if isinstance(self.problem, InverseProblem): self.optimizer_model.instance.add_param_group( - { - "params": [ - self._params[var] - for var in self.problem.unknown_variables - ] - } - ) + { + "params": [ + self._params[var] + for var in self.problem.unknown_variables + ] + } + ) self.scheduler_model.hook(self.optimizer_model) self.scheduler_weights.hook(self.optimizer_weights) return ( - [self.optimizer_model.instance, - self.optimizer_weights.instance], - [self.scheduler_model.instance, - self.scheduler_weights.instance] + [self.optimizer_model.instance, self.optimizer_weights.instance], + [self.scheduler_model.instance, self.scheduler_weights.instance], ) def on_train_batch_end(self, outputs, batch, batch_idx): @@ -275,8 +277,7 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface): """ # increase by one the counter of optimization to save loggers ( - self.trainer.fit_loop.epoch_loop.manual_optimization - .optim_step_progress.total.completed + self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed ) += 1 return super().on_train_batch_end(outputs, batch, batch_idx) @@ -291,19 +292,22 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface): :rtype: Any """ if self.trainer.batch_size is not None: - raise NotImplementedError("SelfAdaptivePINN only works with full " - "batch size, set batch_size=None inside " - "the Trainer to use the solver.") + raise NotImplementedError( + "SelfAdaptivePINN only works with full " + "batch size, set batch_size=None inside " + "the Trainer to use the solver." + ) device = torch.device( self.trainer._accelerator_connector._accelerator_flag ) # Initialize the self adaptive weights only for training points - for condition_name, tensor in ( - self.trainer.data_module.train_dataset.input_points.items() - ): - self.weights_dict[condition_name].sa_weights.data = ( - torch.rand((tensor.shape[0], 1), device=device) + for ( + condition_name, + tensor, + ) in self.trainer.data_module.train_dataset.input_points.items(): + self.weights_dict[condition_name].sa_weights.data = torch.rand( + (tensor.shape[0], 1), device=device ) return super().on_train_start() @@ -318,11 +322,11 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface): # First initialize self-adaptive weights with correct shape, # then load the values from the checkpoint. for condition_name, _ in self.problem.input_pts.items(): - shape = checkpoint['state_dict'][ + shape = checkpoint["state_dict"][ f"_pina_models.1.{condition_name}.sa_weights" ].shape - self.weights_dict[condition_name].sa_weights.data = ( - torch.rand(shape) + self.weights_dict[condition_name].sa_weights.data = torch.rand( + shape ) return super().on_load_checkpoint(checkpoint) diff --git a/pina/solver/reduced_order_model.py b/pina/solver/reduced_order_model.py index 16580e1..54aa8a2 100644 --- a/pina/solver/reduced_order_model.py +++ b/pina/solver/reduced_order_model.py @@ -1,4 +1,4 @@ -""" Module for ReducedOrderModelSolver """ +"""Module for ReducedOrderModelSolver""" import torch @@ -126,7 +126,7 @@ class ReducedOrderModelSolver(SupervisedSolver): optimizer=optimizer, scheduler=scheduler, weighting=weighting, - use_lt=use_lt + use_lt=use_lt, ) # assert reduction object contains encode/ decode @@ -185,4 +185,4 @@ class ReducedOrderModelSolver(SupervisedSolver): reduction_network.decode(encode_repr_reduction_network), output_pts ) - return loss_encode + loss_reconstruction \ No newline at end of file + return loss_encode + loss_reconstruction diff --git a/pina/solver/solver.py b/pina/solver/solver.py index 181390d..2ca7c1c 100644 --- a/pina/solver/solver.py +++ b/pina/solver/solver.py @@ -1,4 +1,4 @@ -""" Solver module. """ +"""Solver module.""" import lightning import torch @@ -18,10 +18,7 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): SolverInterface base class. This class is a wrapper of LightningModule. """ - def __init__(self, - problem, - weighting, - use_lt): + def __init__(self, problem, weighting, use_lt): """ :param problem: A problem definition instance. :type problem: AbstractProblem @@ -82,7 +79,9 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): """ losses = self.optimization_cycle(batch) for name, value in losses.items(): - self.store_log(f'{name}_loss', value.item(), self.get_batch_size(batch)) + self.store_log( + f"{name}_loss", value.item(), self.get_batch_size(batch) + ) loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor) return loss @@ -96,7 +95,7 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): :rtype: LabelTensor """ loss = self._optimization_cycle(batch=batch) - self.store_log('train_loss', loss, self.get_batch_size(batch)) + self.store_log("train_loss", loss, self.get_batch_size(batch)) return loss def validation_step(self, batch): @@ -107,7 +106,7 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): :type batch: tuple """ loss = self._optimization_cycle(batch=batch) - self.store_log('val_loss', loss, self.get_batch_size(batch)) + self.store_log("val_loss", loss, self.get_batch_size(batch)) def test_step(self, batch): """ @@ -117,14 +116,15 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): :type batch: tuple """ loss = self._optimization_cycle(batch=batch) - self.store_log('test_loss', loss, self.get_batch_size(batch)) + self.store_log("test_loss", loss, self.get_batch_size(batch)) def store_log(self, name, value, batch_size): - self.log(name=name, - value=value, - batch_size=batch_size, - **self.trainer.logging_kwargs - ) + self.log( + name=name, + value=value, + batch_size=batch_size, + **self.trainer.logging_kwargs, + ) @abstractmethod def forward(self, *args, **kwargs): @@ -172,7 +172,7 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): # assuming batch is a custom Batch object batch_size = 0 for data in batch: - batch_size += len(data[1]['input_points']) + batch_size += len(data[1]["input_points"]) return batch_size @staticmethod @@ -203,8 +203,9 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): def _check_already_compiled(self): models = self._pina_models - if len(models) == 1 and isinstance(self._pina_models[0], - torch.nn.ModuleDict): + if len(models) == 1 and isinstance( + self._pina_models[0], torch.nn.ModuleDict + ): models = list(self._pina_models.values()) for model in models: if not isinstance(model, (OptimizedModule, torch.nn.ModuleDict)): @@ -225,13 +226,15 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): class SingleSolverInterface(SolverInterface): - def __init__(self, - problem, - model, - optimizer=None, - scheduler=None, - weighting=None, - use_lt=True): + def __init__( + self, + problem, + model, + optimizer=None, + scheduler=None, + weighting=None, + use_lt=True, + ): """ :param problem: A problem definition instance. :type problem: AbstractProblem @@ -248,9 +251,7 @@ class SingleSolverInterface(SolverInterface): if scheduler is None: scheduler = self.default_torch_scheduler() - super().__init__(problem=problem, - use_lt=use_lt, - weighting=weighting) + super().__init__(problem=problem, use_lt=use_lt, weighting=weighting) # check consistency of models argument and encapsulate in list check_consistency(model, torch.nn.Module) @@ -284,10 +285,7 @@ class SingleSolverInterface(SolverInterface): """ self.optimizer.hook(self.model.parameters()) self.scheduler.hook(self.optimizer) - return ( - [self.optimizer.instance], - [self.scheduler.instance] - ) + return ([self.optimizer.instance], [self.scheduler.instance]) def _compile_model(self): if isinstance(self._pina_models[0], torch.nn.ModuleDict): @@ -330,13 +328,15 @@ class MultiSolverInterface(SolverInterface): SolverInterface class """ - def __init__(self, - problem, - models, - optimizers=None, - schedulers=None, - weighting=None, - use_lt=True): + def __init__( + self, + problem, + models, + optimizers=None, + schedulers=None, + weighting=None, + use_lt=True, + ): """ :param problem: A problem definition instance. :type problem: AbstractProblem @@ -351,9 +351,9 @@ class MultiSolverInterface(SolverInterface): """ if not isinstance(models, (list, tuple)) or len(models) < 2: raise ValueError( - 'models should be list[torch.nn.Module] or ' - 'tuple[torch.nn.Module] with len greater than ' - 'one.' + "models should be list[torch.nn.Module] or " + "tuple[torch.nn.Module] with len greater than " + "one." ) if any(opt is None for opt in optimizers): @@ -368,9 +368,7 @@ class MultiSolverInterface(SolverInterface): for sched in schedulers ] - super().__init__(problem=problem, - use_lt=use_lt, - weighting=weighting) + super().__init__(problem=problem, use_lt=use_lt, weighting=weighting) # check consistency of models argument and encapsulate in list check_consistency(models, torch.nn.Module) @@ -400,15 +398,15 @@ class MultiSolverInterface(SolverInterface): :return: The optimizers and the schedulers :rtype: tuple(list, list) """ - for optimizer, scheduler, model in zip(self.optimizers, - self.schedulers, - self.models): + for optimizer, scheduler, model in zip( + self.optimizers, self.schedulers, self.models + ): optimizer.hook(model.parameters()) scheduler.hook(optimizer) return ( [optimizer.instance for optimizer in self.optimizers], - [scheduler.instance for scheduler in self.schedulers] + [scheduler.instance for scheduler in self.schedulers], ) def _compile_model(self): diff --git a/pina/solver/supervised.py b/pina/solver/supervised.py index b453b65..56771b8 100644 --- a/pina/solver/supervised.py +++ b/pina/solver/supervised.py @@ -1,4 +1,5 @@ -""" Module for SupervisedSolver """ +"""Module for SupervisedSolver""" + import torch from torch.nn.modules.loss import _Loss from .solver import SingleSolverInterface @@ -38,14 +39,16 @@ class SupervisedSolver(SingleSolverInterface): accepted_conditions_types = InputOutputPointsCondition - def __init__(self, - problem, - model, - loss=None, - optimizer=None, - scheduler=None, - weighting=None, - use_lt=True): + def __init__( + self, + problem, + model, + loss=None, + optimizer=None, + scheduler=None, + weighting=None, + use_lt=True, + ): """ :param AbstractProblem problem: The formualation of the problem. :param torch.nn.Module model: The neural network model to use. @@ -61,16 +64,19 @@ class SupervisedSolver(SingleSolverInterface): if loss is None: loss = torch.nn.MSELoss() - super().__init__(model=model, - problem=problem, - optimizer=optimizer, - scheduler=scheduler, - weighting=weighting, - use_lt=use_lt) + super().__init__( + model=model, + problem=problem, + optimizer=optimizer, + scheduler=scheduler, + weighting=weighting, + use_lt=use_lt, + ) # check consistency - check_consistency(loss, (LossInterface, _Loss, torch.nn.Module), - subclass=False) + check_consistency( + loss, (LossInterface, _Loss, torch.nn.Module), subclass=False + ) self._loss = loss def optimization_cycle(self, batch): @@ -79,7 +85,7 @@ class SupervisedSolver(SingleSolverInterface): in the given batch. :param batch: A batch of data, where each element is a tuple containing - a condition name and a dictionary of points. + a condition name and a dictionary of points. :type batch: list of tuples (str, dict) :return: The computed loss for the all conditions in the batch, cast to a subclass of `torch.Tensor`. It should return a dict @@ -88,9 +94,13 @@ class SupervisedSolver(SingleSolverInterface): """ condition_loss = {} for condition_name, points in batch: - input_pts, output_pts = points['input_points'], points['output_points'] + input_pts, output_pts = ( + points["input_points"], + points["output_points"], + ) condition_loss[condition_name] = self.loss_data( - input_pts=input_pts, output_pts=output_pts) + input_pts=input_pts, output_pts=output_pts + ) return condition_loss def loss_data(self, input_pts, output_pts): @@ -114,4 +124,4 @@ class SupervisedSolver(SingleSolverInterface): """ Loss for training. """ - return self._loss \ No newline at end of file + return self._loss diff --git a/pina/solvers/__init__.py b/pina/solvers/__init__.py index cc4ca13..b7373a3 100644 --- a/pina/solvers/__init__.py +++ b/pina/solvers/__init__.py @@ -8,6 +8,7 @@ from ..utils import custom_warning_format warnings.formatwarning = custom_warning_format warnings.filterwarnings("always", category=DeprecationWarning) warnings.warn( - f"'pina.solvers' is deprecated and will be removed " - f"in future versions. Please use 'pina.solver' instead.", - DeprecationWarning) \ No newline at end of file + f"'pina.solvers' is deprecated and will be removed " + f"in future versions. Please use 'pina.solver' instead.", + DeprecationWarning, +) diff --git a/pina/solvers/pinns/__init__.py b/pina/solvers/pinns/__init__.py index 90fee4b..78184b0 100644 --- a/pina/solvers/pinns/__init__.py +++ b/pina/solvers/pinns/__init__.py @@ -8,7 +8,8 @@ from ...utils import custom_warning_format warnings.formatwarning = custom_warning_format warnings.filterwarnings("always", category=DeprecationWarning) warnings.warn( - "'pina.solvers.pinns' is deprecated and will be removed " - "in future versions. Please use " - "'pina.solver.physic_informed_solver' instead.", - DeprecationWarning) \ No newline at end of file + "'pina.solvers.pinns' is deprecated and will be removed " + "in future versions. Please use " + "'pina.solver.physic_informed_solver' instead.", + DeprecationWarning, +) diff --git a/pina/trainer.py b/pina/trainer.py index 3fe0132..8831d40 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -1,4 +1,5 @@ -""" Trainer module. """ +"""Trainer module.""" + import sys import torch import lightning @@ -9,18 +10,20 @@ from .solver import SolverInterface, PINNInterface class Trainer(lightning.pytorch.Trainer): - def __init__(self, - solver, - batch_size=None, - train_size=.7, - test_size=.2, - val_size=.1, - predict_size=0., - compile=None, - automatic_batching=None, - num_workers=None, - pin_memory=None, - **kwargs): + def __init__( + self, + solver, + batch_size=None, + train_size=0.7, + test_size=0.2, + val_size=0.1, + predict_size=0.0, + compile=None, + automatic_batching=None, + num_workers=None, + pin_memory=None, + **kwargs, + ): """ PINA Trainer class for costumizing every aspect of training via flags. @@ -75,30 +78,34 @@ class Trainer(lightning.pytorch.Trainer): else: num_workers = 0 if train_size + test_size + val_size + predict_size > 1: - raise ValueError('train_size, test_size, val_size and predict_size ' - 'must sum up to 1.') + raise ValueError( + "train_size, test_size, val_size and predict_size " + "must sum up to 1." + ) for size in [train_size, test_size, val_size, predict_size]: if size < 0 or size > 1: - raise ValueError('splitting sizes for train, validation, test ' - 'and prediction must be between [0, 1].') + raise ValueError( + "splitting sizes for train, validation, test " + "and prediction must be between [0, 1]." + ) if batch_size is not None: check_consistency(batch_size, int) # inference mode set to false when validating/testing PINNs otherwise # gradient is not tracked and optimization_cycle fails if isinstance(solver, PINNInterface): - kwargs['inference_mode'] = False + kwargs["inference_mode"] = False # Logging depends on the batch size, when batch_size is None then # log_every_n_steps should be zero if batch_size is None: - kwargs['log_every_n_steps'] = 0 + kwargs["log_every_n_steps"] = 0 else: - kwargs.setdefault('log_every_n_steps', 50) # default for lightning + kwargs.setdefault("log_every_n_steps", 50) # default for lightning # Setting default kwargs, overriding lightning defaults - kwargs.setdefault('enable_progress_bar', True) - kwargs.setdefault('logger', None) + kwargs.setdefault("enable_progress_bar", True) + kwargs.setdefault("logger", None) super().__init__(**kwargs) @@ -106,27 +113,37 @@ class Trainer(lightning.pytorch.Trainer): if compile is None or sys.platform == "win32": compile = False - self.automatic_batching = automatic_batching if automatic_batching \ - is not None else False + self.automatic_batching = ( + automatic_batching if automatic_batching is not None else False + ) # set attributes self.compile = compile self.solver = solver self.batch_size = batch_size self._move_to_device() self.data_module = None - self._create_datamodule(train_size, test_size, val_size, predict_size, - batch_size, automatic_batching, pin_memory, - num_workers) + self._create_datamodule( + train_size, + test_size, + val_size, + predict_size, + batch_size, + automatic_batching, + pin_memory, + num_workers, + ) # logging self.logging_kwargs = { - 'logger': bool( - kwargs['logger'] is None or kwargs['logger'] is True), - 'sync_dist': bool( - len(self._accelerator_connector._parallel_devices) > 1), - 'on_step': bool(kwargs['log_every_n_steps'] > 0), - 'prog_bar': bool(kwargs['enable_progress_bar']), - 'on_epoch': True + "logger": bool( + kwargs["logger"] is None or kwargs["logger"] is True + ), + "sync_dist": bool( + len(self._accelerator_connector._parallel_devices) > 1 + ), + "on_step": bool(kwargs["log_every_n_steps"] > 0), + "prog_bar": bool(kwargs["enable_progress_bar"]), + "on_epoch": True, } def _move_to_device(self): @@ -136,32 +153,39 @@ class Trainer(lightning.pytorch.Trainer): if hasattr(pb, "unknown_parameters"): for key in pb.unknown_parameters: pb.unknown_parameters[key] = torch.nn.Parameter( - pb.unknown_parameters[key].data.to(device)) + pb.unknown_parameters[key].data.to(device) + ) - def _create_datamodule(self, - train_size, - test_size, - val_size, - predict_size, - batch_size, - automatic_batching, - pin_memory, - num_workers): + def _create_datamodule( + self, + train_size, + test_size, + val_size, + predict_size, + batch_size, + automatic_batching, + pin_memory, + num_workers, + ): """ This method is used here because is resampling is needed during training, there is no need to define to touch the trainer dataloader, just call the method. """ if not self.solver.problem.are_all_domains_discretised: - error_message = '\n'.join([ - f"""{" " * 13} ---> Domain {key} { + error_message = "\n".join( + [ + f"""{" " * 13} ---> Domain {key} { "sampled" if key in self.solver.problem.discretised_domains else - "not sampled"}""" for key in - self.solver.problem.domains.keys() - ]) - raise RuntimeError('Cannot create Trainer if not all conditions ' - 'are sampled. The Trainer got the following:\n' - f'{error_message}') + "not sampled"}""" + for key in self.solver.problem.domains.keys() + ] + ) + raise RuntimeError( + "Cannot create Trainer if not all conditions " + "are sampled. The Trainer got the following:\n" + f"{error_message}" + ) self.data_module = PinaDataModule( self.solver.problem, train_size=train_size, @@ -171,7 +195,8 @@ class Trainer(lightning.pytorch.Trainer): batch_size=batch_size, automatic_batching=automatic_batching, num_workers=num_workers, - pin_memory=pin_memory) + pin_memory=pin_memory, + ) def train(self, **kwargs): """ diff --git a/pina/utils.py b/pina/utils.py index e56abcf..c4cb876 100644 --- a/pina/utils.py +++ b/pina/utils.py @@ -8,10 +8,11 @@ from .label_tensor import LabelTensor def custom_warning_format( - message, category, filename, lineno, file=None, line=None - ): + message, category, filename, lineno, file=None, line=None +): return f"{filename}: {category.__name__}: {message}\n" + def check_consistency(object, object_instance, subclass=False): """Helper function to check object inheritance consistency. Given a specific ``'object'`` we check if the object is @@ -39,6 +40,7 @@ def check_consistency(object, object_instance, subclass=False): except AssertionError: raise ValueError(f"{type(obj).__name__} must be {object_instance}.") + def labelize_forward(forward, input_variables, output_variables): """ Wrapper decorator to allow users to enable or disable the use of @@ -51,6 +53,7 @@ def labelize_forward(forward, input_variables, output_variables): :param output_variables: The problem output variables. :type output_variables: list[str] | tuple[str] """ + def wrapper(x): x = x.extract(input_variables) output = forward(x) @@ -59,8 +62,10 @@ def labelize_forward(forward, input_variables, output_variables): output = output.as_subclass(LabelTensor) output.labels = output_variables return output + return wrapper + def merge_tensors(tensors): # name to be changed if tensors: return reduce(merge_two_tensors, tensors[1:], tensors[0]) @@ -72,8 +77,9 @@ def merge_two_tensors(tensor1, tensor2): n2 = tensor2.shape[0] tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels) - tensor2 = LabelTensor(tensor2.repeat_interleave(n1, dim=0), - labels=tensor2.labels) + tensor2 = LabelTensor( + tensor2.repeat_interleave(n1, dim=0), labels=tensor2.labels + ) return tensor1.append(tensor2) diff --git a/pyproject.toml b/pyproject.toml index 782691d..caf1099 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,9 @@ test = [ "pytest-cov", "scipy" ] +dev = [ + "black @ git+https://github.com/psf/black" +] [project.urls] Homepage = "https://mathlab.github.io/PINA/"