From 5445559cb236c97419a019e92e6f38b3aaf7f791 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Mon, 12 Aug 2024 18:30:46 +0200 Subject: [PATCH] Format Python code with psf/black push (#325) * :art: Format Python code with psf/black --- pina/callbacks/__init__.py | 4 +- pina/callbacks/processing_callbacks.py | 24 ++++-- pina/model/layers/__init__.py | 2 +- pina/model/layers/rbf_layer.py | 104 +++++++++++++++---------- pina/problem/abstract_problem.py | 7 +- 5 files changed, 85 insertions(+), 56 deletions(-) diff --git a/pina/callbacks/__init__.py b/pina/callbacks/__init__.py index 4ba0271..e1eaf82 100644 --- a/pina/callbacks/__init__.py +++ b/pina/callbacks/__init__.py @@ -2,8 +2,8 @@ __all__ = [ "SwitchOptimizer", "R3Refinement", "MetricTracker", - "PINAProgressBar" - ] + "PINAProgressBar", +] from .optimizer_callbacks import SwitchOptimizer from .adaptive_refinment_callbacks import R3Refinement diff --git a/pina/callbacks/processing_callbacks.py b/pina/callbacks/processing_callbacks.py index c6175b0..a70218e 100644 --- a/pina/callbacks/processing_callbacks.py +++ b/pina/callbacks/processing_callbacks.py @@ -6,9 +6,12 @@ import torch import copy from pytorch_lightning.callbacks import Callback, TQDMProgressBar -from lightning.pytorch.callbacks.progress.progress_bar import get_standard_metrics +from lightning.pytorch.callbacks.progress.progress_bar import ( + get_standard_metrics, +) from pina.utils import check_consistency + class MetricTracker(Callback): def __init__(self): @@ -37,7 +40,7 @@ class MetricTracker(Callback): def on_train_epoch_end(self, trainer, pl_module): """ - Collect and track metrics at the end of each training epoch. + Collect and track metrics at the end of each training epoch. :param trainer: The trainer object managing the training process. :type trainer: pytorch_lightning.Trainer @@ -68,7 +71,8 @@ class MetricTracker(Callback): class PINAProgressBar(TQDMProgressBar): BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]" - def __init__(self, metrics='mean', **kwargs): + + def __init__(self, metrics="mean", **kwargs): """ PINA Implementation of a Lightning Callback for enriching the progress bar. @@ -123,7 +127,7 @@ class PINAProgressBar(TQDMProgressBar): if pbar_metrics: pbar_metrics = { key: pbar_metrics[key] for key in self._sorted_metrics - } + } duplicates = list(standard_metrics.keys() & pbar_metrics.keys()) if duplicates: rank_zero_warn( @@ -133,7 +137,7 @@ class PINAProgressBar(TQDMProgressBar): ) return {**standard_metrics, **pbar_metrics} - + def on_fit_start(self, trainer, pl_module): """ Check that the metrics defined in the initialization are available, @@ -145,9 +149,13 @@ class PINAProgressBar(TQDMProgressBar): """ # Check if all keys in sort_keys are present in the dictionary for key in self._sorted_metrics: - if key not in trainer.solver.problem.conditions.keys() and key != 'mean': + if ( + key not in trainer.solver.problem.conditions.keys() + and key != "mean" + ): raise KeyError(f"Key '{key}' is not present in the dictionary") # add the loss pedix self._sorted_metrics = [ - metric + '_loss' for metric in self._sorted_metrics] - return super().on_fit_start(trainer, pl_module) \ No newline at end of file + metric + "_loss" for metric in self._sorted_metrics + ] + return super().on_fit_start(trainer, pl_module) diff --git a/pina/model/layers/__init__.py b/pina/model/layers/__init__.py index eb12bf0..898ca43 100644 --- a/pina/model/layers/__init__.py +++ b/pina/model/layers/__init__.py @@ -13,7 +13,7 @@ __all__ = [ "FourierFeatureEmbedding", "AVNOBlock", "LowRankBlock", - "RBFBlock" + "RBFBlock", ] from .convolution_2d import ContinuousConvBlock diff --git a/pina/model/layers/rbf_layer.py b/pina/model/layers/rbf_layer.py index ef55406..e088d00 100644 --- a/pina/model/layers/rbf_layer.py +++ b/pina/model/layers/rbf_layer.py @@ -6,65 +6,74 @@ from itertools import combinations_with_replacement import torch from ...utils import check_consistency + def linear(r): - ''' + """ Linear radial basis function. - ''' + """ return -r + def thin_plate_spline(r, eps=1e-7): - ''' + """ Thin plate spline radial basis function. - ''' + """ r = torch.clamp(r, min=eps) return r**2 * torch.log(r) + def cubic(r): - ''' + """ Cubic radial basis function. - ''' + """ return r**3 + def quintic(r): - ''' + """ Quintic radial basis function. - ''' - return -r**5 + """ + return -(r**5) + def multiquadric(r): - ''' + """ Multiquadric radial basis function. - ''' + """ return -torch.sqrt(r**2 + 1) + def inverse_multiquadric(r): - ''' + """ Inverse multiquadric radial basis function. - ''' - return 1/torch.sqrt(r**2 + 1) + """ + return 1 / torch.sqrt(r**2 + 1) + def inverse_quadratic(r): - ''' + """ Inverse quadratic radial basis function. - ''' - return 1/(r**2 + 1) + """ + return 1 / (r**2 + 1) + def gaussian(r): - ''' + """ Gaussian radial basis function. - ''' - return torch.exp(-r**2) + """ + return torch.exp(-(r**2)) + radial_functions = { - "linear": linear, - "thin_plate_spline": thin_plate_spline, - "cubic": cubic, - "quintic": quintic, - "multiquadric": multiquadric, - "inverse_multiquadric": inverse_multiquadric, - "inverse_quadratic": inverse_quadratic, - "gaussian": gaussian - } + "linear": linear, + "thin_plate_spline": thin_plate_spline, + "cubic": cubic, + "quintic": quintic, + "multiquadric": multiquadric, + "inverse_multiquadric": inverse_multiquadric, + "inverse_quadratic": inverse_quadratic, + "gaussian": gaussian, +} scale_invariant = {"linear", "thin_plate_spline", "cubic", "quintic"} @@ -73,8 +82,8 @@ min_degree_funcs = { "linear": 0, "thin_plate_spline": 1, "cubic": 1, - "quintic": 2 - } + "quintic": 2, +} class RBFBlock(torch.nn.Module): @@ -88,6 +97,7 @@ class RBFBlock(torch.nn.Module): it is inspired from the implementation in `torchrbf. `_ """ + def __init__( self, neighbors=None, @@ -207,7 +217,8 @@ class RBFBlock(torch.nn.Module): if value < min_degree: warnings.warn( "`degree` is too small for this kernel. Setting to " - f"{min_degree}.", UserWarning, + f"{min_degree}.", + UserWarning, ) self._degree = value @@ -222,8 +233,9 @@ class RBFBlock(torch.nn.Module): ) if isinstance(self.smoothing, (int, float)): - self.smoothing = torch.full((y.shape[0],), self.smoothing - ).float().to(y.device) + self.smoothing = ( + torch.full((y.shape[0],), self.smoothing).float().to(y.device) + ) def fit(self, y, d): """ @@ -243,15 +255,22 @@ class RBFBlock(torch.nn.Module): raise NotImplementedError("neighbors currently not supported") powers = RBFBlock.monomial_powers(self.y.shape[1], self.degree).to( - y.device) + y.device + ) if powers.shape[0] > nobs: - raise ValueError("The data is not compatible with the " - "requested degree.") + raise ValueError( + "The data is not compatible with the requested degree." + ) if self.neighbors is None: - self._shift, self._scale, self._coeffs = RBFBlock.solve(self.y, - self.d.reshape((self.y.shape[0], -1)), - self.smoothing, self.kernel, self.epsilon, powers) + self._shift, self._scale, self._coeffs = RBFBlock.solve( + self.y, + self.d.reshape((self.y.shape[0], -1)), + self.smoothing, + self.kernel, + self.epsilon, + powers, + ) self.powers = powers @@ -411,8 +430,9 @@ class RBFBlock(torch.nn.Module): of the interpolator """ - lhs, rhs, shift, scale = RBFBlock.build(y, d, smoothing, kernel, - epsilon, powers) + lhs, rhs, shift, scale = RBFBlock.build( + y, d, smoothing, kernel, epsilon, powers + ) try: coeffs = torch.linalg.solve(lhs, rhs) except RuntimeError as e: diff --git a/pina/problem/abstract_problem.py b/pina/problem/abstract_problem.py index 8b98ec9..6e5e317 100644 --- a/pina/problem/abstract_problem.py +++ b/pina/problem/abstract_problem.py @@ -196,9 +196,10 @@ class AbstractProblem(metaclass=ABCMeta): # check consistency location locations_to_sample = [ - condition for condition in self.conditions - if hasattr(self.conditions[condition], 'location') - ] + condition + for condition in self.conditions + if hasattr(self.conditions[condition], "location") + ] if locations == "all": # only locations that can be sampled locations = locations_to_sample