Format Python code with psf/black push (#325)

* 🎨 Format Python code with psf/black
This commit is contained in:
github-actions[bot]
2024-08-12 18:30:46 +02:00
committed by GitHub
parent cce9876751
commit 5445559cb2
5 changed files with 85 additions and 56 deletions

View File

@@ -2,7 +2,7 @@ __all__ = [
"SwitchOptimizer", "SwitchOptimizer",
"R3Refinement", "R3Refinement",
"MetricTracker", "MetricTracker",
"PINAProgressBar" "PINAProgressBar",
] ]
from .optimizer_callbacks import SwitchOptimizer from .optimizer_callbacks import SwitchOptimizer

View File

@@ -6,9 +6,12 @@ import torch
import copy import copy
from pytorch_lightning.callbacks import Callback, TQDMProgressBar from pytorch_lightning.callbacks import Callback, TQDMProgressBar
from lightning.pytorch.callbacks.progress.progress_bar import get_standard_metrics from lightning.pytorch.callbacks.progress.progress_bar import (
get_standard_metrics,
)
from pina.utils import check_consistency from pina.utils import check_consistency
class MetricTracker(Callback): class MetricTracker(Callback):
def __init__(self): def __init__(self):
@@ -68,7 +71,8 @@ class MetricTracker(Callback):
class PINAProgressBar(TQDMProgressBar): class PINAProgressBar(TQDMProgressBar):
BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]" BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]"
def __init__(self, metrics='mean', **kwargs):
def __init__(self, metrics="mean", **kwargs):
""" """
PINA Implementation of a Lightning Callback for enriching the progress PINA Implementation of a Lightning Callback for enriching the progress
bar. bar.
@@ -145,9 +149,13 @@ class PINAProgressBar(TQDMProgressBar):
""" """
# Check if all keys in sort_keys are present in the dictionary # Check if all keys in sort_keys are present in the dictionary
for key in self._sorted_metrics: for key in self._sorted_metrics:
if key not in trainer.solver.problem.conditions.keys() and key != 'mean': if (
key not in trainer.solver.problem.conditions.keys()
and key != "mean"
):
raise KeyError(f"Key '{key}' is not present in the dictionary") raise KeyError(f"Key '{key}' is not present in the dictionary")
# add the loss pedix # add the loss pedix
self._sorted_metrics = [ self._sorted_metrics = [
metric + '_loss' for metric in self._sorted_metrics] metric + "_loss" for metric in self._sorted_metrics
]
return super().on_fit_start(trainer, pl_module) return super().on_fit_start(trainer, pl_module)

View File

@@ -13,7 +13,7 @@ __all__ = [
"FourierFeatureEmbedding", "FourierFeatureEmbedding",
"AVNOBlock", "AVNOBlock",
"LowRankBlock", "LowRankBlock",
"RBFBlock" "RBFBlock",
] ]
from .convolution_2d import ContinuousConvBlock from .convolution_2d import ContinuousConvBlock

View File

