Rename files
This commit is contained in:
committed by
Nicola Demo
parent
886bd23fdb
commit
ff43a7492b
453
pina/model/block/rbf_block.py
Normal file
453
pina/model/block/rbf_block.py
Normal file
@@ -0,0 +1,453 @@
|
||||
"""Module for Radial Basis Function Interpolation layer."""
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from itertools import combinations_with_replacement
|
||||
import torch
|
||||
from ...utils import check_consistency
|
||||
|
||||
|
||||
def linear(r):
|
||||
"""
|
||||
Linear radial basis function.
|
||||
"""
|
||||
return -r
|
||||
|
||||
|
||||
def thin_plate_spline(r, eps=1e-7):
|
||||
"""
|
||||
Thin plate spline radial basis function.
|
||||
"""
|
||||
r = torch.clamp(r, min=eps)
|
||||
return r**2 * torch.log(r)
|
||||
|
||||
|
||||
def cubic(r):
|
||||
"""
|
||||
Cubic radial basis function.
|
||||
"""
|
||||
return r**3
|
||||
|
||||
|
||||
def quintic(r):
|
||||
"""
|
||||
Quintic radial basis function.
|
||||
"""
|
||||
return -(r**5)
|
||||
|
||||
|
||||
def multiquadric(r):
|
||||
"""
|
||||
Multiquadric radial basis function.
|
||||
"""
|
||||
return -torch.sqrt(r**2 + 1)
|
||||
|
||||
|
||||
def inverse_multiquadric(r):
|
||||
"""
|
||||
Inverse multiquadric radial basis function.
|
||||
"""
|
||||
return 1 / torch.sqrt(r**2 + 1)
|
||||
|
||||
|
||||
def inverse_quadratic(r):
|
||||
"""
|
||||
Inverse quadratic radial basis function.
|
||||
"""
|
||||
return 1 / (r**2 + 1)
|
||||
|
||||
|
||||
def gaussian(r):
|
||||
"""
|
||||
Gaussian radial basis function.
|
||||
"""
|
||||
return torch.exp(-(r**2))
|
||||
|
||||
|
||||
radial_functions = {
|
||||
"linear": linear,
|
||||
"thin_plate_spline": thin_plate_spline,
|
||||
"cubic": cubic,
|
||||
"quintic": quintic,
|
||||
"multiquadric": multiquadric,
|
||||
"inverse_multiquadric": inverse_multiquadric,
|
||||
"inverse_quadratic": inverse_quadratic,
|
||||
"gaussian": gaussian,
|
||||
}
|
||||
|
||||
scale_invariant = {"linear", "thin_plate_spline", "cubic", "quintic"}
|
||||
|
||||
min_degree_funcs = {
|
||||
"multiquadric": 0,
|
||||
"linear": 0,
|
||||
"thin_plate_spline": 1,
|
||||
"cubic": 1,
|
||||
"quintic": 2,
|
||||
}
|
||||
|
||||
|
||||
class RBFBlock(torch.nn.Module):
|
||||
"""
|
||||
Radial Basis Function (RBF) interpolation layer. It need to be fitted with
|
||||
the data with the method :meth:`fit`, before it can be used to interpolate
|
||||
new points. The layer is not trainable.
|
||||
|
||||
.. note::
|
||||
It reproduces the implementation of ``scipy.interpolate.RBFBlock`` and
|
||||
it is inspired from the implementation in `torchrbf.
|
||||
<https://github.com/ArmanMaesumi/torchrbf>`_
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
neighbors=None,
|
||||
smoothing=0.0,
|
||||
kernel="thin_plate_spline",
|
||||
epsilon=None,
|
||||
degree=None,
|
||||
):
|
||||
"""
|
||||
:param int neighbors: Number of neighbors to use for the
|
||||
interpolation.
|
||||
If ``None``, use all data points.
|
||||
:param float smoothing: Smoothing parameter for the interpolation.
|
||||
if 0.0, the interpolation is exact and no smoothing is applied.
|
||||
:param str kernel: Radial basis function to use. Must be one of
|
||||
``linear``, ``thin_plate_spline``, ``cubic``, ``quintic``,
|
||||
``multiquadric``, ``inverse_multiquadric``, ``inverse_quadratic``,
|
||||
or ``gaussian``.
|
||||
:param float epsilon: Shape parameter that scaled the input to
|
||||
the RBF. This defaults to 1 for kernels in ``scale_invariant``
|
||||
dictionary, and must be specified for other kernels.
|
||||
:param int degree: Degree of the added polynomial.
|
||||
For some kernels, there exists a minimum degree of the polynomial
|
||||
such that the RBF is well-posed. Those minimum degrees are specified
|
||||
in the `min_degree_funcs` dictionary above. If `degree` is less than
|
||||
the minimum degree, a warning is raised and the degree is set to the
|
||||
minimum value.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
check_consistency(neighbors, (int, type(None)))
|
||||
check_consistency(smoothing, (int, float, torch.Tensor))
|
||||
check_consistency(kernel, str)
|
||||
check_consistency(epsilon, (float, type(None)))
|
||||
check_consistency(degree, (int, type(None)))
|
||||
|
||||
self.neighbors = neighbors
|
||||
self.smoothing = smoothing
|
||||
self.kernel = kernel
|
||||
self.epsilon = epsilon
|
||||
self.degree = degree
|
||||
self.powers = None
|
||||
# initialize data points and values
|
||||
self.y = None
|
||||
self.d = None
|
||||
# initialize attributes for the fitted model
|
||||
self._shift = None
|
||||
self._scale = None
|
||||
self._coeffs = None
|
||||
|
||||
@property
|
||||
def smoothing(self):
|
||||
"""
|
||||
Smoothing parameter for the interpolation.
|
||||
|
||||
:rtype: float
|
||||
"""
|
||||
return self._smoothing
|
||||
|
||||
@smoothing.setter
|
||||
def smoothing(self, value):
|
||||
self._smoothing = value
|
||||
|
||||
@property
|
||||
def kernel(self):
|
||||
"""
|
||||
Radial basis function to use.
|
||||
|
||||
:rtype: str
|
||||
"""
|
||||
return self._kernel
|
||||
|
||||
@kernel.setter
|
||||
def kernel(self, value):
|
||||
if value not in radial_functions:
|
||||
raise ValueError(f"Unknown kernel: {value}")
|
||||
self._kernel = value.lower()
|
||||
|
||||
@property
|
||||
def epsilon(self):
|
||||
"""
|
||||
Shape parameter that scaled the input to the RBF.
|
||||
|
||||
:rtype: float
|
||||
"""
|
||||
return self._epsilon
|
||||
|
||||
@epsilon.setter
|
||||
def epsilon(self, value):
|
||||
if value is None:
|
||||
if self.kernel in scale_invariant:
|
||||
value = 1.0
|
||||
else:
|
||||
raise ValueError("Must specify `epsilon` for this kernel.")
|
||||
else:
|
||||
value = float(value)
|
||||
self._epsilon = value
|
||||
|
||||
@property
|
||||
def degree(self):
|
||||
"""
|
||||
Degree of the added polynomial.
|
||||
|
||||
:rtype: int
|
||||
"""
|
||||
return self._degree
|
||||
|
||||
@degree.setter
|
||||
def degree(self, value):
|
||||
min_degree = min_degree_funcs.get(self.kernel, -1)
|
||||
if value is None:
|
||||
value = max(min_degree, 0)
|
||||
else:
|
||||
value = int(value)
|
||||
if value < -1:
|
||||
raise ValueError("`degree` must be at least -1.")
|
||||
if value < min_degree:
|
||||
warnings.warn(
|
||||
"`degree` is too small for this kernel. Setting to "
|
||||
f"{min_degree}.",
|
||||
UserWarning,
|
||||
)
|
||||
self._degree = value
|
||||
|
||||
def _check_data(self, y, d):
|
||||
if y.ndim != 2:
|
||||
raise ValueError("y must be a 2-dimensional tensor.")
|
||||
|
||||
if d.shape[0] != y.shape[0]:
|
||||
raise ValueError(
|
||||
"The first dim of d must have the same length as "
|
||||
"the first dim of y."
|
||||
)
|
||||
|
||||
if isinstance(self.smoothing, (int, float)):
|
||||
self.smoothing = (
|
||||
torch.full((y.shape[0],), self.smoothing).float().to(y.device)
|
||||
)
|
||||
|
||||
def fit(self, y, d):
|
||||
"""
|
||||
Fit the RBF interpolator to the data.
|
||||
|
||||
:param torch.Tensor y: (n, d) tensor of data points.
|
||||
:param torch.Tensor d: (n, m) tensor of data values.
|
||||
"""
|
||||
self._check_data(y, d)
|
||||
|
||||
self.y = y
|
||||
self.d = d
|
||||
|
||||
if self.neighbors is None:
|
||||
nobs = self.y.shape[0]
|
||||
else:
|
||||
raise NotImplementedError("neighbors currently not supported")
|
||||
|
||||
powers = RBFBlock.monomial_powers(self.y.shape[1], self.degree).to(
|
||||
y.device
|
||||
)
|
||||
if powers.shape[0] > nobs:
|
||||
raise ValueError(
|
||||
"The data is not compatible with the requested degree."
|
||||
)
|
||||
|
||||
if self.neighbors is None:
|
||||
self._shift, self._scale, self._coeffs = RBFBlock.solve(
|
||||
self.y,
|
||||
self.d.reshape((self.y.shape[0], -1)),
|
||||
self.smoothing,
|
||||
self.kernel,
|
||||
self.epsilon,
|
||||
powers,
|
||||
)
|
||||
|
||||
self.powers = powers
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Returns the interpolated data at the given points `x`.
|
||||
|
||||
:param torch.Tensor x: `(n, d)` tensor of points at which
|
||||
to query the interpolator
|
||||
|
||||
:rtype: `(n, m)` torch.Tensor of interpolated data.
|
||||
"""
|
||||
if x.ndim != 2:
|
||||
raise ValueError("`x` must be a 2-dimensional tensor.")
|
||||
|
||||
nx, ndim = x.shape
|
||||
if ndim != self.y.shape[1]:
|
||||
raise ValueError(
|
||||
"Expected the second dim of `x` to have length "
|
||||
f"{self.y.shape[1]}."
|
||||
)
|
||||
|
||||
kernel_func = radial_functions[self.kernel]
|
||||
|
||||
yeps = self.y * self.epsilon
|
||||
xeps = x * self.epsilon
|
||||
xhat = (x - self._shift) / self._scale
|
||||
|
||||
kv = RBFBlock.kernel_vector(xeps, yeps, kernel_func)
|
||||
p = RBFBlock.polynomial_matrix(xhat, self.powers)
|
||||
vec = torch.cat([kv, p], dim=1)
|
||||
out = torch.matmul(vec, self._coeffs)
|
||||
out = out.reshape((nx,) + self.d.shape[1:])
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def kernel_vector(x, y, kernel_func):
|
||||
"""
|
||||
Evaluate radial functions with centers `y` for all points in `x`.
|
||||
|
||||
:param torch.Tensor x: `(n, d)` tensor of points.
|
||||
:param torch.Tensor y: `(m, d)` tensor of centers.
|
||||
:param str kernel_func: Radial basis function to use.
|
||||
|
||||
:rtype: `(n, m)` torch.Tensor of radial function values.
|
||||
"""
|
||||
return kernel_func(torch.cdist(x, y))
|
||||
|
||||
@staticmethod
|
||||
def polynomial_matrix(x, powers):
|
||||
"""
|
||||
Evaluate monomials at `x` with given `powers`.
|
||||
|
||||
:param torch.Tensor x: `(n, d)` tensor of points.
|
||||
:param torch.Tensor powers: `(r, d)` tensor of powers for each monomial.
|
||||
|
||||
:rtype: `(n, r)` torch.Tensor of monomial values.
|
||||
"""
|
||||
x_ = torch.repeat_interleave(x, repeats=powers.shape[0], dim=0)
|
||||
powers_ = powers.repeat(x.shape[0], 1)
|
||||
return torch.prod(x_**powers_, dim=1).view(x.shape[0], powers.shape[0])
|
||||
|
||||
@staticmethod
|
||||
def kernel_matrix(x, kernel_func):
|
||||
"""
|
||||
Returns radial function values for all pairs of points in `x`.
|
||||
|
||||
:param torch.Tensor x: `(n, d`) tensor of points.
|
||||
:param str kernel_func: Radial basis function to use.
|
||||
|
||||
:rtype: `(n, n`) torch.Tensor of radial function values.
|
||||
"""
|
||||
return kernel_func(torch.cdist(x, x))
|
||||
|
||||
@staticmethod
|
||||
def monomial_powers(ndim, degree):
|
||||
"""
|
||||
Return the powers for each monomial in a polynomial.
|
||||
|
||||
:param int ndim: Number of variables in the polynomial.
|
||||
:param int degree: Degree of the polynomial.
|
||||
|
||||
:rtype: `(nmonos, ndim)` torch.Tensor where each row contains the powers
|
||||
for each variable in a monomial.
|
||||
|
||||
"""
|
||||
nmonos = math.comb(degree + ndim, ndim)
|
||||
out = torch.zeros((nmonos, ndim), dtype=torch.int32)
|
||||
count = 0
|
||||
for deg in range(degree + 1):
|
||||
for mono in combinations_with_replacement(range(ndim), deg):
|
||||
for var in mono:
|
||||
out[count, var] += 1
|
||||
count += 1
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def build(y, d, smoothing, kernel, epsilon, powers):
|
||||
"""
|
||||
Build the RBF linear system.
|
||||
|
||||
:param torch.Tensor y: (n, d) tensor of data points.
|
||||
:param torch.Tensor d: (n, m) tensor of data values.
|
||||
:param torch.Tensor smoothing: (n,) tensor of smoothing parameters.
|
||||
:param str kernel: Radial basis function to use.
|
||||
:param float epsilon: Shape parameter that scaled the input to the RBF.
|
||||
:param torch.Tensor powers: (r, d) tensor of powers for each monomial.
|
||||
|
||||
:rtype: (lhs, rhs, shift, scale) where `lhs` and `rhs` are the
|
||||
left-hand side and right-hand side of the linear system, and
|
||||
`shift` and `scale` are the shift and scale parameters.
|
||||
"""
|
||||
p = d.shape[0]
|
||||
s = d.shape[1]
|
||||
r = powers.shape[0]
|
||||
kernel_func = radial_functions[kernel]
|
||||
|
||||
mins = torch.min(y, dim=0).values
|
||||
maxs = torch.max(y, dim=0).values
|
||||
shift = (maxs + mins) / 2
|
||||
scale = (maxs - mins) / 2
|
||||
|
||||
scale[scale == 0.0] = 1.0
|
||||
|
||||
yeps = y * epsilon
|
||||
yhat = (y - shift) / scale
|
||||
|
||||
lhs = torch.empty((p + r, p + r), device=d.device).float()
|
||||
lhs[:p, :p] = RBFBlock.kernel_matrix(yeps, kernel_func)
|
||||
lhs[:p, p:] = RBFBlock.polynomial_matrix(yhat, powers)
|
||||
lhs[p:, :p] = lhs[:p, p:].T
|
||||
lhs[p:, p:] = 0.0
|
||||
lhs[:p, :p] += torch.diag(smoothing)
|
||||
|
||||
rhs = torch.empty((r + p, s), device=d.device).float()
|
||||
rhs[:p] = d
|
||||
rhs[p:] = 0.0
|
||||
return lhs, rhs, shift, scale
|
||||
|
||||
@staticmethod
|
||||
def solve(y, d, smoothing, kernel, epsilon, powers):
|
||||
"""
|
||||
Build then solve the RBF linear system.
|
||||
|
||||
:param torch.Tensor y: (n, d) tensor of data points.
|
||||
:param torch.Tensor d: (n, m) tensor of data values.
|
||||
:param torch.Tensor smoothing: (n,) tensor of smoothing parameters.
|
||||
|
||||
:param str kernel: Radial basis function to use.
|
||||
:param float epsilon: Shape parameter that scaled the input to the RBF.
|
||||
:param torch.Tensor powers: (r, d) tensor of powers for each monomial.
|
||||
|
||||
:raises ValueError: If the linear system is singular.
|
||||
|
||||
:rtype: (shift, scale, coeffs) where `shift` and `scale` are the
|
||||
shift and scale parameters, and `coeffs` are the coefficients
|
||||
of the interpolator
|
||||
"""
|
||||
|
||||
lhs, rhs, shift, scale = RBFBlock.build(
|
||||
y, d, smoothing, kernel, epsilon, powers
|
||||
)
|
||||
try:
|
||||
coeffs = torch.linalg.solve(lhs, rhs)
|
||||
except RuntimeError as e:
|
||||
msg = "Singular matrix."
|
||||
nmonos = powers.shape[0]
|
||||
if nmonos > 0:
|
||||
pmat = RBFBlock.polynomial_matrix((y - shift) / scale, powers)
|
||||
rank = torch.linalg.matrix_rank(pmat)
|
||||
if rank < nmonos:
|
||||
msg = (
|
||||
"Singular matrix. The matrix of monomials evaluated at "
|
||||
"the data point coordinates does not have full column "
|
||||
f"rank ({rank}/{nmonos})."
|
||||
)
|
||||
|
||||
raise ValueError(msg) from e
|
||||
|
||||
return shift, scale, coeffs
|
||||
Reference in New Issue
Block a user