diff --git a/pina/model/spline.py b/pina/model/spline.py index c6f3c55..6800384 100644 --- a/pina/model/spline.py +++ b/pina/model/spline.py @@ -1,224 +1,238 @@ -"""Module for the Spline model class.""" +"""Module for the B-Spline model class.""" import torch -from ..utils import check_consistency +import warnings +from ..utils import check_positive_integer class Spline(torch.nn.Module): - """ - Spline model class. + r""" + The univariate B-Spline curve model class. + + A univariate B-spline curve of order :math:`k` is a parametric curve defined + as a linear combination of B-spline basis functions and control points: + + .. math:: + + S(x) = \sum_{i=1}^{n} B_{i,k}(x) C_i, \quad x \in [x_1, x_m] + + where: + + - :math:`C_i \in \mathbb{R}` are the control points. These fixed points + influence the shape of the curve but are not generally interpolated, + except at the boundaries under certain knot multiplicities. + - :math:`B_{i,k}(x)` are the B-spline basis functions of order :math:`k`, + i.e., piecewise polynomials of degree :math:`k-1` with support on the + interval :math:`[x_i, x_{i+k}]`. + - :math:`X = \{ x_1, x_2, \dots, x_m \}` is the non-decreasing knot vector. + + If the first and last knots are repeated :math:`k` times, then the curve + interpolates the first and last control points. + + + .. note:: + + The curve is forced to be zero outside the interval defined by the + first and last knots. + + + :Example: + + >>> from pina.model import Spline + >>> import torch + + >>> knots1 = torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0, 2.0, 2.0]) + >>> spline1 = Spline(order=3, knots=knots1, control_points=None) + + >>> knots2 = {"n": 7, "min": 0.0, "max": 2.0, "mode": "auto"} + >>> spline2 = Spline(order=3, knots=knots2, control_points=None) + + >>> knots3 = torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0, 2.0, 2.0]) + >>> control_points3 = torch.tensor([0.0, 1.0, 3.0, 2.0]) + >>> spline3 = Spline(order=3, knots=knots3, control_points=control_points3) """ - def __init__( - self, order=4, knots=None, control_points=None, grid_extension=True - ): + def __init__(self, order=4, knots=None, control_points=None): """ Initialization of the :class:`Spline` class. - :param int order: The order of the spline. Default is ``4``. - :param torch.Tensor knots: The tensor representing knots. If ``None``, - the knots will be initialized automatically. Default is ``None``. - :param torch.Tensor control_points: The control points. Default is - ``None``. - :raises ValueError: If the order is negative. - :raises ValueError: If both knots and control points are ``None``. - :raises ValueError: If the knot tensor is not one or two dimensional. + :param int order: The order of the spline. The corresponding basis + functions are polynomials of degree ``order - 1``. Default is 4. + :param knots: The knots of the spline. If a tensor is provided, knots + are set directly from the tensor. If a dictionary is provided, it + must contain the keys ``"n"``, ``"min"``, ``"max"``, and ``"mode"``. + Here, ``"n"`` specifies the number of knots, ``"min"`` and ``"max"`` + define the interval, and ``"mode"`` selects the sampling strategy. + The supported modes are ``"uniform"``, where the knots are evenly + spaced over :math:`[min, max]`, and ``"auto"``, where knots are + constructed to ensure that the spline interpolates the first and + last control points. In this case, the number of knots is adjusted + if :math:`n < 2 * order`. If None is given, knots are initialized + automatically over :math:`[0, 1]` ensuring interpolation of the + first and last control points. Default is None. + :type knots: torch.Tensor | dict + :param torch.Tensor control_points: The control points of the spline. + If None, they are initialized as learnable parameters with an + initial value of zero. Default is None. + :raises AssertionError: If ``order`` is not a positive integer. + :raises ValueError: If both ``knots`` and ``control_points`` are None. + :raises ValueError: If ``knots`` is not one-dimensional. + :raises ValueError: If ``control_points`` is not one-dimensional. + :raises ValueError: If the number of ``knots`` is not equal to the sum + of ``order`` and the number of ``control_points.`` + :raises UserWarning: If the number of control points is lower than the + order, resulting in a degenerate spline. """ super().__init__() - check_consistency(order, int) + # Check consistency + check_positive_integer(value=order, strict=True) - if order < 0: - raise ValueError("Spline order cannot be negative.") + # Raise error if neither knots nor control points are provided if knots is None and control_points is None: - raise ValueError("Knots and control points cannot be both None.") + raise ValueError("knots and control_points cannot both be None.") - self.order = order - self.k = order - 1 - self.grid_extension = grid_extension - - # Cache for performance optimization - self._boundary_interval_idx = None - - if knots is not None and control_points is not None: - self.knots = knots - self.control_points = control_points - - elif knots is not None: - print("Warning: control points will be initialized automatically.") - print(" experimental feature") - - self.knots = knots - n = len(knots) - order - self.control_points = torch.nn.Parameter( - torch.zeros(n), requires_grad=True - ) - - elif control_points is not None: - print("Warning: knots will be initialized automatically.") - print(" experimental feature") - - self.control_points = control_points - - n = len(self.control_points) - 1 - self.knots = { - "type": "auto", + # Initialize knots if not provided + if knots is None and control_points is not None: + knots = { + "n": len(control_points) + order, "min": 0, "max": 1, - "n": n + 2 + self.order, + "mode": "auto", } - else: - raise ValueError("Knots and control points cannot be both None.") + # Initialization - knots and control points managed by their setters + self.order = order + self.knots = knots + self.control_points = control_points - if self.knots.ndim > 2: - raise ValueError("Knot vector must be one or two-dimensional.") + # Check dimensionality of knots + if self.knots.ndim > 1: + raise ValueError("knots must be one-dimensional.") - # Precompute boundary interval index for performance - self._compute_boundary_interval() + # Check dimensionality of control points + if self.control_points.ndim > 1: + raise ValueError("control_points must be one-dimensional.") + + # Raise error if #knots != order + #control_points + if len(self.knots) != self.order + len(self.control_points): + raise ValueError( + f" The number of knots must be equal to order + number of" + f" control points. Got {len(self.knots)} knots, {self.order}" + f" order and {len(self.control_points)} control points." + ) + + # Raise warning if spline is degenerate + if len(self.control_points) < self.order: + warnings.warn( + "The number of control points is smaller than the spline order." + " This creates a degenerate spline with limited flexibility.", + UserWarning, + ) + + # Precompute boundary interval index + self._boundary_interval_idx = self._compute_boundary_interval() def _compute_boundary_interval(self): """ - Precompute the rightmost non-degenerate interval index for performance. - This avoids the search loop in the basis function on every call. + Precompute the index of the rightmost non-degenerate interval to improve + performance, eliminating the need to perform a search loop in the basis + function on each call. + + :return: The index of the rightmost non-degenerate interval. + :rtype: int """ - # Handle multi-dimensional knots - if self.knots.ndim > 1: - # For multi-dimensional knots, we'll handle boundary detection in - # the basis function - self._boundary_interval_idx = None - return + # Return 0 if there is a single interval + if len(self.knots) < 2: + return 0 - # For 1D knots, find the rightmost non-degenerate interval - for i in range(len(self.knots) - 2, -1, -1): - if self.knots[i] < self.knots[i + 1]: # Non-degenerate interval found - self._boundary_interval_idx = i - return + # Find all indices where knots are strictly increasing + diffs = self.knots[1:] - self.knots[:-1] + valid = torch.nonzero(diffs > 0, as_tuple=False) - self._boundary_interval_idx = len(self.knots) - 2 if len(self.knots) > 1 else 0 + # If all knots are equal, return 0 for degenerate spline + if valid.numel() == 0: + return 0 - def basis(self, x, k, knots): + # Otherwise, return the last valid index + return int(valid[-1]) + + def basis(self, x): """ Compute the basis functions for the spline using an iterative approach. This is a vectorized implementation based on the Cox-de Boor recursion. :param torch.Tensor x: The points to be evaluated. - :param int k: The spline degree. - :param torch.Tensor knots: The tensor of knots. - :return: The basis functions evaluated at x + :return: The basis functions evaluated at x. :rtype: torch.Tensor """ + # Add a final dimension to x + x = x.unsqueeze(-1) - if x.ndim == 1: - x = x.unsqueeze(1) # (batch_size, 1) - if x.ndim == 2: - x = x.unsqueeze(2) # (batch_size, in_dim, 1) + # Add an initial dimension to knots + knots = self.knots.unsqueeze(0) - if knots.ndim == 1: - knots = knots.unsqueeze(0) # (1, n_knots) - if knots.ndim == 2: - knots = knots.unsqueeze(0) # (1, in_dim, n_knots) - - # Base case: k=0 + # Base case of recursion: indicator functions for the intervals basis = (x >= knots[..., :-1]) & (x < knots[..., 1:]) basis = basis.to(x.dtype) + # One-dimensional knots case: ensure rightmost boundary inclusion if self._boundary_interval_idx is not None: - i = self._boundary_interval_idx - tolerance = 1e-10 - x_squeezed = x.squeeze(-1) - knot_left = knots[..., i] - knot_right = knots[..., i + 1] - at_right_boundary = torch.abs(x_squeezed - knot_right) <= tolerance - in_rightmost_interval = ( - x_squeezed >= knot_left - ) & at_right_boundary + # Extract left and right knots of the rightmost interval + knot_left = knots[..., self._boundary_interval_idx] + knot_right = knots[..., self._boundary_interval_idx + 1] - if torch.any(in_rightmost_interval): - # For points at the boundary, ensure they're included in the - # rightmost interval - basis[..., i] = torch.logical_or( - basis[..., i].bool(), in_rightmost_interval + # Identify points at the rightmost boundary + at_rightmost_boundary = ( + x.squeeze(-1) >= knot_left + ) & torch.isclose(x.squeeze(-1), knot_right, rtol=1e-8, atol=1e-10) + + # Ensure the correct value is set at the rightmost boundary + if torch.any(at_rightmost_boundary): + basis[..., self._boundary_interval_idx] = torch.logical_or( + basis[..., self._boundary_interval_idx].bool(), + at_rightmost_boundary, ).to(basis.dtype) - # Iterative step (Cox-de Boor recursion) - for i in range(1, k + 1): - # First term of the recursion + # Iterative case of recursion + for i in range(1, self.order): + + # Compute the denominators for both terms denom1 = knots[..., i:-1] - knots[..., : -(i + 1)] + denom2 = knots[..., i + 1 :] - knots[..., 1:-i] + + # Ensure no division by zero denom1 = torch.where( torch.abs(denom1) < 1e-8, torch.ones_like(denom1), denom1 ) - numer1 = x - knots[..., : -(i + 1)] - term1 = (numer1 / denom1) * basis[..., :-1] - - denom2 = knots[..., i + 1 :] - knots[..., 1:-i] denom2 = torch.where( torch.abs(denom2) < 1e-8, torch.ones_like(denom2), denom2 ) - numer2 = knots[..., i + 1 :] - x - term2 = (numer2 / denom2) * basis[..., 1:] + # Compute the two terms of the recursion + term1 = ((x - knots[..., : -(i + 1)]) / denom1) * basis[..., :-1] + term2 = ((knots[..., i + 1 :] - x) / denom2) * basis[..., 1:] + + # Combine terms to get the new basis basis = term1 + term2 return basis - def compute_control_points(self, x_eval, y_eval): - """ - Compute control points from given evaluations using least squares. - This method fits the control points to match the target y_eval values. - """ - # (batch, in_dim) - A = self.basis(x_eval, self.k, self.knots) - # (batch, in_dim, n_basis) - - in_dim = A.shape[1] - out_dim = y_eval.shape[2] - n_basis = A.shape[2] - c = torch.zeros(in_dim, out_dim, n_basis).to(A.device) - - for i in range(in_dim): - # A_i is (batch, n_basis) - # y_i is (batch, out_dim) - A_i = A[:, i, :] - y_i = y_eval[:, i, :] - c_i = torch.linalg.lstsq(A_i, y_i).solution # (n_basis, out_dim) - c[i, :, :] = c_i.T # (out_dim, n_basis) - - self.control_points = torch.nn.Parameter(c) - def forward(self, x): """ Forward pass for the :class:`Spline` model. - :param torch.Tensor x: The input tensor. + :param x: The input tensor. + :type x: torch.Tensor | LabelTensor :return: The output tensor. :rtype: torch.Tensor """ - t = self.knots - k = self.k - c = self.control_points - - # Create the basis functions - # B will have shape (batch, in_dim, n_basis) - B = self.basis(x, k, t) - - # KAN case where control points are (in_dim, out_dim, n_basis) - if c.ndim == 3: - y_ij = torch.einsum( - "bil,iol->bio", B, c - ) # (batch, in_dim, out_dim) - # sum over input dimensions - y = torch.sum(y_ij, dim=1) # (batch, out_dim) - # Original test case - else: - B = B.squeeze(1) # (batch, n_basis) - if c.ndim == 1: - y = torch.einsum("bi,i->b", B, c) - else: - y = torch.einsum("bi,ij->bj", B, c) - - return y + return torch.einsum( + "bi, i -> b", + self.basis(x.as_subclass(torch.Tensor)).squeeze(1), + self.control_points, + ).reshape(-1, 1) @property def control_points(self): @@ -231,27 +245,42 @@ class Spline(torch.nn.Module): return self._control_points @control_points.setter - def control_points(self, value): + def control_points(self, control_points): """ Set the control points of the spline. - :param value: The control points. - :type value: torch.Tensor | dict - :raises ValueError: If invalid value is passed. + :param torch.Tensor control_points: The control points tensor. If None, + control points are initialized to learnable parameters with zero + initial value. Default is None. + :raises ValueError: If there are not enough knots to define the control + points, due to the relation: #knots = order + #control_points. + :raises ValueError: If control_points is not a torch.Tensor. """ - if isinstance(value, dict): - if "n" not in value: - raise ValueError("Invalid value for control_points") - n = value["n"] - dim = value.get("dim", 1) - value = torch.zeros(n, dim) + # If control points are not provided, initialize them + if control_points is None: - if not isinstance(value, torch.nn.Parameter): - value = torch.nn.Parameter(value) + # Check that there are enough knots to define control points + if len(self.knots) < self.order + 1: + raise ValueError( + f"Not enough knots to define control points. Got " + f"{len(self.knots)} knots, but need at least " + f"{self.order + 1}." + ) - if not isinstance(value, torch.Tensor): - raise ValueError("Invalid value for control_points") - self._control_points = value + # Initialize control points to zero + control_points = torch.zeros(len(self.knots) - self.order) + + # Check validity of control points + elif not isinstance(control_points, torch.Tensor): + raise ValueError( + "control_points must be a torch.Tensor," + f" got {type(control_points)}" + ) + + # Set control points + self._control_points = torch.nn.Parameter( + control_points, requires_grad=True + ) @property def knots(self): @@ -268,37 +297,80 @@ class Spline(torch.nn.Module): """ Set the knots of the spline. - :param value: The knots. + :param value: The knots of the spline. If a tensor is provided, knots + are set directly from the tensor. If a dictionary is provided, it + must contain the keys ``"n"``, ``"min"``, ``"max"``, and ``"mode"``. + Here, ``"n"`` specifies the number of knots, ``"min"`` and ``"max"`` + define the interval, and ``"mode"`` selects the sampling strategy. + The supported modes are ``"uniform"``, where the knots are evenly + spaced over :math:`[min, max]`, and ``"auto"``, where knots are + constructed to ensure that the spline interpolates the first and + last control points. In this case, the number of knots is inferred + and the ``"n"`` key is ignored. :type value: torch.Tensor | dict - :raises ValueError: If invalid value is passed. + :raises ValueError: If value is not a torch.Tensor or a dictionary. + :raises ValueError: If a dictionary is provided but does not contain + the required keys. + :raises ValueError: If the mode specified in the dictionary is invalid. """ + # Check validity of knots + if not isinstance(value, (torch.Tensor, dict)): + raise ValueError( + "Knots must be a torch.Tensor or a dictionary," + f" got {type(value)}." + ) + + # If a dictionary is provided, initialize knots accordingly if isinstance(value, dict): - type_ = value.get("type", "auto") - min_ = value.get("min", 0) - max_ = value.get("max", 1) - n = value.get("n", 10) + # Check that required keys are present + required_keys = {"n", "min", "max", "mode"} + if not required_keys.issubset(value.keys()): + raise ValueError( + f"When providing knots as a dictionary, the following " + f"keys must be present: {required_keys}. Got " + f"{value.keys()}." + ) - if type_ == "uniform": - value = torch.linspace(min_, max_, n + self.k + 1) - elif type_ == "auto": - initial_knots = torch.ones(self.order + 1) * min_ - final_knots = torch.ones(self.order + 1) * max_ + # Uniform sampling of knots + if value["mode"] == "uniform": + value = torch.linspace(value["min"], value["max"], value["n"]) - if n < self.order + 1: - value = torch.concatenate((initial_knots, final_knots)) - elif n - 2 * self.order + 1 == 1: - value = torch.Tensor([(max_ + min_) / 2]) + # Automatic sampling of interpolating knots + elif value["mode"] == "auto": + + # Repeat the first and last knots 'order' times + initial_knots = torch.ones(self.order) * value["min"] + final_knots = torch.ones(self.order) * value["max"] + + # Number of internal knots + n_internal = value["n"] - 2 * self.order + + # If no internal knots are needed, just concatenate boundaries + if n_internal <= 0: + value = torch.cat((initial_knots, final_knots)) + + # Else, sample internal knots uniformly and exclude boundaries + # Recover the correct number of internal knots when slicing by + # adding 2 to n_internal else: - value = torch.linspace(min_, max_, n - 2 * self.order - 1) + internal_knots = torch.linspace( + value["min"], value["max"], n_internal + 2 + )[1:-1] + value = torch.cat( + (initial_knots, internal_knots, final_knots) + ) - value = torch.concatenate((initial_knots, value, final_knots)) + # Raise error if mode is invalid + else: + raise ValueError( + f"Invalid mode for knots initialization. Got " + f"{value['mode']}, but expected 'uniform' or 'auto'." + ) - if not isinstance(value, torch.Tensor): - raise ValueError("Invalid value for knots") - - self._knots = value + # Set knots + self.register_buffer("_knots", value.sort(dim=0).values) # Recompute boundary interval when knots change if hasattr(self, "_boundary_interval_idx"): - self._compute_boundary_interval() \ No newline at end of file + self._boundary_interval_idx = self._compute_boundary_interval() diff --git a/tests/test_model/test_spline.py b/tests/test_model/test_spline.py index d38b161..c30e542 100644 --- a/tests/test_model/test_spline.py +++ b/tests/test_model/test_spline.py @@ -1,81 +1,171 @@ import torch import pytest - +import numpy as np +from scipy.interpolate import BSpline from pina.model import Spline - -data = torch.rand((20, 3)) -input_vars = 3 -output_vars = 4 - -valid_args = [ - { - "knots": torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 3.0, 3.0]), - "control_points": torch.tensor([0.0, 0.0, 1.0, 0.0, 0.0]), - "order": 3, - }, - { - "knots": torch.tensor( - [-2.0, -2.0, -2.0, -2.0, -1.0, 0.0, 1.0, 2.0, 2.0, 2.0, 2.0] - ), - "control_points": torch.tensor([0.0, 0.0, 0.0, 6.0, 0.0, 0.0, 0.0]), - "order": 4, - }, - # {'control_points': {'n': 5, 'dim': 1}, 'order': 2}, - # {'control_points': {'n': 7, 'dim': 1}, 'order': 3} -] +from pina import LabelTensor -def scipy_check(model, x, y): - from scipy.interpolate._bsplines import BSpline - import numpy as np +# Utility quantities for testing +order = torch.randint(1, 8, (1,)).item() +n_ctrl_pts = torch.randint(order, order + 5, (1,)).item() +n_knots = order + n_ctrl_pts - spline = BSpline( +# Input tensor +pts = LabelTensor(torch.linspace(0, 1, 100).reshape(-1, 1), ["x"]) + + +# Function to compare with scipy implementation +def check_scipy_spline(model, x, output_): + + # Define scipy spline + scipy_spline = BSpline( t=model.knots.detach().numpy(), c=model.control_points.detach().numpy(), k=model.order - 1, ) - y_scipy = spline(x).flatten() - y = y.detach().numpy() - np.testing.assert_allclose(y, y_scipy, atol=1e-5) + + # Compare outputs + np.testing.assert_allclose( + output_.squeeze().detach().numpy(), + scipy_spline(x).flatten(), + atol=1e-5, + rtol=1e-5, + ) + + +# Define all possible combinations of valid arguments for the Spline class +valid_args = [ + { + "order": order, + "control_points": torch.rand(n_ctrl_pts), + "knots": torch.linspace(0, 1, n_knots), + }, + { + "order": order, + "control_points": torch.rand(n_ctrl_pts), + "knots": {"n": n_knots, "min": 0, "max": 1, "mode": "auto"}, + }, + { + "order": order, + "control_points": torch.rand(n_ctrl_pts), + "knots": {"n": n_knots, "min": 0, "max": 1, "mode": "uniform"}, + }, + { + "order": order, + "control_points": None, + "knots": torch.linspace(0, 1, n_knots), + }, + { + "order": order, + "control_points": None, + "knots": {"n": n_knots, "min": 0, "max": 1, "mode": "auto"}, + }, + { + "order": order, + "control_points": None, + "knots": {"n": n_knots, "min": 0, "max": 1, "mode": "uniform"}, + }, + { + "order": order, + "control_points": torch.rand(n_ctrl_pts), + "knots": None, + }, +] @pytest.mark.parametrize("args", valid_args) def test_constructor(args): Spline(**args) + # Should fail if order is not a positive integer + with pytest.raises(AssertionError): + Spline( + order=-1, control_points=args["control_points"], knots=args["knots"] + ) -def test_constructor_wrong(): + # Should fail if control_points is not None or a torch.Tensor with pytest.raises(ValueError): - Spline() + Spline( + order=args["order"], control_points=[1, 2, 3], knots=args["knots"] + ) + + # Should fail if knots is not None, a torch.Tensor, or a dict + with pytest.raises(ValueError): + Spline( + order=args["order"], control_points=args["control_points"], knots=5 + ) + + # Should fail if both knots and control_points are None + with pytest.raises(ValueError): + Spline(order=args["order"], control_points=None, knots=None) + + # Should fail if knots is not one-dimensional + with pytest.raises(ValueError): + Spline( + order=args["order"], + control_points=args["control_points"], + knots=torch.rand(n_knots, 4), + ) + + # Should fail if control_points is not one-dimensional + with pytest.raises(ValueError): + Spline( + order=args["order"], + control_points=torch.rand(n_ctrl_pts, 4), + knots=args["knots"], + ) + + # Should fail if the number of knots != order + number of control points + # If control points are None, they are initialized to fulfill this condition + if args["control_points"] is not None: + with pytest.raises(ValueError): + Spline( + order=args["order"], + control_points=args["control_points"], + knots=torch.linspace(0, 1, n_knots + 1), + ) + + # Should fail if the knot dict is missing required keys + with pytest.raises(ValueError): + Spline( + order=args["order"], + control_points=args["control_points"], + knots={"n": n_knots, "min": 0, "max": 1}, + ) + + # Should fail if the knot dict has invalid 'mode' key + with pytest.raises(ValueError): + Spline( + order=args["order"], + control_points=args["control_points"], + knots={"n": n_knots, "min": 0, "max": 1, "mode": "invalid"}, + ) @pytest.mark.parametrize("args", valid_args) def test_forward(args): - min_x = args["knots"][0] - max_x = args["knots"][-1] - xi = torch.linspace(min_x, max_x, 1000) + + # Define the model model = Spline(**args) - yi = model(xi).squeeze() - scipy_check(model, xi, yi) - return + + # Evaluate the model + output_ = model(pts) + assert output_.shape == (pts.shape[0], 1) + + # Compare with scipy implementation only for interpolant knots (mode: auto) + if isinstance(args["knots"], dict) and args["knots"]["mode"] == "auto": + check_scipy_spline(model, pts, output_) @pytest.mark.parametrize("args", valid_args) def test_backward(args): - min_x = args["knots"][0] - max_x = args["knots"][-1] - xi = torch.linspace(min_x, max_x, 100) - model = Spline(**args) - yi = model(xi) - fake_loss = torch.sum(yi) - assert model.control_points.grad is None - fake_loss.backward() - assert model.control_points.grad is not None - # dim_in, dim_out = 3, 2 - # fnn = FeedForward(dim_in, dim_out) - # data.requires_grad = True - # output_ = fnn(data) - # l=torch.mean(output_) - # l.backward() - # assert data._grad.shape == torch.Size([20,3]) + # Define the model + model = Spline(**args) + + # Evaluate the model + output_ = model(pts) + loss = torch.mean(output_) + loss.backward() + assert model.control_points.grad.shape == model.control_points.shape