import torch import pytest import numpy as np from scipy.interpolate import BSpline from pina.model import Spline from pina import LabelTensor # Utility quantities for testing order = torch.randint(1, 8, (1,)).item() n_ctrl_pts = torch.randint(order, order + 5, (1,)).item() n_knots = order + n_ctrl_pts # Input tensor pts = LabelTensor(torch.linspace(0, 1, 100).reshape(-1, 1), ["x"]) # Function to compare with scipy implementation def check_scipy_spline(model, x, output_): # Define scipy spline scipy_spline = BSpline( t=model.knots.detach().numpy(), c=model.control_points.detach().numpy(), k=model.order - 1, ) # Compare outputs np.testing.assert_allclose( output_.squeeze().detach().numpy(), scipy_spline(x).flatten(), atol=1e-5, rtol=1e-5, ) # Define all possible combinations of valid arguments for the Spline class valid_args = [ { "order": order, "control_points": torch.rand(n_ctrl_pts), "knots": torch.linspace(0, 1, n_knots), }, { "order": order, "control_points": torch.rand(n_ctrl_pts), "knots": {"n": n_knots, "min": 0, "max": 1, "mode": "auto"}, }, { "order": order, "control_points": torch.rand(n_ctrl_pts), "knots": {"n": n_knots, "min": 0, "max": 1, "mode": "uniform"}, }, { "order": order, "control_points": None, "knots": torch.linspace(0, 1, n_knots), }, { "order": order, "control_points": None, "knots": {"n": n_knots, "min": 0, "max": 1, "mode": "auto"}, }, { "order": order, "control_points": None, "knots": {"n": n_knots, "min": 0, "max": 1, "mode": "uniform"}, }, { "order": order, "control_points": torch.rand(n_ctrl_pts), "knots": None, }, ] @pytest.mark.parametrize("args", valid_args) def test_constructor(args): Spline(**args) # Should fail if order is not a positive integer with pytest.raises(AssertionError): Spline( order=-1, control_points=args["control_points"], knots=args["knots"] ) # Should fail if control_points is not None or a torch.Tensor with pytest.raises(ValueError): Spline( order=args["order"], control_points=[1, 2, 3], knots=args["knots"] ) # Should fail if knots is not None, a torch.Tensor, or a dict with pytest.raises(ValueError): Spline( order=args["order"], control_points=args["control_points"], knots=5 ) # Should fail if both knots and control_points are None with pytest.raises(ValueError): Spline(order=args["order"], control_points=None, knots=None) # Should fail if knots is not one-dimensional with pytest.raises(ValueError): Spline( order=args["order"], control_points=args["control_points"], knots=torch.rand(n_knots, 4), ) # Should fail if control_points is not one-dimensional with pytest.raises(ValueError): Spline( order=args["order"], control_points=torch.rand(n_ctrl_pts, 4), knots=args["knots"], ) # Should fail if the number of knots != order + number of control points # If control points are None, they are initialized to fulfill this condition if args["control_points"] is not None: with pytest.raises(ValueError): Spline( order=args["order"], control_points=args["control_points"], knots=torch.linspace(0, 1, n_knots + 1), ) # Should fail if the knot dict is missing required keys with pytest.raises(ValueError): Spline( order=args["order"], control_points=args["control_points"], knots={"n": n_knots, "min": 0, "max": 1}, ) # Should fail if the knot dict has invalid 'mode' key with pytest.raises(ValueError): Spline( order=args["order"], control_points=args["control_points"], knots={"n": n_knots, "min": 0, "max": 1, "mode": "invalid"}, ) @pytest.mark.parametrize("args", valid_args) def test_forward(args): # Define the model model = Spline(**args) # Evaluate the model output_ = model(pts) assert output_.shape == (pts.shape[0], 1) # Compare with scipy implementation only for interpolant knots (mode: auto) if isinstance(args["knots"], dict) and args["knots"]["mode"] == "auto": check_scipy_spline(model, pts, output_) @pytest.mark.parametrize("args", valid_args) def test_backward(args): # Define the model model = Spline(**args) # Evaluate the model output_ = model(pts) loss = torch.mean(output_) loss.backward() assert model.control_points.grad.shape == model.control_points.shape