Format Python code with psf/black push (#325)

* 🎨 Format Python code with psf/black
This commit is contained in:
github-actions[bot]
2024-08-12 18:30:46 +02:00
committed by GitHub
parent cce9876751
commit 5445559cb2
5 changed files with 85 additions and 56 deletions

View File

@@ -2,8 +2,8 @@ __all__ = [
"SwitchOptimizer",
"R3Refinement",
"MetricTracker",
"PINAProgressBar"
]
"PINAProgressBar",
]
from .optimizer_callbacks import SwitchOptimizer
from .adaptive_refinment_callbacks import R3Refinement

View File

@@ -6,9 +6,12 @@ import torch
import copy
from pytorch_lightning.callbacks import Callback, TQDMProgressBar
from lightning.pytorch.callbacks.progress.progress_bar import get_standard_metrics
from lightning.pytorch.callbacks.progress.progress_bar import (
get_standard_metrics,
)
from pina.utils import check_consistency
class MetricTracker(Callback):
def __init__(self):
@@ -37,7 +40,7 @@ class MetricTracker(Callback):
def on_train_epoch_end(self, trainer, pl_module):
"""
Collect and track metrics at the end of each training epoch.
Collect and track metrics at the end of each training epoch.
:param trainer: The trainer object managing the training process.
:type trainer: pytorch_lightning.Trainer
@@ -68,7 +71,8 @@ class MetricTracker(Callback):
class PINAProgressBar(TQDMProgressBar):
BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]"
def __init__(self, metrics='mean', **kwargs):
def __init__(self, metrics="mean", **kwargs):
"""
PINA Implementation of a Lightning Callback for enriching the progress
bar.
@@ -123,7 +127,7 @@ class PINAProgressBar(TQDMProgressBar):
if pbar_metrics:
pbar_metrics = {
key: pbar_metrics[key] for key in self._sorted_metrics
}
}
duplicates = list(standard_metrics.keys() & pbar_metrics.keys())
if duplicates:
rank_zero_warn(
@@ -133,7 +137,7 @@ class PINAProgressBar(TQDMProgressBar):
)
return {**standard_metrics, **pbar_metrics}
def on_fit_start(self, trainer, pl_module):
"""
Check that the metrics defined in the initialization are available,
@@ -145,9 +149,13 @@ class PINAProgressBar(TQDMProgressBar):
"""
# Check if all keys in sort_keys are present in the dictionary
for key in self._sorted_metrics:
if key not in trainer.solver.problem.conditions.keys() and key != 'mean':
if (
key not in trainer.solver.problem.conditions.keys()
and key != "mean"
):
raise KeyError(f"Key '{key}' is not present in the dictionary")
# add the loss pedix
self._sorted_metrics = [
metric + '_loss' for metric in self._sorted_metrics]
return super().on_fit_start(trainer, pl_module)
metric + "_loss" for metric in self._sorted_metrics
]
return super().on_fit_start(trainer, pl_module)

View File

@@ -13,7 +13,7 @@ __all__ = [
"FourierFeatureEmbedding",
"AVNOBlock",
"LowRankBlock",
"RBFBlock"
"RBFBlock",
]
from .convolution_2d import ContinuousConvBlock

View File

