Format Python code with psf/black push (#325)

* 🎨 Format Python code with psf/black
This commit is contained in:
github-actions[bot]
2024-08-12 18:30:46 +02:00
committed by GitHub
parent cce9876751
commit 5445559cb2
5 changed files with 85 additions and 56 deletions

View File

@@ -6,65 +6,74 @@ from itertools import combinations_with_replacement
import torch
from ...utils import check_consistency
def linear(r):
'''
"""
Linear radial basis function.
'''
"""
return -r
def thin_plate_spline(r, eps=1e-7):
'''
"""
Thin plate spline radial basis function.
'''
"""
r = torch.clamp(r, min=eps)
return r**2 * torch.log(r)
def cubic(r):
'''
"""
Cubic radial basis function.
'''
"""
return r**3
def quintic(r):
'''
"""
Quintic radial basis function.
'''
return -r**5
"""
return -(r**5)
def multiquadric(r):
'''
"""
Multiquadric radial basis function.
'''
"""
return -torch.sqrt(r**2 + 1)
def inverse_multiquadric(r):
'''
"""
Inverse multiquadric radial basis function.
'''
return 1/torch.sqrt(r**2 + 1)
"""
return 1 / torch.sqrt(r**2 + 1)
def inverse_quadratic(r):
'''
"""
Inverse quadratic radial basis function.
'''
return 1/(r**2 + 1)
"""
return 1 / (r**2 + 1)
def gaussian(r):
'''
"""
Gaussian radial basis function.
'''
return torch.exp(-r**2)
"""
return torch.exp(-(r**2))
radial_functions = {
"linear": linear,
"thin_plate_spline": thin_plate_spline,
"cubic": cubic,
"quintic": quintic,
"multiquadric": multiquadric,
"inverse_multiquadric": inverse_multiquadric,
"inverse_quadratic": inverse_quadratic,
"gaussian": gaussian
}
"linear": linear,
"thin_plate_spline": thin_plate_spline,
"cubic": cubic,
"quintic": quintic,
"multiquadric": multiquadric,
"inverse_multiquadric": inverse_multiquadric,
"inverse_quadratic": inverse_quadratic,
"gaussian": gaussian,
}
scale_invariant = {"linear", "thin_plate_spline", "cubic", "quintic"}
@@ -73,8 +82,8 @@ min_degree_funcs = {
"linear": 0,
"thin_plate_spline": 1,
"cubic": 1,
"quintic": 2
}
"quintic": 2,
}
class RBFBlock(torch.nn.Module):
@@ -88,6 +97,7 @@ class RBFBlock(torch.nn.Module):
it is inspired from the implementation in `torchrbf.
<https://github.com/ArmanMaesumi/torchrbf>`_
"""
def __init__(
self,
neighbors=None,
@@ -207,7 +217,8 @@ class RBFBlock(torch.nn.Module):
if value < min_degree:
warnings.warn(
"`degree` is too small for this kernel. Setting to "
f"{min_degree}.", UserWarning,
f"{min_degree}.",
UserWarning,
)
self._degree = value
@@ -222,8 +233,9 @@ class RBFBlock(torch.nn.Module):
)
if isinstance(self.smoothing, (int, float)):
self.smoothing = torch.full((y.shape[0],), self.smoothing
).float().to(y.device)
self.smoothing = (
torch.full((y.shape[0],), self.smoothing).float().to(y.device)
)
def fit(self, y, d):
"""
@@ -243,15 +255,22 @@ class RBFBlock(torch.nn.Module):
raise NotImplementedError("neighbors currently not supported")
powers = RBFBlock.monomial_powers(self.y.shape[1], self.degree).to(
y.device)
y.device
)
if powers.shape[0] > nobs:
raise ValueError("The data is not compatible with the "
"requested degree.")
raise ValueError(
"The data is not compatible with the requested degree."
)
if self.neighbors is None:
self._shift, self._scale, self._coeffs = RBFBlock.solve(self.y,
self.d.reshape((self.y.shape[0], -1)),
self.smoothing, self.kernel, self.epsilon, powers)
self._shift, self._scale, self._coeffs = RBFBlock.solve(
self.y,
self.d.reshape((self.y.shape[0], -1)),
self.smoothing,
self.kernel,
self.epsilon,
powers,
)
self.powers = powers
@@ -411,8 +430,9 @@ class RBFBlock(torch.nn.Module):
of the interpolator
"""
lhs, rhs, shift, scale = RBFBlock.build(y, d, smoothing, kernel,
epsilon, powers)
lhs, rhs, shift, scale = RBFBlock.build(
y, d, smoothing, kernel, epsilon, powers
)
try:
coeffs = torch.linalg.solve(lhs, rhs)
except RuntimeError as e: