From df4ea64c747be8a5d0dece65e1040209aacdb9f3 Mon Sep 17 00:00:00 2001 From: GiovanniCanali Date: Mon, 6 Oct 2025 15:50:14 +0200 Subject: [PATCH] add b-spline surface --- docs/source/_rst/_code.rst | 1 + docs/source/_rst/model/spline_surface.rst | 7 + pina/model/__init__.py | 1 + pina/model/spline.py | 32 ++-- pina/model/spline_surface.py | 212 ++++++++++++++++++++++ tests/test_model/test_spline.py | 22 ++- tests/test_model/test_spline_surface.py | 180 ++++++++++++++++++ 7 files changed, 425 insertions(+), 30 deletions(-) create mode 100644 docs/source/_rst/model/spline_surface.rst create mode 100644 pina/model/spline_surface.py create mode 100644 tests/test_model/test_spline_surface.py diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index 965a286..1516994 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -95,6 +95,7 @@ Models MultiFeedForward ResidualFeedForward Spline + SplineSurface DeepONet MIONet KernelNeuralOperator diff --git a/docs/source/_rst/model/spline_surface.rst b/docs/source/_rst/model/spline_surface.rst new file mode 100644 index 0000000..6bbf137 --- /dev/null +++ b/docs/source/_rst/model/spline_surface.rst @@ -0,0 +1,7 @@ +Spline Surface +================ +.. currentmodule:: pina.model.spline_surface + +.. autoclass:: SplineSurface + :members: + :show-inheritance: \ No newline at end of file diff --git a/pina/model/__init__.py b/pina/model/__init__.py index 1edeacd..05ccc6c 100644 --- a/pina/model/__init__.py +++ b/pina/model/__init__.py @@ -26,6 +26,7 @@ from .kernel_neural_operator import KernelNeuralOperator from .average_neural_operator import AveragingNeuralOperator from .low_rank_neural_operator import LowRankNeuralOperator from .spline import Spline +from .spline_surface import SplineSurface from .graph_neural_operator import GraphNeuralOperator from .pirate_network import PirateNet from .equivariant_graph_neural_operator import EquivariantGraphNeuralOperator diff --git a/pina/model/spline.py b/pina/model/spline.py index 6800384..a276a6c 100644 --- a/pina/model/spline.py +++ b/pina/model/spline.py @@ -1,8 +1,8 @@ """Module for the B-Spline model class.""" -import torch import warnings -from ..utils import check_positive_integer +import torch +from ..utils import check_positive_integer, check_consistency class Spline(torch.nn.Module): @@ -75,6 +75,10 @@ class Spline(torch.nn.Module): If None, they are initialized as learnable parameters with an initial value of zero. Default is None. :raises AssertionError: If ``order`` is not a positive integer. + :raises ValueError: If ``knots`` is neither a torch.Tensor nor a + dictionary, when provided. + :raises ValueError: If ``control_points`` is not a torch.Tensor, + when provided. :raises ValueError: If both ``knots`` and ``control_points`` are None. :raises ValueError: If ``knots`` is not one-dimensional. :raises ValueError: If ``control_points`` is not one-dimensional. @@ -87,6 +91,8 @@ class Spline(torch.nn.Module): # Check consistency check_positive_integer(value=order, strict=True) + check_consistency(knots, (type(None), torch.Tensor, dict)) + check_consistency(control_points, (type(None), torch.Tensor)) # Raise error if neither knots nor control points are provided if knots is None and control_points is None: @@ -229,10 +235,10 @@ class Spline(torch.nn.Module): :rtype: torch.Tensor """ return torch.einsum( - "bi, i -> b", - self.basis(x.as_subclass(torch.Tensor)).squeeze(1), + "...bi, i -> ...b", + self.basis(x.as_subclass(torch.Tensor)).squeeze(-1), self.control_points, - ).reshape(-1, 1) + ) @property def control_points(self): @@ -254,7 +260,6 @@ class Spline(torch.nn.Module): initial value. Default is None. :raises ValueError: If there are not enough knots to define the control points, due to the relation: #knots = order + #control_points. - :raises ValueError: If control_points is not a torch.Tensor. """ # If control points are not provided, initialize them if control_points is None: @@ -270,13 +275,6 @@ class Spline(torch.nn.Module): # Initialize control points to zero control_points = torch.zeros(len(self.knots) - self.order) - # Check validity of control points - elif not isinstance(control_points, torch.Tensor): - raise ValueError( - "control_points must be a torch.Tensor," - f" got {type(control_points)}" - ) - # Set control points self._control_points = torch.nn.Parameter( control_points, requires_grad=True @@ -308,18 +306,10 @@ class Spline(torch.nn.Module): last control points. In this case, the number of knots is inferred and the ``"n"`` key is ignored. :type value: torch.Tensor | dict - :raises ValueError: If value is not a torch.Tensor or a dictionary. :raises ValueError: If a dictionary is provided but does not contain the required keys. :raises ValueError: If the mode specified in the dictionary is invalid. """ - # Check validity of knots - if not isinstance(value, (torch.Tensor, dict)): - raise ValueError( - "Knots must be a torch.Tensor or a dictionary," - f" got {type(value)}." - ) - # If a dictionary is provided, initialize knots accordingly if isinstance(value, dict): diff --git a/pina/model/spline_surface.py b/pina/model/spline_surface.py new file mode 100644 index 0000000..30d41bb --- /dev/null +++ b/pina/model/spline_surface.py @@ -0,0 +1,212 @@ +"""Module for the bivariate B-Spline surface model class.""" + +import torch +from .spline import Spline +from ..utils import check_consistency + + +class SplineSurface(torch.nn.Module): + r""" + The bivariate B-Spline surface model class. + + A bivariate B-spline surface is a parametric surface defined as the tensor + product of two univariate B-spline curves: + + .. math:: + + S(x, y) = \sum_{i,j=1}^{n_x, n_y} B_{i,k}(x) B_{j,s}(y) C_{i,j}, + \quad x \in [x_1, x_m], y \in [y_1, y_l] + + where: + + - :math:`C_{i,j} \in \mathbb{R}^2` are the control points. These fixed + points influence the shape of the surface but are not generally + interpolated, except at the boundaries under certain knot multiplicities. + - :math:`B_{i,k}(x)` and :math:`B_{j,s}(y)` are the B-spline basis functions + defined over two orthogonal directions, with orders :math:`k` and + :math:`s`, respectively. + - :math:`X = \{ x_1, x_2, \dots, x_m \}` and + :math:`Y = \{ y_1, y_2, \dots, y_l \}` are the non-decreasing knot + vectors along the two directions. + """ + + def __init__(self, orders, knots_u=None, knots_v=None, control_points=None): + """ + Initialization of the :class:`SplineSurface` class. + + :param list[int] orders: The orders of the spline along each parametric + direction. Each order defines the degree of the corresponding basis + as ``degree = order - 1``. + :param knots_u: The knots of the spline along the first direction. + For details on valid formats and initialization modes, see the + :class:`Spline` class. Default is None. + :type knots_u: torch.Tensor | dict + :param knots_v: The knots of the spline along the second direction. + For details on valid formats and initialization modes, see the + :class:`Spline` class. Default is None. + :type knots_v: torch.Tensor | dict + :param torch.Tensor control_points: The control points defining the + surface geometry. It must be a two-dimensional tensor of shape + ``[len(knots_u) - orders[0], len(knots_v) - orders[1]]``. + If None, they are initialized as learnable parameters with zero + values. Default is None. + :raises ValueError: If ``orders`` is not a list of integers. + :raises ValueError: If ``knots_u`` is neither a torch.Tensor nor a + dictionary, when provided. + :raises ValueError: If ``knots_v`` is neither a torch.Tensor nor a + dictionary, when provided. + :raises ValueError: If ``control_points`` is not a torch.Tensor, + when provided. + :raises ValueError: If ``orders`` is not a list of two elements. + :raises ValueError: If ``knots_u``, ``knots_v``, and ``control_points`` + are all None. + """ + super().__init__() + + # Check consistency + check_consistency(orders, int) + check_consistency(control_points, (type(None), torch.Tensor)) + check_consistency(knots_u, (type(None), torch.Tensor, dict)) + check_consistency(knots_v, (type(None), torch.Tensor, dict)) + + # Check orders is a list of two elements + if len(orders) != 2: + raise ValueError("orders must be a list of two elements.") + + # Raise error if neither knots nor control points are provided + if (knots_u is None or knots_v is None) and control_points is None: + raise ValueError( + "control_points cannot be None if knots_u or knots_v is None." + ) + + # Initialize knots_u if not provided + if knots_u is None and control_points is not None: + knots_u = { + "n": control_points.shape[0] + orders[0], + "min": 0, + "max": 1, + "mode": "auto", + } + + # Initialize knots_v if not provided + if knots_v is None and control_points is not None: + knots_v = { + "n": control_points.shape[1] + orders[1], + "min": 0, + "max": 1, + "mode": "auto", + } + + # Create two univariate b-splines + self.spline_u = Spline(order=orders[0], knots=knots_u) + self.spline_v = Spline(order=orders[1], knots=knots_v) + self.control_points = control_points + + # Delete unneeded parameters + delattr(self.spline_u, "_control_points") + delattr(self.spline_v, "_control_points") + + def forward(self, x): + """ + Forward pass for the :class:`SplineSurface` model. + + :param x: The input tensor. + :type x: torch.Tensor | LabelTensor + :return: The output tensor. + :rtype: torch.Tensor + """ + return torch.einsum( + "...bi, ...bj, ij -> ...b", + self.spline_u.basis(x.as_subclass(torch.Tensor)[..., 0]), + self.spline_v.basis(x.as_subclass(torch.Tensor)[..., 1]), + self.control_points, + ).unsqueeze(-1) + + @property + def knots(self): + """ + The knots of the univariate splines defining the spline surface. + + :return: The knots. + :rtype: tuple(torch.Tensor, torch.Tensor) + """ + return self.spline_u.knots, self.spline_v.knots + + @knots.setter + def knots(self, value): + """ + Set the knots of the spline surface. + + :param value: A tuple (knots_u, knots_v) containing the knots for both + parametric directions. + :type value: tuple(torch.Tensor | dict, torch.Tensor | dict) + :raises ValueError: If value is not a tuple of two elements. + """ + # Check value is a tuple of two elements + if not (isinstance(value, tuple) and len(value) == 2): + raise ValueError("Knots must be a tuple of two elements.") + + knots_u, knots_v = value + self.spline_u.knots = knots_u + self.spline_v.knots = knots_v + + @property + def control_points(self): + """ + The control points of the spline. + + :return: The control points. + :rtype: torch.Tensor + """ + return self._control_points + + @control_points.setter + def control_points(self, control_points): + """ + Set the control points of the spline surface. + + :param torch.Tensor control_points: The bidimensional control points + tensor, where each dimension refers to a direction in the parameter + space. If None, control points are initialized to learnable + parameters with zero initial value. Default is None. + :raises ValueError: If in any direction there are not enough knots to + define the control points, due to the relation: + #knots = order + #control_points. + :raises ValueError: If ``control_points`` is not of the correct shape. + """ + # Save correct shape of control points + __valid_shape = ( + len(self.spline_u.knots) - self.spline_u.order, + len(self.spline_v.knots) - self.spline_v.order, + ) + + # If control points are not provided, initialize them + if control_points is None: + + # Check that there are enough knots to define control points + if ( + len(self.spline_u.knots) < self.spline_u.order + 1 + or len(self.spline_v.knots) < self.spline_v.order + 1 + ): + raise ValueError( + f"Not enough knots to define control points. Got " + f"{len(self.spline_u.knots)} knots along u and " + f"{len(self.spline_v.knots)} knots along v, but need at " + f"least {self.spline_u.order + 1} and " + f"{self.spline_v.order + 1}, respectively." + ) + + # Initialize control points to zero + control_points = torch.zeros(__valid_shape) + + # Check control points + if control_points.shape != __valid_shape: + raise ValueError( + "control_points must be of the correct shape. ", + f"Expected {__valid_shape}, got {control_points.shape}.", + ) + + # Register control points as a learnable parameter + self._control_points = torch.nn.Parameter( + control_points, requires_grad=True + ) diff --git a/tests/test_model/test_spline.py b/tests/test_model/test_spline.py index c30e542..d22de9f 100644 --- a/tests/test_model/test_spline.py +++ b/tests/test_model/test_spline.py @@ -1,6 +1,5 @@ import torch import pytest -import numpy as np from scipy.interpolate import BSpline from pina.model import Spline from pina import LabelTensor @@ -12,7 +11,10 @@ n_ctrl_pts = torch.randint(order, order + 5, (1,)).item() n_knots = order + n_ctrl_pts # Input tensor -pts = LabelTensor(torch.linspace(0, 1, 100).reshape(-1, 1), ["x"]) +points = [ + LabelTensor(torch.rand(100, 1), ["x"]), + LabelTensor(torch.rand(2, 100, 1), ["x"]), +] # Function to compare with scipy implementation @@ -26,15 +28,15 @@ def check_scipy_spline(model, x, output_): ) # Compare outputs - np.testing.assert_allclose( - output_.squeeze().detach().numpy(), - scipy_spline(x).flatten(), + torch.allclose( + output_, + torch.tensor(scipy_spline(x), dtype=output_.dtype), atol=1e-5, rtol=1e-5, ) -# Define all possible combinations of valid arguments for the Spline class +# Define all possible combinations of valid arguments for Spline class valid_args = [ { "order": order, @@ -144,14 +146,15 @@ def test_constructor(args): @pytest.mark.parametrize("args", valid_args) -def test_forward(args): +@pytest.mark.parametrize("pts", points) +def test_forward(args, pts): # Define the model model = Spline(**args) # Evaluate the model output_ = model(pts) - assert output_.shape == (pts.shape[0], 1) + assert output_.shape == pts.shape # Compare with scipy implementation only for interpolant knots (mode: auto) if isinstance(args["knots"], dict) and args["knots"]["mode"] == "auto": @@ -159,7 +162,8 @@ def test_forward(args): @pytest.mark.parametrize("args", valid_args) -def test_backward(args): +@pytest.mark.parametrize("pts", points) +def test_backward(args, pts): # Define the model model = Spline(**args) diff --git a/tests/test_model/test_spline_surface.py b/tests/test_model/test_spline_surface.py new file mode 100644 index 0000000..feab587 --- /dev/null +++ b/tests/test_model/test_spline_surface.py @@ -0,0 +1,180 @@ +import torch +import random +import pytest +from pina.model import SplineSurface +from pina import LabelTensor + + +# Utility quantities for testing +orders = [random.randint(1, 8) for _ in range(2)] +n_ctrl_pts = random.randint(max(orders), max(orders) + 5) +n_knots = [orders[i] + n_ctrl_pts for i in range(2)] + +# Input tensor +points = [ + LabelTensor(torch.rand(100, 2), ["x", "y"]), + LabelTensor(torch.rand(2, 100, 2), ["x", "y"]), +] + + +@pytest.mark.parametrize( + "knots_u", + [ + torch.rand(n_knots[0]), + {"n": n_knots[0], "min": 0, "max": 1, "mode": "auto"}, + {"n": n_knots[0], "min": 0, "max": 1, "mode": "uniform"}, + None, + ], +) +@pytest.mark.parametrize( + "knots_v", + [ + torch.rand(n_knots[1]), + {"n": n_knots[1], "min": 0, "max": 1, "mode": "auto"}, + {"n": n_knots[1], "min": 0, "max": 1, "mode": "uniform"}, + None, + ], +) +@pytest.mark.parametrize( + "control_points", [torch.rand(n_ctrl_pts, n_ctrl_pts), None] +) +def test_constructor(knots_u, knots_v, control_points): + + # Skip if knots_u, knots_v, and control_points are all None + if (knots_u is None or knots_v is None) and control_points is None: + return + + SplineSurface( + orders=orders, + knots_u=knots_u, + knots_v=knots_v, + control_points=control_points, + ) + + # Should fail if orders is not list of two elements + with pytest.raises(ValueError): + SplineSurface( + orders=[orders[0]], + knots_u=knots_u, + knots_v=knots_v, + control_points=control_points, + ) + + # Should fail if both knots and control_points are None + with pytest.raises(ValueError): + SplineSurface( + orders=orders, + knots_u=None, + knots_v=None, + control_points=None, + ) + + # Should fail if control_points is not a torch.Tensor when provided + with pytest.raises(ValueError): + SplineSurface( + orders=orders, + knots_u=knots_u, + knots_v=knots_v, + control_points=[[0.0] * n_ctrl_pts] * n_ctrl_pts, + ) + + # Should fail if control_points is not of the correct shape when provided + # It assumes that at least one among knots_u and knots_v is not None + if knots_u is not None or knots_v is not None: + with pytest.raises(ValueError): + SplineSurface( + orders=orders, + knots_u=knots_u, + knots_v=knots_v, + control_points=torch.rand(n_ctrl_pts + 1, n_ctrl_pts + 1), + ) + + # Should fail if there are not enough knots_u to define the control points + with pytest.raises(ValueError): + SplineSurface( + orders=orders, + knots_u=torch.linspace(0, 1, orders[0]), + knots_v=knots_v, + control_points=None, + ) + + # Should fail if there are not enough knots_v to define the control points + with pytest.raises(ValueError): + SplineSurface( + orders=orders, + knots_u=knots_u, + knots_v=torch.linspace(0, 1, orders[1]), + control_points=None, + ) + + +@pytest.mark.parametrize( + "knots_u", + [ + torch.rand(n_knots[0]), + {"n": n_knots[0], "min": 0, "max": 1, "mode": "auto"}, + {"n": n_knots[0], "min": 0, "max": 1, "mode": "uniform"}, + ], +) +@pytest.mark.parametrize( + "knots_v", + [ + torch.rand(n_knots[1]), + {"n": n_knots[1], "min": 0, "max": 1, "mode": "auto"}, + {"n": n_knots[1], "min": 0, "max": 1, "mode": "uniform"}, + ], +) +@pytest.mark.parametrize( + "control_points", [torch.rand(n_ctrl_pts, n_ctrl_pts), None] +) +@pytest.mark.parametrize("pts", points) +def test_forward(knots_u, knots_v, control_points, pts): + + # Define the model + model = SplineSurface( + orders=orders, + knots_u=knots_u, + knots_v=knots_v, + control_points=control_points, + ) + + # Evaluate the model + output_ = model(pts) + assert output_.shape == (*pts.shape[:-1], 1) + + +@pytest.mark.parametrize( + "knots_u", + [ + torch.rand(n_knots[0]), + {"n": n_knots[0], "min": 0, "max": 1, "mode": "auto"}, + {"n": n_knots[0], "min": 0, "max": 1, "mode": "uniform"}, + ], +) +@pytest.mark.parametrize( + "knots_v", + [ + torch.rand(n_knots[1]), + {"n": n_knots[1], "min": 0, "max": 1, "mode": "auto"}, + {"n": n_knots[1], "min": 0, "max": 1, "mode": "uniform"}, + ], +) +@pytest.mark.parametrize( + "control_points", [torch.rand(n_ctrl_pts, n_ctrl_pts), None] +) +@pytest.mark.parametrize("pts", points) +def test_backward(knots_u, knots_v, control_points, pts): + + # Define the model + model = SplineSurface( + orders=orders, + knots_u=knots_u, + knots_v=knots_v, + control_points=control_points, + ) + + # Evaluate the model + output_ = model(pts) + loss = torch.mean(output_) + loss.backward() + assert model.control_points.grad.shape == model.control_points.shape