fix doc model part 2

This commit is contained in:
giovanni
2025-03-14 16:07:08 +01:00
committed by Nicola Demo
parent 001d1fc9cf
commit f9881a79b5
18 changed files with 887 additions and 851 deletions

View File

@@ -1,4 +1,4 @@
"""Module for Radial Basis Function Interpolation layer."""
"""Module for the Radial Basis Function Interpolation layer."""
import math
import warnings
@@ -10,6 +10,10 @@ from ...utils import check_consistency
def linear(r):
"""
Linear radial basis function.
:param torch.Tensor r: Distance between points.
:return: The linear radial basis function.
:rtype: torch.Tensor
"""
return -r
@@ -17,6 +21,11 @@ def linear(r):
def thin_plate_spline(r, eps=1e-7):
"""
Thin plate spline radial basis function.
:param torch.Tensor r: Distance between points.
:param float eps: Small value to avoid log(0).
:return: The thin plate spline radial basis function.
:rtype: torch.Tensor
"""
r = torch.clamp(r, min=eps)
return r**2 * torch.log(r)
@@ -25,6 +34,10 @@ def thin_plate_spline(r, eps=1e-7):
def cubic(r):
"""
Cubic radial basis function.
:param torch.Tensor r: Distance between points.
:return: The cubic radial basis function.
:rtype: torch.Tensor
"""
return r**3
@@ -32,6 +45,10 @@ def cubic(r):
def quintic(r):
"""
Quintic radial basis function.
:param torch.Tensor r: Distance between points.
:return: The quintic radial basis function.
:rtype: torch.Tensor
"""
return -(r**5)
@@ -39,6 +56,10 @@ def quintic(r):
def multiquadric(r):
"""
Multiquadric radial basis function.
:param torch.Tensor r: Distance between points.
:return: The multiquadric radial basis function.
:rtype: torch.Tensor
"""
return -torch.sqrt(r**2 + 1)
@@ -46,6 +67,10 @@ def multiquadric(r):
def inverse_multiquadric(r):
"""
Inverse multiquadric radial basis function.
:param torch.Tensor r: Distance between points.
:return: The inverse multiquadric radial basis function.
:rtype: torch.Tensor
"""
return 1 / torch.sqrt(r**2 + 1)
@@ -53,6 +78,10 @@ def inverse_multiquadric(r):
def inverse_quadratic(r):
"""
Inverse quadratic radial basis function.
:param torch.Tensor r: Distance between points.
:return: The inverse quadratic radial basis function.
:rtype: torch.Tensor
"""
return 1 / (r**2 + 1)
@@ -60,6 +89,10 @@ def inverse_quadratic(r):
def gaussian(r):
"""
Gaussian radial basis function.
:param torch.Tensor r: Distance between points.
:return: The gaussian radial basis function.
:rtype: torch.Tensor
"""
return torch.exp(-(r**2))
@@ -88,13 +121,14 @@ min_degree_funcs = {
class RBFBlock(torch.nn.Module):
"""
Radial Basis Function (RBF) interpolation layer. It need to be fitted with
the data with the method :meth:`fit`, before it can be used to interpolate
new points. The layer is not trainable.
Radial Basis Function (RBF) interpolation layer.
The user needs to fit the model with the data, before using it to
interpolate new points. The layer is not trainable.
.. note::
It reproduces the implementation of ``scipy.interpolate.RBFBlock`` and
it is inspired from the implementation in `torchrbf.
It reproduces the implementation of :class:`scipy.interpolate.RBFBlock`
and it is inspired from the implementation in `torchrbf.
<https://github.com/ArmanMaesumi/torchrbf>`_
"""
@@ -107,24 +141,25 @@ class RBFBlock(torch.nn.Module):
degree=None,
):
"""
:param int neighbors: Number of neighbors to use for the
interpolation.
If ``None``, use all data points.
:param float smoothing: Smoothing parameter for the interpolation.
if 0.0, the interpolation is exact and no smoothing is applied.
:param str kernel: Radial basis function to use. Must be one of
``linear``, ``thin_plate_spline``, ``cubic``, ``quintic``,
``multiquadric``, ``inverse_multiquadric``, ``inverse_quadratic``,
or ``gaussian``.
:param float epsilon: Shape parameter that scaled the input to
the RBF. This defaults to 1 for kernels in ``scale_invariant``
dictionary, and must be specified for other kernels.
:param int degree: Degree of the added polynomial.
For some kernels, there exists a minimum degree of the polynomial
such that the RBF is well-posed. Those minimum degrees are specified
in the `min_degree_funcs` dictionary above. If `degree` is less than
the minimum degree, a warning is raised and the degree is set to the
minimum value.
Initialization of the :class:`RBFBlock` class.
:param int neighbors: The number of neighbors used for interpolation.
If ``None``, all data are used.
:param float smoothing: The moothing parameter for the interpolation.
If ``0.0``, the interpolation is exact and no smoothing is applied.
:param str kernel: The radial basis function to use.
The available kernels are: ``linear``, ``thin_plate_spline``,
``cubic``, ``quintic``, ``multiquadric``, ``inverse_multiquadric``,
``inverse_quadratic``, or ``gaussian``.
:param float epsilon: The shape parameter that scales the input to the
RBF. Default is ``1`` for kernels in the ``scale_invariant``
dictionary, while it must be specified for other kernels.
:param int degree: The degree of the polynomial. Some kernels require a
minimum degree of the polynomial to ensure that the RBF is well
defined. These minimum degrees are specified in the
``min_degree_funcs`` dictionary. If ``degree`` is less than the
minimum degree required, a warning is raised and the degree is set
to the minimum value.
"""
super().__init__()
@@ -151,27 +186,39 @@ class RBFBlock(torch.nn.Module):
@property
def smoothing(self):
"""
Smoothing parameter for the interpolation.
The smoothing parameter for the interpolation.
:return: The smoothing parameter.
:rtype: float
"""
return self._smoothing
@smoothing.setter
def smoothing(self, value):
"""
Set the smoothing parameter for the interpolation.
:param float value: The smoothing parameter.
"""
self._smoothing = value
@property
def kernel(self):
"""
Radial basis function to use.
The Radial basis function.
:return: The radial basis function.
:rtype: str
"""
return self._kernel
@kernel.setter
def kernel(self, value):
"""
Set the radial basis function.
:param str value: The radial basis function.
"""
if value not in radial_functions:
raise ValueError(f"Unknown kernel: {value}")
self._kernel = value.lower()
@@ -179,14 +226,22 @@ class RBFBlock(torch.nn.Module):
@property
def epsilon(self):
"""
Shape parameter that scaled the input to the RBF.
The shape parameter that scales the input to the RBF.
:return: The shape parameter.
:rtype: float
"""
return self._epsilon
@epsilon.setter
def epsilon(self, value):
"""
Set the shape parameter.
:param float value: The shape parameter.
:raises ValueError: If the kernel requires an epsilon and it is not
specified.
"""
if value is None:
if self.kernel in scale_invariant:
value = 1.0
@@ -199,14 +254,23 @@ class RBFBlock(torch.nn.Module):
@property
def degree(self):
"""
Degree of the added polynomial.
The degree of the polynomial.
:return: The degree of the polynomial.
:rtype: int
"""
return self._degree
@degree.setter
def degree(self, value):
"""
Set the degree of the polynomial.
:param int value: The degree of the polynomial.
:raises UserWarning: If the degree is less than the minimum required
for the kernel.
:raises ValueError: If the degree is less than -1.
"""
min_degree = min_degree_funcs.get(self.kernel, -1)
if value is None:
value = max(min_degree, 0)
@@ -223,6 +287,13 @@ class RBFBlock(torch.nn.Module):
self._degree = value
def _check_data(self, y, d):
"""
Check the data consistency.
:param torch.Tensor y: The tensor of data points.
:param torch.Tensor d: The tensor of data values.
:raises ValueError: If the data is not consistent.
"""
if y.ndim != 2:
raise ValueError("y must be a 2-dimensional tensor.")
@@ -241,8 +312,11 @@ class RBFBlock(torch.nn.Module):
"""
Fit the RBF interpolator to the data.
:param torch.Tensor y: (n, d) tensor of data points.
:param torch.Tensor d: (n, m) tensor of data values.
:param torch.Tensor y: The tensor of data points.
:param torch.Tensor d: The tensor of data values.
:raises NotImplementedError: If the neighbors are not ``None``.
:raises ValueError: If the data is not compatible with the requested
degree.
"""
self._check_data(y, d)
@@ -252,7 +326,7 @@ class RBFBlock(torch.nn.Module):
if self.neighbors is None:
nobs = self.y.shape[0]
else:
raise NotImplementedError("neighbors currently not supported")
raise NotImplementedError("Neighbors currently not supported")
powers = RBFBlock.monomial_powers(self.y.shape[1], self.degree).to(
y.device
@@ -276,12 +350,14 @@ class RBFBlock(torch.nn.Module):
def forward(self, x):
"""
Returns the interpolated data at the given points `x`.
Forward pass.
:param torch.Tensor x: `(n, d)` tensor of points at which
to query the interpolator
:rtype: `(n, m)` torch.Tensor of interpolated data.
:param torch.Tensor x: The tensor of points to interpolate.
:raises ValueError: If the input is not a 2-dimensional tensor.
:raises ValueError: If the second dimension of the input is not the same
as the second dimension of the data.
:return: The interpolated data.
:rtype: torch.Tensor
"""
if x.ndim != 2:
raise ValueError("`x` must be a 2-dimensional tensor.")
@@ -309,25 +385,25 @@ class RBFBlock(torch.nn.Module):
@staticmethod
def kernel_vector(x, y, kernel_func):
"""
Evaluate radial functions with centers `y` for all points in `x`.
Evaluate for all points ``x`` the radial functions with center ``y``.
:param torch.Tensor x: `(n, d)` tensor of points.
:param torch.Tensor y: `(m, d)` tensor of centers.
:param torch.Tensor x: The tensor of points.
:param torch.Tensor y: The tensor of centers.
:param str kernel_func: Radial basis function to use.
:rtype: `(n, m)` torch.Tensor of radial function values.
:return: The radial function values.
:rtype: torch.Tensor
"""
return kernel_func(torch.cdist(x, y))
@staticmethod
def polynomial_matrix(x, powers):
"""
Evaluate monomials at `x` with given `powers`.
Evaluate monomials of power ``powers`` at points ``x``.
:param torch.Tensor x: `(n, d)` tensor of points.
:param torch.Tensor powers: `(r, d)` tensor of powers for each monomial.
:rtype: `(n, r)` torch.Tensor of monomial values.
:param torch.Tensor x: The tensor of points.
:param torch.Tensor powers: The tensor of powers for each monomial.
:return: The monomial values.
:rtype: torch.Tensor
"""
x_ = torch.repeat_interleave(x, repeats=powers.shape[0], dim=0)
powers_ = powers.repeat(x.shape[0], 1)
@@ -336,12 +412,12 @@ class RBFBlock(torch.nn.Module):
@staticmethod
def kernel_matrix(x, kernel_func):
"""
Returns radial function values for all pairs of points in `x`.
Return the radial function values for all pairs of points in ``x``.
:param torch.Tensor x: `(n, d`) tensor of points.
:param str kernel_func: Radial basis function to use.
:rtype: `(n, n`) torch.Tensor of radial function values.
:param torch.Tensor x: The tensor of points.
:param str kernel_func: The radial basis function to use.
:return: The radial function values.
:rtype: torch.Tensor
"""
return kernel_func(torch.cdist(x, x))
@@ -350,12 +426,10 @@ class RBFBlock(torch.nn.Module):
"""
Return the powers for each monomial in a polynomial.
:param int ndim: Number of variables in the polynomial.
:param int degree: Degree of the polynomial.
:rtype: `(nmonos, ndim)` torch.Tensor where each row contains the powers
for each variable in a monomial.
:param int ndim: The number of variables in the polynomial.
:param int degree: The degree of the polynomial.
:return: The powers for each monomial.
:rtype: torch.Tensor
"""
nmonos = math.comb(degree + ndim, ndim)
out = torch.zeros((nmonos, ndim), dtype=torch.int32)
@@ -372,16 +446,16 @@ class RBFBlock(torch.nn.Module):
"""
Build the RBF linear system.
:param torch.Tensor y: (n, d) tensor of data points.
:param torch.Tensor d: (n, m) tensor of data values.
:param torch.Tensor smoothing: (n,) tensor of smoothing parameters.
:param str kernel: Radial basis function to use.
:param float epsilon: Shape parameter that scaled the input to the RBF.
:param torch.Tensor powers: (r, d) tensor of powers for each monomial.
:rtype: (lhs, rhs, shift, scale) where `lhs` and `rhs` are the
left-hand side and right-hand side of the linear system, and
`shift` and `scale` are the shift and scale parameters.
:param torch.Tensor y: The tensor of data points.
:param torch.Tensor d: The tensor of data values.
:param torch.Tensor smoothing: The tensor of smoothing parameters.
:param str kernel: The radial basis function to use.
:param float epsilon: The shape parameter that scales the input to the
RBF.
:param torch.Tensor powers: The tensor of powers for each monomial.
:return: The left-hand side and right-hand side of the linear system,
and the shift and scale parameters.
:rtype: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
"""
p = d.shape[0]
s = d.shape[1]
@@ -413,21 +487,20 @@ class RBFBlock(torch.nn.Module):
@staticmethod
def solve(y, d, smoothing, kernel, epsilon, powers):
"""
Build then solve the RBF linear system.
Build and solve the RBF linear system.
:param torch.Tensor y: (n, d) tensor of data points.
:param torch.Tensor d: (n, m) tensor of data values.
:param torch.Tensor smoothing: (n,) tensor of smoothing parameters.
:param str kernel: Radial basis function to use.
:param float epsilon: Shape parameter that scaled the input to the RBF.
:param torch.Tensor powers: (r, d) tensor of powers for each monomial.
:param torch.Tensor y: The tensor of data points.
:param torch.Tensor d: The tensor of data values.
:param torch.Tensor smoothing: The tensor of smoothing parameters.
:param str kernel: The radial basis function to use.
:param float epsilon: The shape parameter that scaled the input to the
RBF.
:param torch.Tensor powers: The tensor of powers for each monomial.
:raises ValueError: If the linear system is singular.
:rtype: (shift, scale, coeffs) where `shift` and `scale` are the
shift and scale parameters, and `coeffs` are the coefficients
of the interpolator
:return: The shift and scale parameters, and the coefficients of the
interpolator.
:rtype: tuple[torch.Tensor, torch.Tensor, torch.Tensor]
"""
lhs, rhs, shift, scale = RBFBlock.build(