vectorize Cox - de Boor recursion
Co-authored-by: Filippo Olivo <folivo@filippoolivo.com> Co-authored-by: ajacoby9 <a99jacoby@gmail.com>
This commit is contained in:
@@ -9,7 +9,9 @@ class Spline(torch.nn.Module):
|
|||||||
Spline model class.
|
Spline model class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, order=4, knots=None, control_points=None) -> None:
|
def __init__(
|
||||||
|
self, order=4, knots=None, control_points=None, grid_extension=True
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initialization of the :class:`Spline` class.
|
Initialization of the :class:`Spline` class.
|
||||||
|
|
||||||
@@ -20,7 +22,7 @@ class Spline(torch.nn.Module):
|
|||||||
``None``.
|
``None``.
|
||||||
:raises ValueError: If the order is negative.
|
:raises ValueError: If the order is negative.
|
||||||
:raises ValueError: If both knots and control points are ``None``.
|
:raises ValueError: If both knots and control points are ``None``.
|
||||||
:raises ValueError: If the knot tensor is not one-dimensional.
|
:raises ValueError: If the knot tensor is not one or two dimensional.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -33,6 +35,10 @@ class Spline(torch.nn.Module):
|
|||||||
|
|
||||||
self.order = order
|
self.order = order
|
||||||
self.k = order - 1
|
self.k = order - 1
|
||||||
|
self.grid_extension = grid_extension
|
||||||
|
|
||||||
|
# Cache for performance optimization
|
||||||
|
self._boundary_interval_idx = None
|
||||||
|
|
||||||
if knots is not None and control_points is not None:
|
if knots is not None and control_points is not None:
|
||||||
self.knots = knots
|
self.knots = knots
|
||||||
@@ -65,45 +71,154 @@ class Spline(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Knots and control points cannot be both None.")
|
raise ValueError("Knots and control points cannot be both None.")
|
||||||
|
|
||||||
if self.knots.ndim != 1:
|
if self.knots.ndim > 2:
|
||||||
raise ValueError("Knot vector must be one-dimensional.")
|
raise ValueError("Knot vector must be one or two-dimensional.")
|
||||||
|
|
||||||
def basis(self, x, k, i, t):
|
# Precompute boundary interval index for performance
|
||||||
|
self._compute_boundary_interval()
|
||||||
|
|
||||||
|
def _compute_boundary_interval(self):
|
||||||
"""
|
"""
|
||||||
Recursive method to compute the basis functions of the spline.
|
Precompute the rightmost non-degenerate interval index for performance.
|
||||||
|
This avoids the search loop in the basis function on every call.
|
||||||
|
"""
|
||||||
|
# Handle multi-dimensional knots
|
||||||
|
if self.knots.ndim > 1:
|
||||||
|
# For multi-dimensional knots, we'll handle boundary detection in
|
||||||
|
# the basis function
|
||||||
|
self._boundary_interval_idx = None
|
||||||
|
return
|
||||||
|
|
||||||
|
# For 1D knots, find the rightmost non-degenerate interval
|
||||||
|
for i in range(len(self.knots) - 2, -1, -1):
|
||||||
|
if self.knots[i] < self.knots[i + 1]: # Non-degenerate interval found
|
||||||
|
self._boundary_interval_idx = i
|
||||||
|
return
|
||||||
|
|
||||||
|
self._boundary_interval_idx = len(self.knots) - 2 if len(self.knots) > 1 else 0
|
||||||
|
|
||||||
|
def basis(self, x, k, knots):
|
||||||
|
"""
|
||||||
|
Compute the basis functions for the spline using an iterative approach.
|
||||||
|
This is a vectorized implementation based on the Cox-de Boor recursion.
|
||||||
|
|
||||||
:param torch.Tensor x: The points to be evaluated.
|
:param torch.Tensor x: The points to be evaluated.
|
||||||
:param int k: The spline degree.
|
:param int k: The spline degree.
|
||||||
:param int i: The index of the interval.
|
:param torch.Tensor knots: The tensor of knots.
|
||||||
:param torch.Tensor t: The tensor of knots.
|
|
||||||
:return: The basis functions evaluated at x
|
:return: The basis functions evaluated at x
|
||||||
:rtype: torch.Tensor
|
:rtype: torch.Tensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if k == 0:
|
if x.ndim == 1:
|
||||||
a = torch.where(
|
x = x.unsqueeze(1) # (batch_size, 1)
|
||||||
torch.logical_and(t[i] <= x, x < t[i + 1]), 1.0, 0.0
|
if x.ndim == 2:
|
||||||
|
x = x.unsqueeze(2) # (batch_size, in_dim, 1)
|
||||||
|
|
||||||
|
if knots.ndim == 1:
|
||||||
|
knots = knots.unsqueeze(0) # (1, n_knots)
|
||||||
|
if knots.ndim == 2:
|
||||||
|
knots = knots.unsqueeze(0) # (1, in_dim, n_knots)
|
||||||
|
|
||||||
|
# Base case: k=0
|
||||||
|
basis = (x >= knots[..., :-1]) & (x < knots[..., 1:])
|
||||||
|
basis = basis.to(x.dtype)
|
||||||
|
|
||||||
|
if self._boundary_interval_idx is not None:
|
||||||
|
i = self._boundary_interval_idx
|
||||||
|
tolerance = 1e-10
|
||||||
|
x_squeezed = x.squeeze(-1)
|
||||||
|
knot_left = knots[..., i]
|
||||||
|
knot_right = knots[..., i + 1]
|
||||||
|
|
||||||
|
at_right_boundary = torch.abs(x_squeezed - knot_right) <= tolerance
|
||||||
|
in_rightmost_interval = (
|
||||||
|
x_squeezed >= knot_left
|
||||||
|
) & at_right_boundary
|
||||||
|
|
||||||
|
if torch.any(in_rightmost_interval):
|
||||||
|
# For points at the boundary, ensure they're included in the
|
||||||
|
# rightmost interval
|
||||||
|
basis[..., i] = torch.logical_or(
|
||||||
|
basis[..., i].bool(), in_rightmost_interval
|
||||||
|
).to(basis.dtype)
|
||||||
|
|
||||||
|
# Iterative step (Cox-de Boor recursion)
|
||||||
|
for i in range(1, k + 1):
|
||||||
|
# First term of the recursion
|
||||||
|
denom1 = knots[..., i:-1] - knots[..., : -(i + 1)]
|
||||||
|
denom1 = torch.where(
|
||||||
|
torch.abs(denom1) < 1e-8, torch.ones_like(denom1), denom1
|
||||||
)
|
)
|
||||||
if i == len(t) - self.order - 1:
|
numer1 = x - knots[..., : -(i + 1)]
|
||||||
a = torch.where(x == t[-1], 1.0, a)
|
term1 = (numer1 / denom1) * basis[..., :-1]
|
||||||
a.requires_grad_(True)
|
|
||||||
return a
|
|
||||||
|
|
||||||
if t[i + k] == t[i]:
|
denom2 = knots[..., i + 1 :] - knots[..., 1:-i]
|
||||||
c1 = torch.tensor([0.0] * len(x), requires_grad=True)
|
denom2 = torch.where(
|
||||||
else:
|
torch.abs(denom2) < 1e-8, torch.ones_like(denom2), denom2
|
||||||
c1 = (x - t[i]) / (t[i + k] - t[i]) * self.basis(x, k - 1, i, t)
|
|
||||||
|
|
||||||
if t[i + k + 1] == t[i + 1]:
|
|
||||||
c2 = torch.tensor([0.0] * len(x), requires_grad=True)
|
|
||||||
else:
|
|
||||||
c2 = (
|
|
||||||
(t[i + k + 1] - x)
|
|
||||||
/ (t[i + k + 1] - t[i + 1])
|
|
||||||
* self.basis(x, k - 1, i + 1, t)
|
|
||||||
)
|
)
|
||||||
|
numer2 = knots[..., i + 1 :] - x
|
||||||
|
term2 = (numer2 / denom2) * basis[..., 1:]
|
||||||
|
|
||||||
return c1 + c2
|
basis = term1 + term2
|
||||||
|
|
||||||
|
return basis
|
||||||
|
|
||||||
|
def compute_control_points(self, x_eval, y_eval):
|
||||||
|
"""
|
||||||
|
Compute control points from given evaluations using least squares.
|
||||||
|
This method fits the control points to match the target y_eval values.
|
||||||
|
"""
|
||||||
|
# (batch, in_dim)
|
||||||
|
A = self.basis(x_eval, self.k, self.knots)
|
||||||
|
# (batch, in_dim, n_basis)
|
||||||
|
|
||||||
|
in_dim = A.shape[1]
|
||||||
|
out_dim = y_eval.shape[2]
|
||||||
|
n_basis = A.shape[2]
|
||||||
|
c = torch.zeros(in_dim, out_dim, n_basis).to(A.device)
|
||||||
|
|
||||||
|
for i in range(in_dim):
|
||||||
|
# A_i is (batch, n_basis)
|
||||||
|
# y_i is (batch, out_dim)
|
||||||
|
A_i = A[:, i, :]
|
||||||
|
y_i = y_eval[:, i, :]
|
||||||
|
c_i = torch.linalg.lstsq(A_i, y_i).solution # (n_basis, out_dim)
|
||||||
|
c[i, :, :] = c_i.T # (out_dim, n_basis)
|
||||||
|
|
||||||
|
self.control_points = torch.nn.Parameter(c)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Forward pass for the :class:`Spline` model.
|
||||||
|
|
||||||
|
:param torch.Tensor x: The input tensor.
|
||||||
|
:return: The output tensor.
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
"""
|
||||||
|
t = self.knots
|
||||||
|
k = self.k
|
||||||
|
c = self.control_points
|
||||||
|
|
||||||
|
# Create the basis functions
|
||||||
|
# B will have shape (batch, in_dim, n_basis)
|
||||||
|
B = self.basis(x, k, t)
|
||||||
|
|
||||||
|
# KAN case where control points are (in_dim, out_dim, n_basis)
|
||||||
|
if c.ndim == 3:
|
||||||
|
y_ij = torch.einsum(
|
||||||
|
"bil,iol->bio", B, c
|
||||||
|
) # (batch, in_dim, out_dim)
|
||||||
|
# sum over input dimensions
|
||||||
|
y = torch.sum(y_ij, dim=1) # (batch, out_dim)
|
||||||
|
# Original test case
|
||||||
|
else:
|
||||||
|
B = B.squeeze(1) # (batch, n_basis)
|
||||||
|
if c.ndim == 1:
|
||||||
|
y = torch.einsum("bi,i->b", B, c)
|
||||||
|
else:
|
||||||
|
y = torch.einsum("bi,ij->bj", B, c)
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def control_points(self):
|
def control_points(self):
|
||||||
@@ -131,9 +246,12 @@ class Spline(torch.nn.Module):
|
|||||||
dim = value.get("dim", 1)
|
dim = value.get("dim", 1)
|
||||||
value = torch.zeros(n, dim)
|
value = torch.zeros(n, dim)
|
||||||
|
|
||||||
|
if not isinstance(value, torch.nn.Parameter):
|
||||||
|
value = torch.nn.Parameter(value)
|
||||||
|
|
||||||
if not isinstance(value, torch.Tensor):
|
if not isinstance(value, torch.Tensor):
|
||||||
raise ValueError("Invalid value for control_points")
|
raise ValueError("Invalid value for control_points")
|
||||||
self._control_points = torch.nn.Parameter(value, requires_grad=True)
|
self._control_points = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def knots(self):
|
def knots(self):
|
||||||
@@ -181,19 +299,6 @@ class Spline(torch.nn.Module):
|
|||||||
|
|
||||||
self._knots = value
|
self._knots = value
|
||||||
|
|
||||||
def forward(self, x):
|
# Recompute boundary interval when knots change
|
||||||
"""
|
if hasattr(self, "_boundary_interval_idx"):
|
||||||
Forward pass for the :class:`Spline` model.
|
self._compute_boundary_interval()
|
||||||
|
|
||||||
:param torch.Tensor x: The input tensor.
|
|
||||||
:return: The output tensor.
|
|
||||||
:rtype: torch.Tensor
|
|
||||||
"""
|
|
||||||
t = self.knots
|
|
||||||
k = self.k
|
|
||||||
c = self.control_points
|
|
||||||
|
|
||||||
basis = map(lambda i: self.basis(x, k, i, t)[:, None], range(len(c)))
|
|
||||||
y = (torch.cat(list(basis), dim=1) * c).sum(axis=1)
|
|
||||||
|
|
||||||
return y
|
|
||||||
Reference in New Issue
Block a user