fix doc model part 2
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
"""Module for Radial Basis Function Interpolation layer."""
|
||||
"""Module for the Radial Basis Function Interpolation layer."""
|
||||
|
||||
import math
|
||||
import warnings
|
||||
@@ -10,6 +10,10 @@ from ...utils import check_consistency
|
||||
def linear(r):
|
||||
"""
|
||||
Linear radial basis function.
|
||||
|
||||
:param torch.Tensor r: Distance between points.
|
||||
:return: The linear radial basis function.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return -r
|
||||
|
||||
@@ -17,6 +21,11 @@ def linear(r):
|
||||
def thin_plate_spline(r, eps=1e-7):
|
||||
"""
|
||||
Thin plate spline radial basis function.
|
||||
|
||||
:param torch.Tensor r: Distance between points.
|
||||
:param float eps: Small value to avoid log(0).
|
||||
:return: The thin plate spline radial basis function.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
r = torch.clamp(r, min=eps)
|
||||
return r**2 * torch.log(r)
|
||||
@@ -25,6 +34,10 @@ def thin_plate_spline(r, eps=1e-7):
|
||||
def cubic(r):
|
||||
"""
|
||||
Cubic radial basis function.
|
||||
|
||||
:param torch.Tensor r: Distance between points.
|
||||
:return: The cubic radial basis function.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return r**3
|
||||
|
||||
@@ -32,6 +45,10 @@ def cubic(r):
|
||||
def quintic(r):
|
||||
"""
|
||||
Quintic radial basis function.
|
||||
|
||||
:param torch.Tensor r: Distance between points.
|
||||
:return: The quintic radial basis function.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return -(r**5)
|
||||
|
||||
@@ -39,6 +56,10 @@ def quintic(r):
|
||||
def multiquadric(r):
|
||||
"""
|
||||
Multiquadric radial basis function.
|
||||
|
||||
:param torch.Tensor r: Distance between points.
|
||||
:return: The multiquadric radial basis function.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return -torch.sqrt(r**2 + 1)
|
||||
|
||||
@@ -46,6 +67,10 @@ def multiquadric(r):
|
||||
def inverse_multiquadric(r):
|
||||
"""
|
||||
Inverse multiquadric radial basis function.
|
||||
|
||||
:param torch.Tensor r: Distance between points.
|
||||
:return: The inverse multiquadric radial basis function.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return 1 / torch.sqrt(r**2 + 1)
|
||||
|
||||
@@ -53,6 +78,10 @@ def inverse_multiquadric(r):
|
||||
def inverse_quadratic(r):
|
||||
"""
|
||||
Inverse quadratic radial basis function.
|
||||
|
||||
:param torch.Tensor r: Distance between points.
|
||||
:return: The inverse quadratic radial basis function.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return 1 / (r**2 + 1)
|
||||
|
||||
@@ -60,6 +89,10 @@ def inverse_quadratic(r):
|
||||
def gaussian(r):
|
||||
"""
|
||||
Gaussian radial basis function.
|
||||
|
||||
:param torch.Tensor r: Distance between points.
|
||||
:return: The gaussian radial basis function.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return torch.exp(-(r**2))
|
||||
|
||||
@@ -88,13 +121,14 @@ min_degree_funcs = {
|
||||
|
||||
class RBFBlock(torch.nn.Module):
|
||||
"""
|
||||
Radial Basis Function (RBF) interpolation layer. It need to be fitted with
|
||||
the data with the method :meth:`fit`, before it can be used to interpolate
|
||||
new points. The layer is not trainable.
|
||||
Radial Basis Function (RBF) interpolation layer.
|
||||
|
||||
The user needs to fit the model with the data, before using it to
|
||||
interpolate new points. The layer is not trainable.
|
||||
|
||||
.. note::
|
||||
It reproduces the implementation of ``scipy.interpolate.RBFBlock`` and
|
||||
it is inspired from the implementation in `torchrbf.
|
||||
It reproduces the implementation of :class:`scipy.interpolate.RBFBlock`
|
||||
and it is inspired from the implementation in `torchrbf.
|
||||
<https://github.com/ArmanMaesumi/torchrbf>`_
|
||||
"""
|
||||
|
||||
@@ -107,24 +141,25 @@ class RBFBlock(torch.nn.Module):
|
||||
degree=None,
|
||||
):
|
||||
"""
|
||||
:param int neighbors: Number of neighbors to use for the
|
||||
interpolation.
|
||||
If ``None``, use all data points.
|
||||
:param float smoothing: Smoothing parameter for the interpolation.
|
||||
if 0.0, the interpolation is exact and no smoothing is applied.
|
||||
:param str kernel: Radial basis function to use. Must be one of
|
||||
``linear``, ``thin_plate_spline``, ``cubic``, ``quintic``,
|
||||
``multiquadric``, ``inverse_multiquadric``, ``inverse_quadratic``,
|
||||
or ``gaussian``.
|
||||
:param float epsilon: Shape parameter that scaled the input to
|
||||
the RBF. This defaults to 1 for kernels in ``scale_invariant``
|
||||
dictionary, and must be specified for other kernels.
|
||||
:param int degree: Degree of the added polynomial.
|
||||
For some kernels, there exists a minimum degree of the polynomial
|
||||
such that the RBF is well-posed. Those minimum degrees are specified
|
||||
in the `min_degree_funcs` dictionary above. If `degree` is less than
|
||||
the minimum degree, a warning is raised and the degree is set to the
|
||||
minimum value.
|
||||
Initialization of the :class:`RBFBlock` class.
|
||||
|
||||
:param int neighbors: The number of neighbors used for interpolation.
|
||||
If ``None``, all data are used.
|
||||
:param float smoothing: The moothing parameter for the interpolation.
|
||||
If ``0.0``, the interpolation is exact and no smoothing is applied.
|
||||
:param str kernel: The radial basis function to use.
|
||||
The available kernels are: ``linear``, ``thin_plate_spline``,
|
||||
``cubic``, ``quintic``, ``multiquadric``, ``inverse_multiquadric``,
|
||||
``inverse_quadratic``, or ``gaussian``.
|
||||
:param float epsilon: The shape parameter that scales the input to the
|
||||
RBF. Default is ``1`` for kernels in the ``scale_invariant``
|
||||
dictionary, while it must be specified for other kernels.
|
||||
:param int degree: The degree of the polynomial. Some kernels require a
|
||||
minimum degree of the polynomial to ensure that the RBF is well
|
||||
defined. These minimum degrees are specified in the
|
||||
``min_degree_funcs`` dictionary. If ``degree`` is less than the
|
||||
minimum degree required, a warning is raised and the degree is set
|
||||
to the minimum value.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
@@ -151,27 +186,39 @@ class RBFBlock(torch.nn.Module):
|
||||
@property
|
||||
def smoothing(self):
|
||||
"""
|
||||
Smoothing parameter for the interpolation.
|
||||
The smoothing parameter for the interpolation.
|
||||
|
||||
:return: The smoothing parameter.
|
||||
:rtype: float
|
||||
"""
|
||||
return self._smoothing
|
||||
|
||||
@smoothing.setter
|
||||
def smoothing(self, value):
|
||||
"""
|
||||
Set the smoothing parameter for the interpolation.
|
||||
|
||||
:param float value: The smoothing parameter.
|
||||
"""
|
||||
self._smoothing = value
|
||||
|
||||
@property
|
||||
def kernel(self):
|
||||
"""
|
||||
Radial basis function to use.
|
||||
The Radial basis function.
|
||||
|
||||
:return: The radial basis function.
|
||||
:rtype: str
|
||||
"""
|
||||
return self._kernel
|
||||
|
||||
@kernel.setter
|
||||
def kernel(self, value):
|
||||
"""
|
||||
Set the radial basis function.
|
||||
|
||||
:param str value: The radial basis function.
|
||||
"""
|
||||
if value not in radial_functions:
|
||||
raise ValueError(f"Unknown kernel: {value}")
|
||||
self._kernel = value.lower()
|
||||
@@ -179,14 +226,22 @@ class RBFBlock(torch.nn.Module):
|
||||
@property
|
||||
def epsilon(self):
|
||||
"""
|
||||
Shape parameter that scaled the input to the RBF.
|
||||
The shape parameter that scales the input to the RBF.
|
||||
|
||||
:return: The shape parameter.
|
||||
:rtype: float
|
||||
"""
|
||||
return self._epsilon
|
||||
|
||||
@epsilon.setter
|
||||
def epsilon(self, value):
|
||||
"""
|
||||
Set the shape parameter.
|
||||
|
||||
:param float value: The shape parameter.
|
||||
:raises ValueError: If the kernel requires an epsilon and it is not
|
||||
specified.
|
||||
"""
|
||||
if value is None:
|
||||
if self.kernel in scale_invariant:
|
||||
value = 1.0
|
||||
@@ -199,14 +254,23 @@ class RBFBlock(torch.nn.Module):
|
||||
@property
|
||||
def degree(self):
|
||||
"""
|
||||
Degree of the added polynomial.
|
||||
The degree of the polynomial.
|
||||
|
||||
:return: The degree of the polynomial.
|
||||
:rtype: int
|
||||
"""
|
||||
return self._degree
|
||||
|
||||
@degree.setter
|
||||
def degree(self, value):
|
||||
"""
|
||||
Set the degree of the polynomial.
|
||||
|
||||
:param int value: The degree of the polynomial.
|
||||
:raises UserWarning: If the degree is less than the minimum required
|
||||
for the kernel.
|
||||
:raises ValueError: If the degree is less than -1.
|
||||
"""
|
||||
min_degree = min_degree_funcs.get(self.kernel, -1)
|
||||
if value is None:
|
||||
value = max(min_degree, 0)
|
||||
@@ -223,6 +287,13 @@ class RBFBlock(torch.nn.Module):
|
||||
self._degree = value
|
||||
|
||||
def _check_data(self, y, d):
|
||||
"""
|
||||
Check the data consistency.
|
||||
|
||||
:param torch.Tensor y: The tensor of data points.
|
||||
:param torch.Tensor d: The tensor of data values.
|
||||
:raises ValueError: If the data is not consistent.
|
||||
"""
|
||||
if y.ndim != 2:
|
||||
raise ValueError("y must be a 2-dimensional tensor.")
|
||||
|
||||
@@ -241,8 +312,11 @@ class RBFBlock(torch.nn.Module):
|
||||
"""
|
||||
Fit the RBF interpolator to the data.
|
||||
|
||||
:param torch.Tensor y: (n, d) tensor of data points.
|
||||
:param torch.Tensor d: (n, m) tensor of data values.
|
||||
:param torch.Tensor y: The tensor of data points.
|
||||
:param torch.Tensor d: The tensor of data values.
|
||||
:raises NotImplementedError: If the neighbors are not ``None``.
|
||||
:raises ValueError: If the data is not compatible with the requested
|
||||
degree.
|
||||
"""
|
||||
self._check_data(y, d)
|
||||
|
||||
@@ -252,7 +326,7 @@ class RBFBlock(torch.nn.Module):
|
||||
if self.neighbors is None:
|
||||
nobs = self.y.shape[0]
|
||||
else:
|
||||
raise NotImplementedError("neighbors currently not supported")
|
||||
raise NotImplementedError("Neighbors currently not supported")
|
||||
|
||||
powers = RBFBlock.monomial_powers(self.y.shape[1], self.degree).to(
|
||||
y.device
|
||||
@@ -276,12 +350,14 @@ class RBFBlock(torch.nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Returns the interpolated data at the given points `x`.
|
||||
Forward pass.
|
||||
|
||||
:param torch.Tensor x: `(n, d)` tensor of points at which
|
||||
to query the interpolator
|
||||
|
||||
:rtype: `(n, m)` torch.Tensor of interpolated data.
|
||||
:param torch.Tensor x: The tensor of points to interpolate.
|
||||
:raises ValueError: If the input is not a 2-dimensional tensor.
|
||||
:raises ValueError: If the second dimension of the input is not the same
|
||||
as the second dimension of the data.
|
||||
:return: The interpolated data.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
if x.ndim != 2:
|
||||
raise ValueError("`x` must be a 2-dimensional tensor.")
|
||||
@@ -309,25 +385,25 @@ class RBFBlock(torch.nn.Module):
|
||||
@staticmethod
|
||||
def kernel_vector(x, y, kernel_func):
|
||||
"""
|
||||
Evaluate radial functions with centers `y` for all points in `x`.
|
||||
Evaluate for all points ``x`` the radial functions with center ``y``.
|
||||
|
||||
:param torch.Tensor x: `(n, d)` tensor of points.
|
||||
:param torch.Tensor y: `(m, d)` tensor of centers.
|
||||
:param torch.Tensor x: The tensor of points.
|
||||
:param torch.Tensor y: The tensor of centers.
|
||||
:param str kernel_func: Radial basis function to use.
|
||||
|
||||
:rtype: `(n, m)` torch.Tensor of radial function values.
|
||||
:return: The radial function values.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return kernel_func(torch.cdist(x, y))
|
||||
|
||||
@staticmethod
|
||||
def polynomial_matrix(x, powers):
|
||||
"""
|
||||
Evaluate monomials at `x` with given `powers`.
|
||||
Evaluate monomials of power ``powers`` at points ``x``.
|
||||
|
||||
:param torch.Tensor x: `(n, d)` tensor of points.
|
||||
:param torch.Tensor powers: `(r, d)` tensor of powers for each monomial.
|
||||
|
||||
:rtype: `(n, r)` torch.Tensor of monomial values.
|
||||
:param torch.Tensor x: The tensor of points.
|
||||
:param torch.Tensor powers: The tensor of powers for each monomial.
|
||||
:return: The monomial values.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
x_ = torch.repeat_interleave(x, repeats=powers.shape[0], dim=0)
|
||||
powers_ = powers.repeat(x.shape[0], 1)
|
||||
@@ -336,12 +412,12 @@ class RBFBlock(torch.nn.Module):
|
||||
@staticmethod
|
||||
def kernel_matrix(x, kernel_func):
|
||||
"""
|
||||
Returns radial function values for all pairs of points in `x`.
|
||||
Return the radial function values for all pairs of points in ``x``.
|
||||
|
||||
:param torch.Tensor x: `(n, d`) tensor of points.
|
||||
:param str kernel_func: Radial basis function to use.
|
||||
|
||||
:rtype: `(n, n`) torch.Tensor of radial function values.
|
||||
:param torch.Tensor x: The tensor of points.
|
||||
:param str kernel_func: The radial basis function to use.
|
||||
:return: The radial function values.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return kernel_func(torch.cdist(x, x))
|
||||
|
||||
@@ -350,12 +426,10 @@ class RBFBlock(torch.nn.Module):
|
||||
"""
|
||||
Return the powers for each monomial in a polynomial.
|
||||
|
||||
:param int ndim: Number of variables in the polynomial.
|
||||
:param int degree: Degree of the polynomial.
|
||||
|
||||
:rtype: `(nmonos, ndim)` torch.Tensor where each row contains the powers
|
||||
for each variable in a monomial.
|
||||
|
||||
:param int ndim: The number of variables in the polynomial.
|
||||
:param int degree: The degree of the polynomial.
|
||||
:return: The powers for each monomial.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
nmonos = math.comb(degree + ndim, ndim)
|
||||
out = torch.zeros((nmonos, ndim), dtype=torch.int32)
|
||||
@@ -372,16 +446,16 @@ class RBFBlock(torch.nn.Module):
|
||||
"""
|
||||
Build the RBF linear system.
|
||||
|
||||
:param torch.Tensor y: (n, d) tensor of data points.
|
||||
:param torch.Tensor d: (n, m) tensor of data values.
|
||||
:param torch.Tensor smoothing: (n,) tensor of smoothing parameters.
|
||||
:param str kernel: Radial basis function to use.
|
||||
:param float epsilon: Shape parameter that scaled the input to the RBF.
|
||||
:param torch.Tensor powers: (r, d) tensor of powers for each monomial.
|
||||
|
||||
:rtype: (lhs, rhs, shift, scale) where `lhs` and `rhs` are the
|
||||
left-hand side and right-hand side of the linear system, and
|
||||
`shift` and `scale` are the shift and scale parameters.
|
||||
:param torch.Tensor y: The tensor of data points.
|
||||
:param torch.Tensor d: The tensor of data values.
|
||||
:param torch.Tensor smoothing: The tensor of smoothing parameters.
|
||||
:param str kernel: The radial basis function to use.
|
||||
:param float epsilon: The shape parameter that scales the input to the
|
||||
RBF.
|
||||
:param torch.Tensor powers: The tensor of powers for each monomial.
|
||||
:return: The left-hand side and right-hand side of the linear system,
|
||||
and the shift and scale parameters.
|
||||
:rtype: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
"""
|
||||
p = d.shape[0]
|
||||
s = d.shape[1]
|
||||
@@ -413,21 +487,20 @@ class RBFBlock(torch.nn.Module):
|
||||
@staticmethod
|
||||
def solve(y, d, smoothing, kernel, epsilon, powers):
|
||||
"""
|
||||
Build then solve the RBF linear system.
|
||||
Build and solve the RBF linear system.
|
||||
|
||||
:param torch.Tensor y: (n, d) tensor of data points.
|
||||
:param torch.Tensor d: (n, m) tensor of data values.
|
||||
:param torch.Tensor smoothing: (n,) tensor of smoothing parameters.
|
||||
|
||||
:param str kernel: Radial basis function to use.
|
||||
:param float epsilon: Shape parameter that scaled the input to the RBF.
|
||||
:param torch.Tensor powers: (r, d) tensor of powers for each monomial.
|
||||
:param torch.Tensor y: The tensor of data points.
|
||||
:param torch.Tensor d: The tensor of data values.
|
||||
:param torch.Tensor smoothing: The tensor of smoothing parameters.
|
||||
|
||||
:param str kernel: The radial basis function to use.
|
||||
:param float epsilon: The shape parameter that scaled the input to the
|
||||
RBF.
|
||||
:param torch.Tensor powers: The tensor of powers for each monomial.
|
||||
:raises ValueError: If the linear system is singular.
|
||||
|
||||
:rtype: (shift, scale, coeffs) where `shift` and `scale` are the
|
||||
shift and scale parameters, and `coeffs` are the coefficients
|
||||
of the interpolator
|
||||
:return: The shift and scale parameters, and the coefficients of the
|
||||
interpolator.
|
||||
:rtype: tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
"""
|
||||
|
||||
lhs, rhs, shift, scale = RBFBlock.build(
|
||||
|
||||
Reference in New Issue
Block a user