@@ -6,65 +6,74 @@ from itertools import combinations_with_replacement
import torch
from ...utils import check_consistency
def linear(r):
'''
"""
Linear radial basis function.
'''
"""
return -r
def thin_plate_spline(r, eps=1e-7):
'''
"""
Thin plate spline radial basis function.
'''
"""
r = torch.clamp(r, min=eps)
return r**2 * torch.log(r)
def cubic(r):
'''
"""
Cubic radial basis function.
'''
"""
return r**3
def quintic(r):
'''
"""
Quintic radial basis function.
'''
return -r**5
"""
return -(r**5)
def multiquadric(r):
'''
"""
Multiquadric radial basis function.
'''
"""
return -torch.sqrt(r**2 + 1)
def inverse_multiquadric(r):
'''
"""
Inverse multiquadric radial basis function.
'''
return 1/torch.sqrt(r**2 + 1)
"""
return 1 / torch.sqrt(r**2 + 1)
def inverse_quadratic(r):
'''
"""
Inverse quadratic radial basis function.
'''
return 1/(r**2 + 1)
"""
return 1 / (r**2 + 1)
def gaussian(r):
'''
"""
Gaussian radial basis function.
'''
return torch.exp(-r**2)
"""
return torch.exp(-(r**2))
radial_functions = {
"linear": linear,
"thin_plate_spline": thin_plate_spline,
"cubic": cubic,
"quintic": quintic,
"multiquadric": multiquadric,
"inverse_multiquadric": inverse_multiquadric,
"inverse_quadratic": inverse_quadratic,
"gaussian": gaussian
}
"linear": linear,
"thin_plate_spline": thin_plate_spline,
"cubic": cubic,
"quintic": quintic,
"multiquadric": multiquadric,
"inverse_multiquadric": inverse_multiquadric,
"inverse_quadratic": inverse_quadratic,
"gaussian": gaussian,
}
scale_invariant = {"linear", "thin_plate_spline", "cubic", "quintic"}
@@ -73,8 +82,8 @@ min_degree_funcs = {
"linear": 0,
"thin_plate_spline": 1,
"cubic": 1,
"quintic": 2
}
"quintic": 2,
}
class RBFBlock(torch.nn.Module):
@@ -88,6 +97,7 @@ class RBFBlock(torch.nn.Module):
it is inspired from the implementation in `torchrbf.
<https://github.com/ArmanMaesumi/torchrbf>`_
"""
def __init__(
self,
neighbors=None,
@@ -207,7 +217,8 @@ class RBFBlock(torch.nn.Module):
if value < min_degree:
warnings.warn(
"`degree` is too small for this kernel. Setting to "
f"{min_degree}.", UserWarning,
f"{min_degree}.",
UserWarning,
)
self._degree = value
@@ -222,8 +233,9 @@ class RBFBlock(torch.nn.Module):
)
if isinstance(self.smoothing, (int, float)):
self.smoothing = torch.full((y.shape[0],), self.smoothing
).float().to(y.device)
self.smoothing = (
torch.full((y.shape[0],), self.smoothing).float().to(y.device)
)
def fit(self, y, d):
"""
@@ -243,15 +255,22 @@ class RBFBlock(torch.nn.Module):
raise NotImplementedError("neighbors currently not supported")
powers = RBFBlock.monomial_powers(self.y.shape[1], self.degree).to(
y.device)
y.device
)
if powers.shape[0] > nobs:
raise ValueError("The data is not compatible with the "
"requested degree.")
raise ValueError(
"The data is not compatible with the requested degree."
)
if self.neighbors is None:
self._shift, self._scale, self._coeffs = RBFBlock.solve(self.y,
self.d.reshape((self.y.shape[0], -1)),
self.smoothing, self.kernel, self.epsilon, powers)
self._shift, self._scale, self._coeffs = RBFBlock.solve(
self.y,
self.d.reshape((self.y.shape[0], -1)),
self.smoothing,
self.kernel,
self.epsilon,
powers,
)
self.powers = powers
@@ -411,8 +430,9 @@ class RBFBlock(torch.nn.Module):
of the interpolator
"""
lhs, rhs, shift, scale = RBFBlock.build(y, d, smoothing, kernel,
epsilon, powers)
lhs, rhs, shift, scale = RBFBlock.build(
y, d, smoothing, kernel, epsilon, powers
)
try:
coeffs = torch.linalg.solve(lhs, rhs)
except RuntimeError as e:

View File

@@ -196,9 +196,10 @@ class AbstractProblem(metaclass=ABCMeta):
# check consistency location
locations_to_sample = [
condition for condition in self.conditions
if hasattr(self.conditions[condition], 'location')
]
condition
for condition in self.conditions
if hasattr(self.conditions[condition], "location")
]
if locations == "all":
# only locations that can be sampled
locations = locations_to_sample