fix logic and extend tests

This commit is contained in:
GiovanniCanali
2025-10-03 12:31:33 +02:00
parent ad41ba05b2
commit 71ce8c55f6
2 changed files with 408 additions and 246 deletions

View File

@@ -1,224 +1,238 @@
"""Module for the Spline model class.""" """Module for the B-Spline model class."""
import torch import torch
from ..utils import check_consistency import warnings
from ..utils import check_positive_integer
class Spline(torch.nn.Module): class Spline(torch.nn.Module):
""" r"""
Spline model class. The univariate B-Spline curve model class.
A univariate B-spline curve of order :math:`k` is a parametric curve defined
as a linear combination of B-spline basis functions and control points:
.. math::
S(x) = \sum_{i=1}^{n} B_{i,k}(x) C_i, \quad x \in [x_1, x_m]
where:
- :math:`C_i \in \mathbb{R}` are the control points. These fixed points
influence the shape of the curve but are not generally interpolated,
except at the boundaries under certain knot multiplicities.
- :math:`B_{i,k}(x)` are the B-spline basis functions of order :math:`k`,
i.e., piecewise polynomials of degree :math:`k-1` with support on the
interval :math:`[x_i, x_{i+k}]`.
- :math:`X = \{ x_1, x_2, \dots, x_m \}` is the non-decreasing knot vector.
If the first and last knots are repeated :math:`k` times, then the curve
interpolates the first and last control points.
.. note::
The curve is forced to be zero outside the interval defined by the
first and last knots.
:Example:
>>> from pina.model import Spline
>>> import torch
>>> knots1 = torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0, 2.0, 2.0])
>>> spline1 = Spline(order=3, knots=knots1, control_points=None)
>>> knots2 = {"n": 7, "min": 0.0, "max": 2.0, "mode": "auto"}
>>> spline2 = Spline(order=3, knots=knots2, control_points=None)
>>> knots3 = torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0, 2.0, 2.0])
>>> control_points3 = torch.tensor([0.0, 1.0, 3.0, 2.0])
>>> spline3 = Spline(order=3, knots=knots3, control_points=control_points3)
""" """
def __init__( def __init__(self, order=4, knots=None, control_points=None):
self, order=4, knots=None, control_points=None, grid_extension=True
):
""" """
Initialization of the :class:`Spline` class. Initialization of the :class:`Spline` class.
:param int order: The order of the spline. Default is ``4``. :param int order: The order of the spline. The corresponding basis
:param torch.Tensor knots: The tensor representing knots. If ``None``, functions are polynomials of degree ``order - 1``. Default is 4.
the knots will be initialized automatically. Default is ``None``. :param knots: The knots of the spline. If a tensor is provided, knots
:param torch.Tensor control_points: The control points. Default is are set directly from the tensor. If a dictionary is provided, it
``None``. must contain the keys ``"n"``, ``"min"``, ``"max"``, and ``"mode"``.
:raises ValueError: If the order is negative. Here, ``"n"`` specifies the number of knots, ``"min"`` and ``"max"``
:raises ValueError: If both knots and control points are ``None``. define the interval, and ``"mode"`` selects the sampling strategy.
:raises ValueError: If the knot tensor is not one or two dimensional. The supported modes are ``"uniform"``, where the knots are evenly
spaced over :math:`[min, max]`, and ``"auto"``, where knots are
constructed to ensure that the spline interpolates the first and
last control points. In this case, the number of knots is adjusted
if :math:`n < 2 * order`. If None is given, knots are initialized
automatically over :math:`[0, 1]` ensuring interpolation of the
first and last control points. Default is None.
:type knots: torch.Tensor | dict
:param torch.Tensor control_points: The control points of the spline.
If None, they are initialized as learnable parameters with an
initial value of zero. Default is None.
:raises AssertionError: If ``order`` is not a positive integer.
:raises ValueError: If both ``knots`` and ``control_points`` are None.
:raises ValueError: If ``knots`` is not one-dimensional.
:raises ValueError: If ``control_points`` is not one-dimensional.
:raises ValueError: If the number of ``knots`` is not equal to the sum
of ``order`` and the number of ``control_points.``
:raises UserWarning: If the number of control points is lower than the
order, resulting in a degenerate spline.
""" """
super().__init__() super().__init__()
check_consistency(order, int) # Check consistency
check_positive_integer(value=order, strict=True)
if order < 0: # Raise error if neither knots nor control points are provided
raise ValueError("Spline order cannot be negative.")
if knots is None and control_points is None: if knots is None and control_points is None:
raise ValueError("Knots and control points cannot be both None.") raise ValueError("knots and control_points cannot both be None.")
self.order = order # Initialize knots if not provided
self.k = order - 1 if knots is None and control_points is not None:
self.grid_extension = grid_extension knots = {
"n": len(control_points) + order,
# Cache for performance optimization
self._boundary_interval_idx = None
if knots is not None and control_points is not None:
self.knots = knots
self.control_points = control_points
elif knots is not None:
print("Warning: control points will be initialized automatically.")
print(" experimental feature")
self.knots = knots
n = len(knots) - order
self.control_points = torch.nn.Parameter(
torch.zeros(n), requires_grad=True
)
elif control_points is not None:
print("Warning: knots will be initialized automatically.")
print(" experimental feature")
self.control_points = control_points
n = len(self.control_points) - 1
self.knots = {
"type": "auto",
"min": 0, "min": 0,
"max": 1, "max": 1,
"n": n + 2 + self.order, "mode": "auto",
} }
else: # Initialization - knots and control points managed by their setters
raise ValueError("Knots and control points cannot be both None.") self.order = order
self.knots = knots
self.control_points = control_points
if self.knots.ndim > 2: # Check dimensionality of knots
raise ValueError("Knot vector must be one or two-dimensional.") if self.knots.ndim > 1:
raise ValueError("knots must be one-dimensional.")
# Precompute boundary interval index for performance # Check dimensionality of control points
self._compute_boundary_interval() if self.control_points.ndim > 1:
raise ValueError("control_points must be one-dimensional.")
# Raise error if #knots != order + #control_points
if len(self.knots) != self.order + len(self.control_points):
raise ValueError(
f" The number of knots must be equal to order + number of"
f" control points. Got {len(self.knots)} knots, {self.order}"
f" order and {len(self.control_points)} control points."
)
# Raise warning if spline is degenerate
if len(self.control_points) < self.order:
warnings.warn(
"The number of control points is smaller than the spline order."
" This creates a degenerate spline with limited flexibility.",
UserWarning,
)
# Precompute boundary interval index
self._boundary_interval_idx = self._compute_boundary_interval()
def _compute_boundary_interval(self): def _compute_boundary_interval(self):
""" """
Precompute the rightmost non-degenerate interval index for performance. Precompute the index of the rightmost non-degenerate interval to improve
This avoids the search loop in the basis function on every call. performance, eliminating the need to perform a search loop in the basis
function on each call.
:return: The index of the rightmost non-degenerate interval.
:rtype: int
""" """
# Handle multi-dimensional knots # Return 0 if there is a single interval
if self.knots.ndim > 1: if len(self.knots) < 2:
# For multi-dimensional knots, we'll handle boundary detection in return 0
# the basis function
self._boundary_interval_idx = None
return
# For 1D knots, find the rightmost non-degenerate interval # Find all indices where knots are strictly increasing
for i in range(len(self.knots) - 2, -1, -1): diffs = self.knots[1:] - self.knots[:-1]
if self.knots[i] < self.knots[i + 1]: # Non-degenerate interval found valid = torch.nonzero(diffs > 0, as_tuple=False)
self._boundary_interval_idx = i
return
self._boundary_interval_idx = len(self.knots) - 2 if len(self.knots) > 1 else 0 # If all knots are equal, return 0 for degenerate spline
if valid.numel() == 0:
return 0
def basis(self, x, k, knots): # Otherwise, return the last valid index
return int(valid[-1])
def basis(self, x):
""" """
Compute the basis functions for the spline using an iterative approach. Compute the basis functions for the spline using an iterative approach.
This is a vectorized implementation based on the Cox-de Boor recursion. This is a vectorized implementation based on the Cox-de Boor recursion.
:param torch.Tensor x: The points to be evaluated. :param torch.Tensor x: The points to be evaluated.
:param int k: The spline degree. :return: The basis functions evaluated at x.
:param torch.Tensor knots: The tensor of knots.
:return: The basis functions evaluated at x
:rtype: torch.Tensor :rtype: torch.Tensor
""" """
# Add a final dimension to x
x = x.unsqueeze(-1)
if x.ndim == 1: # Add an initial dimension to knots
x = x.unsqueeze(1) # (batch_size, 1) knots = self.knots.unsqueeze(0)
if x.ndim == 2:
x = x.unsqueeze(2) # (batch_size, in_dim, 1)
if knots.ndim == 1: # Base case of recursion: indicator functions for the intervals
knots = knots.unsqueeze(0) # (1, n_knots)
if knots.ndim == 2:
knots = knots.unsqueeze(0) # (1, in_dim, n_knots)
# Base case: k=0
basis = (x >= knots[..., :-1]) & (x < knots[..., 1:]) basis = (x >= knots[..., :-1]) & (x < knots[..., 1:])
basis = basis.to(x.dtype) basis = basis.to(x.dtype)
# One-dimensional knots case: ensure rightmost boundary inclusion
if self._boundary_interval_idx is not None: if self._boundary_interval_idx is not None:
i = self._boundary_interval_idx
tolerance = 1e-10
x_squeezed = x.squeeze(-1)
knot_left = knots[..., i]
knot_right = knots[..., i + 1]
at_right_boundary = torch.abs(x_squeezed - knot_right) <= tolerance # Extract left and right knots of the rightmost interval
in_rightmost_interval = ( knot_left = knots[..., self._boundary_interval_idx]
x_squeezed >= knot_left knot_right = knots[..., self._boundary_interval_idx + 1]
) & at_right_boundary
if torch.any(in_rightmost_interval): # Identify points at the rightmost boundary
# For points at the boundary, ensure they're included in the at_rightmost_boundary = (
# rightmost interval x.squeeze(-1) >= knot_left
basis[..., i] = torch.logical_or( ) & torch.isclose(x.squeeze(-1), knot_right, rtol=1e-8, atol=1e-10)
basis[..., i].bool(), in_rightmost_interval
# Ensure the correct value is set at the rightmost boundary
if torch.any(at_rightmost_boundary):
basis[..., self._boundary_interval_idx] = torch.logical_or(
basis[..., self._boundary_interval_idx].bool(),
at_rightmost_boundary,
).to(basis.dtype) ).to(basis.dtype)
# Iterative step (Cox-de Boor recursion) # Iterative case of recursion
for i in range(1, k + 1): for i in range(1, self.order):
# First term of the recursion
# Compute the denominators for both terms
denom1 = knots[..., i:-1] - knots[..., : -(i + 1)] denom1 = knots[..., i:-1] - knots[..., : -(i + 1)]
denom2 = knots[..., i + 1 :] - knots[..., 1:-i]
# Ensure no division by zero
denom1 = torch.where( denom1 = torch.where(
torch.abs(denom1) < 1e-8, torch.ones_like(denom1), denom1 torch.abs(denom1) < 1e-8, torch.ones_like(denom1), denom1
) )
numer1 = x - knots[..., : -(i + 1)]
term1 = (numer1 / denom1) * basis[..., :-1]
denom2 = knots[..., i + 1 :] - knots[..., 1:-i]
denom2 = torch.where( denom2 = torch.where(
torch.abs(denom2) < 1e-8, torch.ones_like(denom2), denom2 torch.abs(denom2) < 1e-8, torch.ones_like(denom2), denom2
) )
numer2 = knots[..., i + 1 :] - x
term2 = (numer2 / denom2) * basis[..., 1:]
# Compute the two terms of the recursion
term1 = ((x - knots[..., : -(i + 1)]) / denom1) * basis[..., :-1]
term2 = ((knots[..., i + 1 :] - x) / denom2) * basis[..., 1:]
# Combine terms to get the new basis
basis = term1 + term2 basis = term1 + term2
return basis return basis
def compute_control_points(self, x_eval, y_eval):
"""
Compute control points from given evaluations using least squares.
This method fits the control points to match the target y_eval values.
"""
# (batch, in_dim)
A = self.basis(x_eval, self.k, self.knots)
# (batch, in_dim, n_basis)
in_dim = A.shape[1]
out_dim = y_eval.shape[2]
n_basis = A.shape[2]
c = torch.zeros(in_dim, out_dim, n_basis).to(A.device)
for i in range(in_dim):
# A_i is (batch, n_basis)
# y_i is (batch, out_dim)
A_i = A[:, i, :]
y_i = y_eval[:, i, :]
c_i = torch.linalg.lstsq(A_i, y_i).solution # (n_basis, out_dim)
c[i, :, :] = c_i.T # (out_dim, n_basis)
self.control_points = torch.nn.Parameter(c)
def forward(self, x): def forward(self, x):
""" """
Forward pass for the :class:`Spline` model. Forward pass for the :class:`Spline` model.
:param torch.Tensor x: The input tensor. :param x: The input tensor.
:type x: torch.Tensor | LabelTensor
:return: The output tensor. :return: The output tensor.
:rtype: torch.Tensor :rtype: torch.Tensor
""" """
t = self.knots return torch.einsum(
k = self.k "bi, i -> b",
c = self.control_points self.basis(x.as_subclass(torch.Tensor)).squeeze(1),
self.control_points,
# Create the basis functions ).reshape(-1, 1)
# B will have shape (batch, in_dim, n_basis)
B = self.basis(x, k, t)
# KAN case where control points are (in_dim, out_dim, n_basis)
if c.ndim == 3:
y_ij = torch.einsum(
"bil,iol->bio", B, c
) # (batch, in_dim, out_dim)
# sum over input dimensions
y = torch.sum(y_ij, dim=1) # (batch, out_dim)
# Original test case
else:
B = B.squeeze(1) # (batch, n_basis)
if c.ndim == 1:
y = torch.einsum("bi,i->b", B, c)
else:
y = torch.einsum("bi,ij->bj", B, c)
return y
@property @property
def control_points(self): def control_points(self):
@@ -231,27 +245,42 @@ class Spline(torch.nn.Module):
return self._control_points return self._control_points
@control_points.setter @control_points.setter
def control_points(self, value): def control_points(self, control_points):
""" """
Set the control points of the spline. Set the control points of the spline.
:param value: The control points. :param torch.Tensor control_points: The control points tensor. If None,
:type value: torch.Tensor | dict control points are initialized to learnable parameters with zero
:raises ValueError: If invalid value is passed. initial value. Default is None.
:raises ValueError: If there are not enough knots to define the control
points, due to the relation: #knots = order + #control_points.
:raises ValueError: If control_points is not a torch.Tensor.
""" """
if isinstance(value, dict): # If control points are not provided, initialize them
if "n" not in value: if control_points is None:
raise ValueError("Invalid value for control_points")
n = value["n"]
dim = value.get("dim", 1)
value = torch.zeros(n, dim)
if not isinstance(value, torch.nn.Parameter): # Check that there are enough knots to define control points
value = torch.nn.Parameter(value) if len(self.knots) < self.order + 1:
raise ValueError(
f"Not enough knots to define control points. Got "
f"{len(self.knots)} knots, but need at least "
f"{self.order + 1}."
)
if not isinstance(value, torch.Tensor): # Initialize control points to zero
raise ValueError("Invalid value for control_points") control_points = torch.zeros(len(self.knots) - self.order)
self._control_points = value
# Check validity of control points
elif not isinstance(control_points, torch.Tensor):
raise ValueError(
"control_points must be a torch.Tensor,"
f" got {type(control_points)}"
)
# Set control points
self._control_points = torch.nn.Parameter(
control_points, requires_grad=True
)
@property @property
def knots(self): def knots(self):
@@ -268,37 +297,80 @@ class Spline(torch.nn.Module):
""" """
Set the knots of the spline. Set the knots of the spline.
:param value: The knots. :param value: The knots of the spline. If a tensor is provided, knots
are set directly from the tensor. If a dictionary is provided, it
must contain the keys ``"n"``, ``"min"``, ``"max"``, and ``"mode"``.
Here, ``"n"`` specifies the number of knots, ``"min"`` and ``"max"``
define the interval, and ``"mode"`` selects the sampling strategy.
The supported modes are ``"uniform"``, where the knots are evenly
spaced over :math:`[min, max]`, and ``"auto"``, where knots are
constructed to ensure that the spline interpolates the first and
last control points. In this case, the number of knots is inferred
and the ``"n"`` key is ignored.
:type value: torch.Tensor | dict :type value: torch.Tensor | dict
:raises ValueError: If invalid value is passed. :raises ValueError: If value is not a torch.Tensor or a dictionary.
:raises ValueError: If a dictionary is provided but does not contain
the required keys.
:raises ValueError: If the mode specified in the dictionary is invalid.
""" """
# Check validity of knots
if not isinstance(value, (torch.Tensor, dict)):
raise ValueError(
"Knots must be a torch.Tensor or a dictionary,"
f" got {type(value)}."
)
# If a dictionary is provided, initialize knots accordingly
if isinstance(value, dict): if isinstance(value, dict):
type_ = value.get("type", "auto") # Check that required keys are present
min_ = value.get("min", 0) required_keys = {"n", "min", "max", "mode"}
max_ = value.get("max", 1) if not required_keys.issubset(value.keys()):
n = value.get("n", 10) raise ValueError(
f"When providing knots as a dictionary, the following "
f"keys must be present: {required_keys}. Got "
f"{value.keys()}."
)
if type_ == "uniform": # Uniform sampling of knots
value = torch.linspace(min_, max_, n + self.k + 1) if value["mode"] == "uniform":
elif type_ == "auto": value = torch.linspace(value["min"], value["max"], value["n"])
initial_knots = torch.ones(self.order + 1) * min_
final_knots = torch.ones(self.order + 1) * max_
if n < self.order + 1: # Automatic sampling of interpolating knots
value = torch.concatenate((initial_knots, final_knots)) elif value["mode"] == "auto":
elif n - 2 * self.order + 1 == 1:
value = torch.Tensor([(max_ + min_) / 2]) # Repeat the first and last knots 'order' times
initial_knots = torch.ones(self.order) * value["min"]
final_knots = torch.ones(self.order) * value["max"]
# Number of internal knots
n_internal = value["n"] - 2 * self.order
# If no internal knots are needed, just concatenate boundaries
if n_internal <= 0:
value = torch.cat((initial_knots, final_knots))
# Else, sample internal knots uniformly and exclude boundaries
# Recover the correct number of internal knots when slicing by
# adding 2 to n_internal
else: else:
value = torch.linspace(min_, max_, n - 2 * self.order - 1) internal_knots = torch.linspace(
value["min"], value["max"], n_internal + 2
)[1:-1]
value = torch.cat(
(initial_knots, internal_knots, final_knots)
)
value = torch.concatenate((initial_knots, value, final_knots)) # Raise error if mode is invalid
else:
raise ValueError(
f"Invalid mode for knots initialization. Got "
f"{value['mode']}, but expected 'uniform' or 'auto'."
)
if not isinstance(value, torch.Tensor): # Set knots
raise ValueError("Invalid value for knots") self.register_buffer("_knots", value.sort(dim=0).values)
self._knots = value
# Recompute boundary interval when knots change # Recompute boundary interval when knots change
if hasattr(self, "_boundary_interval_idx"): if hasattr(self, "_boundary_interval_idx"):
self._compute_boundary_interval() self._boundary_interval_idx = self._compute_boundary_interval()

View File

@@ -1,81 +1,171 @@
import torch import torch
import pytest import pytest
import numpy as np
from scipy.interpolate import BSpline
from pina.model import Spline from pina.model import Spline
from pina import LabelTensor
data = torch.rand((20, 3))
input_vars = 3
output_vars = 4
valid_args = [
{
"knots": torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 3.0, 3.0]),
"control_points": torch.tensor([0.0, 0.0, 1.0, 0.0, 0.0]),
"order": 3,
},
{
"knots": torch.tensor(
[-2.0, -2.0, -2.0, -2.0, -1.0, 0.0, 1.0, 2.0, 2.0, 2.0, 2.0]
),
"control_points": torch.tensor([0.0, 0.0, 0.0, 6.0, 0.0, 0.0, 0.0]),
"order": 4,
},
# {'control_points': {'n': 5, 'dim': 1}, 'order': 2},
# {'control_points': {'n': 7, 'dim': 1}, 'order': 3}
]
def scipy_check(model, x, y): # Utility quantities for testing
from scipy.interpolate._bsplines import BSpline order = torch.randint(1, 8, (1,)).item()
import numpy as np n_ctrl_pts = torch.randint(order, order + 5, (1,)).item()
n_knots = order + n_ctrl_pts
spline = BSpline( # Input tensor
pts = LabelTensor(torch.linspace(0, 1, 100).reshape(-1, 1), ["x"])
# Function to compare with scipy implementation
def check_scipy_spline(model, x, output_):
# Define scipy spline
scipy_spline = BSpline(
t=model.knots.detach().numpy(), t=model.knots.detach().numpy(),
c=model.control_points.detach().numpy(), c=model.control_points.detach().numpy(),
k=model.order - 1, k=model.order - 1,
) )
y_scipy = spline(x).flatten()
y = y.detach().numpy() # Compare outputs
np.testing.assert_allclose(y, y_scipy, atol=1e-5) np.testing.assert_allclose(
output_.squeeze().detach().numpy(),
scipy_spline(x).flatten(),
atol=1e-5,
rtol=1e-5,
)
# Define all possible combinations of valid arguments for the Spline class
valid_args = [
{
"order": order,
"control_points": torch.rand(n_ctrl_pts),
"knots": torch.linspace(0, 1, n_knots),
},
{
"order": order,
"control_points": torch.rand(n_ctrl_pts),
"knots": {"n": n_knots, "min": 0, "max": 1, "mode": "auto"},
},
{
"order": order,
"control_points": torch.rand(n_ctrl_pts),
"knots": {"n": n_knots, "min": 0, "max": 1, "mode": "uniform"},
},
{
"order": order,
"control_points": None,
"knots": torch.linspace(0, 1, n_knots),
},
{
"order": order,
"control_points": None,
"knots": {"n": n_knots, "min": 0, "max": 1, "mode": "auto"},
},
{
"order": order,
"control_points": None,
"knots": {"n": n_knots, "min": 0, "max": 1, "mode": "uniform"},
},
{
"order": order,
"control_points": torch.rand(n_ctrl_pts),
"knots": None,
},
]
@pytest.mark.parametrize("args", valid_args) @pytest.mark.parametrize("args", valid_args)
def test_constructor(args): def test_constructor(args):
Spline(**args) Spline(**args)
# Should fail if order is not a positive integer
with pytest.raises(AssertionError):
Spline(
order=-1, control_points=args["control_points"], knots=args["knots"]
)
def test_constructor_wrong(): # Should fail if control_points is not None or a torch.Tensor
with pytest.raises(ValueError): with pytest.raises(ValueError):
Spline() Spline(
order=args["order"], control_points=[1, 2, 3], knots=args["knots"]
)
# Should fail if knots is not None, a torch.Tensor, or a dict
with pytest.raises(ValueError):
Spline(
order=args["order"], control_points=args["control_points"], knots=5
)
# Should fail if both knots and control_points are None
with pytest.raises(ValueError):
Spline(order=args["order"], control_points=None, knots=None)
# Should fail if knots is not one-dimensional
with pytest.raises(ValueError):
Spline(
order=args["order"],
control_points=args["control_points"],
knots=torch.rand(n_knots, 4),
)
# Should fail if control_points is not one-dimensional
with pytest.raises(ValueError):
Spline(
order=args["order"],
control_points=torch.rand(n_ctrl_pts, 4),
knots=args["knots"],
)
# Should fail if the number of knots != order + number of control points
# If control points are None, they are initialized to fulfill this condition
if args["control_points"] is not None:
with pytest.raises(ValueError):
Spline(
order=args["order"],
control_points=args["control_points"],
knots=torch.linspace(0, 1, n_knots + 1),
)
# Should fail if the knot dict is missing required keys
with pytest.raises(ValueError):
Spline(
order=args["order"],
control_points=args["control_points"],
knots={"n": n_knots, "min": 0, "max": 1},
)
# Should fail if the knot dict has invalid 'mode' key
with pytest.raises(ValueError):
Spline(
order=args["order"],
control_points=args["control_points"],
knots={"n": n_knots, "min": 0, "max": 1, "mode": "invalid"},
)
@pytest.mark.parametrize("args", valid_args) @pytest.mark.parametrize("args", valid_args)
def test_forward(args): def test_forward(args):
min_x = args["knots"][0]
max_x = args["knots"][-1] # Define the model
xi = torch.linspace(min_x, max_x, 1000)
model = Spline(**args) model = Spline(**args)
yi = model(xi).squeeze()
scipy_check(model, xi, yi) # Evaluate the model
return output_ = model(pts)
assert output_.shape == (pts.shape[0], 1)
# Compare with scipy implementation only for interpolant knots (mode: auto)
if isinstance(args["knots"], dict) and args["knots"]["mode"] == "auto":
check_scipy_spline(model, pts, output_)
@pytest.mark.parametrize("args", valid_args) @pytest.mark.parametrize("args", valid_args)
def test_backward(args): def test_backward(args):
min_x = args["knots"][0]
max_x = args["knots"][-1]
xi = torch.linspace(min_x, max_x, 100)
model = Spline(**args)
yi = model(xi)
fake_loss = torch.sum(yi)
assert model.control_points.grad is None
fake_loss.backward()
assert model.control_points.grad is not None
# dim_in, dim_out = 3, 2 # Define the model
# fnn = FeedForward(dim_in, dim_out) model = Spline(**args)
# data.requires_grad = True
# output_ = fnn(data) # Evaluate the model
# l=torch.mean(output_) output_ = model(pts)
# l.backward() loss = torch.mean(output_)
# assert data._grad.shape == torch.Size([20,3]) loss.backward()
assert model.control_points.grad.shape == model.control_points.shape