@@ -6,54 +6,63 @@ from itertools import combinations_with_replacement
import torch import torch
from ...utils import check_consistency from ...utils import check_consistency
def linear(r): def linear(r):
''' """
Linear radial basis function. Linear radial basis function.
''' """
return -r return -r
def thin_plate_spline(r, eps=1e-7): def thin_plate_spline(r, eps=1e-7):
''' """
Thin plate spline radial basis function. Thin plate spline radial basis function.
''' """
r = torch.clamp(r, min=eps) r = torch.clamp(r, min=eps)
return r**2 * torch.log(r) return r**2 * torch.log(r)
def cubic(r): def cubic(r):
''' """
Cubic radial basis function. Cubic radial basis function.
''' """
return r**3 return r**3
def quintic(r): def quintic(r):
''' """
Quintic radial basis function. Quintic radial basis function.
''' """
return -r**5 return -(r**5)
def multiquadric(r): def multiquadric(r):
''' """
Multiquadric radial basis function. Multiquadric radial basis function.
''' """
return -torch.sqrt(r**2 + 1) return -torch.sqrt(r**2 + 1)
def inverse_multiquadric(r): def inverse_multiquadric(r):
''' """
Inverse multiquadric radial basis function. Inverse multiquadric radial basis function.
''' """
return 1 / torch.sqrt(r**2 + 1) return 1 / torch.sqrt(r**2 + 1)
def inverse_quadratic(r): def inverse_quadratic(r):
''' """
Inverse quadratic radial basis function. Inverse quadratic radial basis function.
''' """
return 1 / (r**2 + 1) return 1 / (r**2 + 1)
def gaussian(r): def gaussian(r):
''' """
Gaussian radial basis function. Gaussian radial basis function.
''' """
return torch.exp(-r**2) return torch.exp(-(r**2))
radial_functions = { radial_functions = {
"linear": linear, "linear": linear,
@@ -63,7 +72,7 @@ radial_functions = {
"multiquadric": multiquadric, "multiquadric": multiquadric,
"inverse_multiquadric": inverse_multiquadric, "inverse_multiquadric": inverse_multiquadric,
"inverse_quadratic": inverse_quadratic, "inverse_quadratic": inverse_quadratic,
"gaussian": gaussian "gaussian": gaussian,
} }
scale_invariant = {"linear", "thin_plate_spline", "cubic", "quintic"} scale_invariant = {"linear", "thin_plate_spline", "cubic", "quintic"}
@@ -73,7 +82,7 @@ min_degree_funcs = {
"linear": 0, "linear": 0,
"thin_plate_spline": 1, "thin_plate_spline": 1,
"cubic": 1, "cubic": 1,
"quintic": 2 "quintic": 2,
} }
@@ -88,6 +97,7 @@ class RBFBlock(torch.nn.Module):
it is inspired from the implementation in `torchrbf. it is inspired from the implementation in `torchrbf.
<https://github.com/ArmanMaesumi/torchrbf>`_ <https://github.com/ArmanMaesumi/torchrbf>`_
""" """
def __init__( def __init__(
self, self,
neighbors=None, neighbors=None,
@@ -207,7 +217,8 @@ class RBFBlock(torch.nn.Module):
if value < min_degree: if value < min_degree:
warnings.warn( warnings.warn(
"`degree` is too small for this kernel. Setting to " "`degree` is too small for this kernel. Setting to "
f"{min_degree}.", UserWarning, f"{min_degree}.",
UserWarning,
) )
self._degree = value self._degree = value
@@ -222,8 +233,9 @@ class RBFBlock(torch.nn.Module):
) )
if isinstance(self.smoothing, (int, float)): if isinstance(self.smoothing, (int, float)):
self.smoothing = torch.full((y.shape[0],), self.smoothing self.smoothing = (
).float().to(y.device) torch.full((y.shape[0],), self.smoothing).float().to(y.device)
)
def fit(self, y, d): def fit(self, y, d):
""" """
@@ -243,15 +255,22 @@ class RBFBlock(torch.nn.Module):
raise NotImplementedError("neighbors currently not supported") raise NotImplementedError("neighbors currently not supported")
powers = RBFBlock.monomial_powers(self.y.shape[1], self.degree).to( powers = RBFBlock.monomial_powers(self.y.shape[1], self.degree).to(
y.device) y.device
)
if powers.shape[0] > nobs: if powers.shape[0] > nobs:
raise ValueError("The data is not compatible with the " raise ValueError(
"requested degree.") "The data is not compatible with the requested degree."
)
if self.neighbors is None: if self.neighbors is None:
self._shift, self._scale, self._coeffs = RBFBlock.solve(self.y, self._shift, self._scale, self._coeffs = RBFBlock.solve(
self.y,
self.d.reshape((self.y.shape[0], -1)), self.d.reshape((self.y.shape[0], -1)),
self.smoothing, self.kernel, self.epsilon, powers) self.smoothing,
self.kernel,
self.epsilon,
powers,
)
self.powers = powers self.powers = powers
@@ -411,8 +430,9 @@ class RBFBlock(torch.nn.Module):
of the interpolator of the interpolator
""" """
lhs, rhs, shift, scale = RBFBlock.build(y, d, smoothing, kernel, lhs, rhs, shift, scale = RBFBlock.build(
epsilon, powers) y, d, smoothing, kernel, epsilon, powers
)
try: try:
coeffs = torch.linalg.solve(lhs, rhs) coeffs = torch.linalg.solve(lhs, rhs)
except RuntimeError as e: except RuntimeError as e:

View File

@@ -196,8 +196,9 @@ class AbstractProblem(metaclass=ABCMeta):
# check consistency location # check consistency location
locations_to_sample = [ locations_to_sample = [
condition for condition in self.conditions condition
if hasattr(self.conditions[condition], 'location') for condition in self.conditions
if hasattr(self.conditions[condition], "location")
] ]
if locations == "all": if locations == "all":
# only locations that can be sampled # only locations that can be sampled