Format Python code with psf/black push (#325)
* 🎨 Format Python code with psf/black
This commit is contained in:
committed by
GitHub
parent
cce9876751
commit
5445559cb2
@@ -2,8 +2,8 @@ __all__ = [
|
||||
"SwitchOptimizer",
|
||||
"R3Refinement",
|
||||
"MetricTracker",
|
||||
"PINAProgressBar"
|
||||
]
|
||||
"PINAProgressBar",
|
||||
]
|
||||
|
||||
from .optimizer_callbacks import SwitchOptimizer
|
||||
from .adaptive_refinment_callbacks import R3Refinement
|
||||
|
||||
@@ -6,9 +6,12 @@ import torch
|
||||
import copy
|
||||
|
||||
from pytorch_lightning.callbacks import Callback, TQDMProgressBar
|
||||
from lightning.pytorch.callbacks.progress.progress_bar import get_standard_metrics
|
||||
from lightning.pytorch.callbacks.progress.progress_bar import (
|
||||
get_standard_metrics,
|
||||
)
|
||||
from pina.utils import check_consistency
|
||||
|
||||
|
||||
class MetricTracker(Callback):
|
||||
|
||||
def __init__(self):
|
||||
@@ -37,7 +40,7 @@ class MetricTracker(Callback):
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
"""
|
||||
Collect and track metrics at the end of each training epoch.
|
||||
Collect and track metrics at the end of each training epoch.
|
||||
|
||||
:param trainer: The trainer object managing the training process.
|
||||
:type trainer: pytorch_lightning.Trainer
|
||||
@@ -68,7 +71,8 @@ class MetricTracker(Callback):
|
||||
class PINAProgressBar(TQDMProgressBar):
|
||||
|
||||
BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]"
|
||||
def __init__(self, metrics='mean', **kwargs):
|
||||
|
||||
def __init__(self, metrics="mean", **kwargs):
|
||||
"""
|
||||
PINA Implementation of a Lightning Callback for enriching the progress
|
||||
bar.
|
||||
@@ -123,7 +127,7 @@ class PINAProgressBar(TQDMProgressBar):
|
||||
if pbar_metrics:
|
||||
pbar_metrics = {
|
||||
key: pbar_metrics[key] for key in self._sorted_metrics
|
||||
}
|
||||
}
|
||||
duplicates = list(standard_metrics.keys() & pbar_metrics.keys())
|
||||
if duplicates:
|
||||
rank_zero_warn(
|
||||
@@ -133,7 +137,7 @@ class PINAProgressBar(TQDMProgressBar):
|
||||
)
|
||||
|
||||
return {**standard_metrics, **pbar_metrics}
|
||||
|
||||
|
||||
def on_fit_start(self, trainer, pl_module):
|
||||
"""
|
||||
Check that the metrics defined in the initialization are available,
|
||||
@@ -145,9 +149,13 @@ class PINAProgressBar(TQDMProgressBar):
|
||||
"""
|
||||
# Check if all keys in sort_keys are present in the dictionary
|
||||
for key in self._sorted_metrics:
|
||||
if key not in trainer.solver.problem.conditions.keys() and key != 'mean':
|
||||
if (
|
||||
key not in trainer.solver.problem.conditions.keys()
|
||||
and key != "mean"
|
||||
):
|
||||
raise KeyError(f"Key '{key}' is not present in the dictionary")
|
||||
# add the loss pedix
|
||||
self._sorted_metrics = [
|
||||
metric + '_loss' for metric in self._sorted_metrics]
|
||||
return super().on_fit_start(trainer, pl_module)
|
||||
metric + "_loss" for metric in self._sorted_metrics
|
||||
]
|
||||
return super().on_fit_start(trainer, pl_module)
|
||||
|
||||
@@ -13,7 +13,7 @@ __all__ = [
|
||||
"FourierFeatureEmbedding",
|
||||
"AVNOBlock",
|
||||
"LowRankBlock",
|
||||
"RBFBlock"
|
||||
"RBFBlock",
|
||||
]
|
||||
|
||||
from .convolution_2d import ContinuousConvBlock
|
||||
|
||||
@@ -6,65 +6,74 @@ from itertools import combinations_with_replacement
|
||||
import torch
|
||||
from ...utils import check_consistency
|
||||
|
||||
|
||||
def linear(r):
|
||||
'''
|
||||
"""
|
||||
Linear radial basis function.
|
||||
'''
|
||||
"""
|
||||
return -r
|
||||
|
||||
|
||||
def thin_plate_spline(r, eps=1e-7):
|
||||
'''
|
||||
"""
|
||||
Thin plate spline radial basis function.
|
||||
'''
|
||||
"""
|
||||
r = torch.clamp(r, min=eps)
|
||||
return r**2 * torch.log(r)
|
||||
|
||||
|
||||
def cubic(r):
|
||||
'''
|
||||
"""
|
||||
Cubic radial basis function.
|
||||
'''
|
||||
"""
|
||||
return r**3
|
||||
|
||||
|
||||
def quintic(r):
|
||||
'''
|
||||
"""
|
||||
Quintic radial basis function.
|
||||
'''
|
||||
return -r**5
|
||||
"""
|
||||
return -(r**5)
|
||||
|
||||
|
||||
def multiquadric(r):
|
||||
'''
|
||||
"""
|
||||
Multiquadric radial basis function.
|
||||
'''
|
||||
"""
|
||||
return -torch.sqrt(r**2 + 1)
|
||||
|
||||
|
||||
def inverse_multiquadric(r):
|
||||
'''
|
||||
"""
|
||||
Inverse multiquadric radial basis function.
|
||||
'''
|
||||
return 1/torch.sqrt(r**2 + 1)
|
||||
"""
|
||||
return 1 / torch.sqrt(r**2 + 1)
|
||||
|
||||
|
||||
def inverse_quadratic(r):
|
||||
'''
|
||||
"""
|
||||
Inverse quadratic radial basis function.
|
||||
'''
|
||||
return 1/(r**2 + 1)
|
||||
"""
|
||||
return 1 / (r**2 + 1)
|
||||
|
||||
|
||||
def gaussian(r):
|
||||
'''
|
||||
"""
|
||||
Gaussian radial basis function.
|
||||
'''
|
||||
return torch.exp(-r**2)
|
||||
"""
|
||||
return torch.exp(-(r**2))
|
||||
|
||||
|
||||
radial_functions = {
|
||||
"linear": linear,
|
||||
"thin_plate_spline": thin_plate_spline,
|
||||
"cubic": cubic,
|
||||
"quintic": quintic,
|
||||
"multiquadric": multiquadric,
|
||||
"inverse_multiquadric": inverse_multiquadric,
|
||||
"inverse_quadratic": inverse_quadratic,
|
||||
"gaussian": gaussian
|
||||
}
|
||||
"linear": linear,
|
||||
"thin_plate_spline": thin_plate_spline,
|
||||
"cubic": cubic,
|
||||
"quintic": quintic,
|
||||
"multiquadric": multiquadric,
|
||||
"inverse_multiquadric": inverse_multiquadric,
|
||||
"inverse_quadratic": inverse_quadratic,
|
||||
"gaussian": gaussian,
|
||||
}
|
||||
|
||||
scale_invariant = {"linear", "thin_plate_spline", "cubic", "quintic"}
|
||||
|
||||
@@ -73,8 +82,8 @@ min_degree_funcs = {
|
||||
"linear": 0,
|
||||
"thin_plate_spline": 1,
|
||||
"cubic": 1,
|
||||
"quintic": 2
|
||||
}
|
||||
"quintic": 2,
|
||||
}
|
||||
|
||||
|
||||
class RBFBlock(torch.nn.Module):
|
||||
@@ -88,6 +97,7 @@ class RBFBlock(torch.nn.Module):
|
||||
it is inspired from the implementation in `torchrbf.
|
||||
<https://github.com/ArmanMaesumi/torchrbf>`_
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
neighbors=None,
|
||||
@@ -207,7 +217,8 @@ class RBFBlock(torch.nn.Module):
|
||||
if value < min_degree:
|
||||
warnings.warn(
|
||||
"`degree` is too small for this kernel. Setting to "
|
||||
f"{min_degree}.", UserWarning,
|
||||
f"{min_degree}.",
|
||||
UserWarning,
|
||||
)
|
||||
self._degree = value
|
||||
|
||||
@@ -222,8 +233,9 @@ class RBFBlock(torch.nn.Module):
|
||||
)
|
||||
|
||||
if isinstance(self.smoothing, (int, float)):
|
||||
self.smoothing = torch.full((y.shape[0],), self.smoothing
|
||||
).float().to(y.device)
|
||||
self.smoothing = (
|
||||
torch.full((y.shape[0],), self.smoothing).float().to(y.device)
|
||||
)
|
||||
|
||||
def fit(self, y, d):
|
||||
"""
|
||||
@@ -243,15 +255,22 @@ class RBFBlock(torch.nn.Module):
|
||||
raise NotImplementedError("neighbors currently not supported")
|
||||
|
||||
powers = RBFBlock.monomial_powers(self.y.shape[1], self.degree).to(
|
||||
y.device)
|
||||
y.device
|
||||
)
|
||||
if powers.shape[0] > nobs:
|
||||
raise ValueError("The data is not compatible with the "
|
||||
"requested degree.")
|
||||
raise ValueError(
|
||||
"The data is not compatible with the requested degree."
|
||||
)
|
||||
|
||||
if self.neighbors is None:
|
||||
self._shift, self._scale, self._coeffs = RBFBlock.solve(self.y,
|
||||
self.d.reshape((self.y.shape[0], -1)),
|
||||
self.smoothing, self.kernel, self.epsilon, powers)
|
||||
self._shift, self._scale, self._coeffs = RBFBlock.solve(
|
||||
self.y,
|
||||
self.d.reshape((self.y.shape[0], -1)),
|
||||
self.smoothing,
|
||||
self.kernel,
|
||||
self.epsilon,
|
||||
powers,
|
||||
)
|
||||
|
||||
self.powers = powers
|
||||
|
||||
@@ -411,8 +430,9 @@ class RBFBlock(torch.nn.Module):
|
||||
of the interpolator
|
||||
"""
|
||||
|
||||
lhs, rhs, shift, scale = RBFBlock.build(y, d, smoothing, kernel,
|
||||
epsilon, powers)
|
||||
lhs, rhs, shift, scale = RBFBlock.build(
|
||||
y, d, smoothing, kernel, epsilon, powers
|
||||
)
|
||||
try:
|
||||
coeffs = torch.linalg.solve(lhs, rhs)
|
||||
except RuntimeError as e:
|
||||
|
||||
@@ -196,9 +196,10 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
|
||||
# check consistency location
|
||||
locations_to_sample = [
|
||||
condition for condition in self.conditions
|
||||
if hasattr(self.conditions[condition], 'location')
|
||||
]
|
||||
condition
|
||||
for condition in self.conditions
|
||||
if hasattr(self.conditions[condition], "location")
|
||||
]
|
||||
if locations == "all":
|
||||
# only locations that can be sampled
|
||||
locations = locations_to_sample
|
||||
|
||||
Reference in New Issue
Block a user