add b-spline surface

This commit is contained in:
GiovanniCanali
2025-10-06 15:50:14 +02:00
parent 71ce8c55f6
commit df4ea64c74
7 changed files with 425 additions and 30 deletions

View File

@@ -95,6 +95,7 @@ Models
MultiFeedForward <model/multi_feed_forward.rst> MultiFeedForward <model/multi_feed_forward.rst>
ResidualFeedForward <model/residual_feed_forward.rst> ResidualFeedForward <model/residual_feed_forward.rst>
Spline <model/spline.rst> Spline <model/spline.rst>
SplineSurface <model/spline_surface.rst>
DeepONet <model/deeponet.rst> DeepONet <model/deeponet.rst>
MIONet <model/mionet.rst> MIONet <model/mionet.rst>
KernelNeuralOperator <model/kernel_neural_operator.rst> KernelNeuralOperator <model/kernel_neural_operator.rst>

View File

@@ -0,0 +1,7 @@
Spline Surface
================
.. currentmodule:: pina.model.spline_surface
.. autoclass:: SplineSurface
:members:
:show-inheritance:

View File

@@ -26,6 +26,7 @@ from .kernel_neural_operator import KernelNeuralOperator
from .average_neural_operator import AveragingNeuralOperator from .average_neural_operator import AveragingNeuralOperator
from .low_rank_neural_operator import LowRankNeuralOperator from .low_rank_neural_operator import LowRankNeuralOperator
from .spline import Spline from .spline import Spline
from .spline_surface import SplineSurface
from .graph_neural_operator import GraphNeuralOperator from .graph_neural_operator import GraphNeuralOperator
from .pirate_network import PirateNet from .pirate_network import PirateNet
from .equivariant_graph_neural_operator import EquivariantGraphNeuralOperator from .equivariant_graph_neural_operator import EquivariantGraphNeuralOperator

View File

@@ -1,8 +1,8 @@
"""Module for the B-Spline model class.""" """Module for the B-Spline model class."""
import torch
import warnings import warnings
from ..utils import check_positive_integer import torch
from ..utils import check_positive_integer, check_consistency
class Spline(torch.nn.Module): class Spline(torch.nn.Module):
@@ -75,6 +75,10 @@ class Spline(torch.nn.Module):
If None, they are initialized as learnable parameters with an If None, they are initialized as learnable parameters with an
initial value of zero. Default is None. initial value of zero. Default is None.
:raises AssertionError: If ``order`` is not a positive integer. :raises AssertionError: If ``order`` is not a positive integer.
:raises ValueError: If ``knots`` is neither a torch.Tensor nor a
dictionary, when provided.
:raises ValueError: If ``control_points`` is not a torch.Tensor,
when provided.
:raises ValueError: If both ``knots`` and ``control_points`` are None. :raises ValueError: If both ``knots`` and ``control_points`` are None.
:raises ValueError: If ``knots`` is not one-dimensional. :raises ValueError: If ``knots`` is not one-dimensional.
:raises ValueError: If ``control_points`` is not one-dimensional. :raises ValueError: If ``control_points`` is not one-dimensional.
@@ -87,6 +91,8 @@ class Spline(torch.nn.Module):
# Check consistency # Check consistency
check_positive_integer(value=order, strict=True) check_positive_integer(value=order, strict=True)
check_consistency(knots, (type(None), torch.Tensor, dict))
check_consistency(control_points, (type(None), torch.Tensor))
# Raise error if neither knots nor control points are provided # Raise error if neither knots nor control points are provided
if knots is None and control_points is None: if knots is None and control_points is None:
@@ -229,10 +235,10 @@ class Spline(torch.nn.Module):
:rtype: torch.Tensor :rtype: torch.Tensor
""" """
return torch.einsum( return torch.einsum(
"bi, i -> b", "...bi, i -> ...b",
self.basis(x.as_subclass(torch.Tensor)).squeeze(1), self.basis(x.as_subclass(torch.Tensor)).squeeze(-1),
self.control_points, self.control_points,
).reshape(-1, 1) )
@property @property
def control_points(self): def control_points(self):
@@ -254,7 +260,6 @@ class Spline(torch.nn.Module):
initial value. Default is None. initial value. Default is None.
:raises ValueError: If there are not enough knots to define the control :raises ValueError: If there are not enough knots to define the control
points, due to the relation: #knots = order + #control_points. points, due to the relation: #knots = order + #control_points.
:raises ValueError: If control_points is not a torch.Tensor.
""" """
# If control points are not provided, initialize them # If control points are not provided, initialize them
if control_points is None: if control_points is None:
@@ -270,13 +275,6 @@ class Spline(torch.nn.Module):
# Initialize control points to zero # Initialize control points to zero
control_points = torch.zeros(len(self.knots) - self.order) control_points = torch.zeros(len(self.knots) - self.order)
# Check validity of control points
elif not isinstance(control_points, torch.Tensor):
raise ValueError(
"control_points must be a torch.Tensor,"
f" got {type(control_points)}"
)
# Set control points # Set control points
self._control_points = torch.nn.Parameter( self._control_points = torch.nn.Parameter(
control_points, requires_grad=True control_points, requires_grad=True
@@ -308,18 +306,10 @@ class Spline(torch.nn.Module):
last control points. In this case, the number of knots is inferred last control points. In this case, the number of knots is inferred
and the ``"n"`` key is ignored. and the ``"n"`` key is ignored.
:type value: torch.Tensor | dict :type value: torch.Tensor | dict
:raises ValueError: If value is not a torch.Tensor or a dictionary.
:raises ValueError: If a dictionary is provided but does not contain :raises ValueError: If a dictionary is provided but does not contain
the required keys. the required keys.
:raises ValueError: If the mode specified in the dictionary is invalid. :raises ValueError: If the mode specified in the dictionary is invalid.
""" """
# Check validity of knots
if not isinstance(value, (torch.Tensor, dict)):
raise ValueError(
"Knots must be a torch.Tensor or a dictionary,"
f" got {type(value)}."
)
# If a dictionary is provided, initialize knots accordingly # If a dictionary is provided, initialize knots accordingly
if isinstance(value, dict): if isinstance(value, dict):

View File

@@ -0,0 +1,212 @@
"""Module for the bivariate B-Spline surface model class."""
import torch
from .spline import Spline
from ..utils import check_consistency
class SplineSurface(torch.nn.Module):
r"""
The bivariate B-Spline surface model class.
A bivariate B-spline surface is a parametric surface defined as the tensor
product of two univariate B-spline curves:
.. math::
S(x, y) = \sum_{i,j=1}^{n_x, n_y} B_{i,k}(x) B_{j,s}(y) C_{i,j},
\quad x \in [x_1, x_m], y \in [y_1, y_l]
where:
- :math:`C_{i,j} \in \mathbb{R}^2` are the control points. These fixed
points influence the shape of the surface but are not generally
interpolated, except at the boundaries under certain knot multiplicities.
- :math:`B_{i,k}(x)` and :math:`B_{j,s}(y)` are the B-spline basis functions
defined over two orthogonal directions, with orders :math:`k` and
:math:`s`, respectively.
- :math:`X = \{ x_1, x_2, \dots, x_m \}` and
:math:`Y = \{ y_1, y_2, \dots, y_l \}` are the non-decreasing knot
vectors along the two directions.
"""
def __init__(self, orders, knots_u=None, knots_v=None, control_points=None):
"""
Initialization of the :class:`SplineSurface` class.
:param list[int] orders: The orders of the spline along each parametric
direction. Each order defines the degree of the corresponding basis
as ``degree = order - 1``.
:param knots_u: The knots of the spline along the first direction.
For details on valid formats and initialization modes, see the
:class:`Spline` class. Default is None.
:type knots_u: torch.Tensor | dict
:param knots_v: The knots of the spline along the second direction.
For details on valid formats and initialization modes, see the
:class:`Spline` class. Default is None.
:type knots_v: torch.Tensor | dict
:param torch.Tensor control_points: The control points defining the
surface geometry. It must be a two-dimensional tensor of shape
``[len(knots_u) - orders[0], len(knots_v) - orders[1]]``.
If None, they are initialized as learnable parameters with zero
values. Default is None.
:raises ValueError: If ``orders`` is not a list of integers.
:raises ValueError: If ``knots_u`` is neither a torch.Tensor nor a
dictionary, when provided.
:raises ValueError: If ``knots_v`` is neither a torch.Tensor nor a
dictionary, when provided.
:raises ValueError: If ``control_points`` is not a torch.Tensor,
when provided.
:raises ValueError: If ``orders`` is not a list of two elements.
:raises ValueError: If ``knots_u``, ``knots_v``, and ``control_points``
are all None.
"""
super().__init__()
# Check consistency
check_consistency(orders, int)
check_consistency(control_points, (type(None), torch.Tensor))
check_consistency(knots_u, (type(None), torch.Tensor, dict))
check_consistency(knots_v, (type(None), torch.Tensor, dict))
# Check orders is a list of two elements
if len(orders) != 2:
raise ValueError("orders must be a list of two elements.")
# Raise error if neither knots nor control points are provided
if (knots_u is None or knots_v is None) and control_points is None:
raise ValueError(
"control_points cannot be None if knots_u or knots_v is None."
)
# Initialize knots_u if not provided
if knots_u is None and control_points is not None:
knots_u = {
"n": control_points.shape[0] + orders[0],
"min": 0,
"max": 1,
"mode": "auto",
}
# Initialize knots_v if not provided
if knots_v is None and control_points is not None:
knots_v = {
"n": control_points.shape[1] + orders[1],
"min": 0,
"max": 1,
"mode": "auto",
}
# Create two univariate b-splines
self.spline_u = Spline(order=orders[0], knots=knots_u)
self.spline_v = Spline(order=orders[1], knots=knots_v)
self.control_points = control_points
# Delete unneeded parameters
delattr(self.spline_u, "_control_points")
delattr(self.spline_v, "_control_points")
def forward(self, x):
"""
Forward pass for the :class:`SplineSurface` model.
:param x: The input tensor.
:type x: torch.Tensor | LabelTensor
:return: The output tensor.
:rtype: torch.Tensor
"""
return torch.einsum(
"...bi, ...bj, ij -> ...b",
self.spline_u.basis(x.as_subclass(torch.Tensor)[..., 0]),
self.spline_v.basis(x.as_subclass(torch.Tensor)[..., 1]),
self.control_points,
).unsqueeze(-1)
@property
def knots(self):
"""
The knots of the univariate splines defining the spline surface.
:return: The knots.
:rtype: tuple(torch.Tensor, torch.Tensor)
"""
return self.spline_u.knots, self.spline_v.knots
@knots.setter
def knots(self, value):
"""
Set the knots of the spline surface.
:param value: A tuple (knots_u, knots_v) containing the knots for both
parametric directions.
:type value: tuple(torch.Tensor | dict, torch.Tensor | dict)
:raises ValueError: If value is not a tuple of two elements.
"""
# Check value is a tuple of two elements
if not (isinstance(value, tuple) and len(value) == 2):
raise ValueError("Knots must be a tuple of two elements.")
knots_u, knots_v = value
self.spline_u.knots = knots_u
self.spline_v.knots = knots_v
@property
def control_points(self):
"""
The control points of the spline.
:return: The control points.
:rtype: torch.Tensor
"""
return self._control_points
@control_points.setter
def control_points(self, control_points):
"""
Set the control points of the spline surface.
:param torch.Tensor control_points: The bidimensional control points
tensor, where each dimension refers to a direction in the parameter
space. If None, control points are initialized to learnable
parameters with zero initial value. Default is None.
:raises ValueError: If in any direction there are not enough knots to
define the control points, due to the relation:
#knots = order + #control_points.
:raises ValueError: If ``control_points`` is not of the correct shape.
"""
# Save correct shape of control points
__valid_shape = (
len(self.spline_u.knots) - self.spline_u.order,
len(self.spline_v.knots) - self.spline_v.order,
)
# If control points are not provided, initialize them
if control_points is None:
# Check that there are enough knots to define control points
if (
len(self.spline_u.knots) < self.spline_u.order + 1
or len(self.spline_v.knots) < self.spline_v.order + 1
):
raise ValueError(
f"Not enough knots to define control points. Got "
f"{len(self.spline_u.knots)} knots along u and "
f"{len(self.spline_v.knots)} knots along v, but need at "
f"least {self.spline_u.order + 1} and "
f"{self.spline_v.order + 1}, respectively."
)
# Initialize control points to zero
control_points = torch.zeros(__valid_shape)
# Check control points
if control_points.shape != __valid_shape:
raise ValueError(
"control_points must be of the correct shape. ",
f"Expected {__valid_shape}, got {control_points.shape}.",
)
# Register control points as a learnable parameter
self._control_points = torch.nn.Parameter(
control_points, requires_grad=True
)

View File

@@ -1,6 +1,5 @@
import torch import torch
import pytest import pytest
import numpy as np
from scipy.interpolate import BSpline from scipy.interpolate import BSpline
from pina.model import Spline from pina.model import Spline
from pina import LabelTensor from pina import LabelTensor
@@ -12,7 +11,10 @@ n_ctrl_pts = torch.randint(order, order + 5, (1,)).item()
n_knots = order + n_ctrl_pts n_knots = order + n_ctrl_pts
# Input tensor # Input tensor
pts = LabelTensor(torch.linspace(0, 1, 100).reshape(-1, 1), ["x"]) points = [
LabelTensor(torch.rand(100, 1), ["x"]),
LabelTensor(torch.rand(2, 100, 1), ["x"]),
]
# Function to compare with scipy implementation # Function to compare with scipy implementation
@@ -26,15 +28,15 @@ def check_scipy_spline(model, x, output_):
) )
# Compare outputs # Compare outputs
np.testing.assert_allclose( torch.allclose(
output_.squeeze().detach().numpy(), output_,
scipy_spline(x).flatten(), torch.tensor(scipy_spline(x), dtype=output_.dtype),
atol=1e-5, atol=1e-5,
rtol=1e-5, rtol=1e-5,
) )
# Define all possible combinations of valid arguments for the Spline class # Define all possible combinations of valid arguments for Spline class
valid_args = [ valid_args = [
{ {
"order": order, "order": order,
@@ -144,14 +146,15 @@ def test_constructor(args):
@pytest.mark.parametrize("args", valid_args) @pytest.mark.parametrize("args", valid_args)
def test_forward(args): @pytest.mark.parametrize("pts", points)
def test_forward(args, pts):
# Define the model # Define the model
model = Spline(**args) model = Spline(**args)
# Evaluate the model # Evaluate the model
output_ = model(pts) output_ = model(pts)
assert output_.shape == (pts.shape[0], 1) assert output_.shape == pts.shape
# Compare with scipy implementation only for interpolant knots (mode: auto) # Compare with scipy implementation only for interpolant knots (mode: auto)
if isinstance(args["knots"], dict) and args["knots"]["mode"] == "auto": if isinstance(args["knots"], dict) and args["knots"]["mode"] == "auto":
@@ -159,7 +162,8 @@ def test_forward(args):
@pytest.mark.parametrize("args", valid_args) @pytest.mark.parametrize("args", valid_args)
def test_backward(args): @pytest.mark.parametrize("pts", points)
def test_backward(args, pts):
# Define the model # Define the model
model = Spline(**args) model = Spline(**args)

View File

@@ -0,0 +1,180 @@
import torch
import random
import pytest
from pina.model import SplineSurface
from pina import LabelTensor
# Utility quantities for testing
orders = [random.randint(1, 8) for _ in range(2)]
n_ctrl_pts = random.randint(max(orders), max(orders) + 5)
n_knots = [orders[i] + n_ctrl_pts for i in range(2)]
# Input tensor
points = [
LabelTensor(torch.rand(100, 2), ["x", "y"]),
LabelTensor(torch.rand(2, 100, 2), ["x", "y"]),
]
@pytest.mark.parametrize(
"knots_u",
[
torch.rand(n_knots[0]),
{"n": n_knots[0], "min": 0, "max": 1, "mode": "auto"},
{"n": n_knots[0], "min": 0, "max": 1, "mode": "uniform"},
None,
],
)
@pytest.mark.parametrize(
"knots_v",
[
torch.rand(n_knots[1]),
{"n": n_knots[1], "min": 0, "max": 1, "mode": "auto"},
{"n": n_knots[1], "min": 0, "max": 1, "mode": "uniform"},
None,
],
)
@pytest.mark.parametrize(
"control_points", [torch.rand(n_ctrl_pts, n_ctrl_pts), None]
)
def test_constructor(knots_u, knots_v, control_points):
# Skip if knots_u, knots_v, and control_points are all None
if (knots_u is None or knots_v is None) and control_points is None:
return
SplineSurface(
orders=orders,
knots_u=knots_u,
knots_v=knots_v,
control_points=control_points,
)
# Should fail if orders is not list of two elements
with pytest.raises(ValueError):
SplineSurface(
orders=[orders[0]],
knots_u=knots_u,
knots_v=knots_v,
control_points=control_points,
)
# Should fail if both knots and control_points are None
with pytest.raises(ValueError):
SplineSurface(
orders=orders,
knots_u=None,
knots_v=None,
control_points=None,
)
# Should fail if control_points is not a torch.Tensor when provided
with pytest.raises(ValueError):
SplineSurface(
orders=orders,
knots_u=knots_u,
knots_v=knots_v,
control_points=[[0.0] * n_ctrl_pts] * n_ctrl_pts,
)
# Should fail if control_points is not of the correct shape when provided
# It assumes that at least one among knots_u and knots_v is not None
if knots_u is not None or knots_v is not None:
with pytest.raises(ValueError):
SplineSurface(
orders=orders,
knots_u=knots_u,
knots_v=knots_v,
control_points=torch.rand(n_ctrl_pts + 1, n_ctrl_pts + 1),
)
# Should fail if there are not enough knots_u to define the control points
with pytest.raises(ValueError):
SplineSurface(
orders=orders,
knots_u=torch.linspace(0, 1, orders[0]),
knots_v=knots_v,
control_points=None,
)
# Should fail if there are not enough knots_v to define the control points
with pytest.raises(ValueError):
SplineSurface(
orders=orders,
knots_u=knots_u,
knots_v=torch.linspace(0, 1, orders[1]),
control_points=None,
)
@pytest.mark.parametrize(
"knots_u",
[
torch.rand(n_knots[0]),
{"n": n_knots[0], "min": 0, "max": 1, "mode": "auto"},
{"n": n_knots[0], "min": 0, "max": 1, "mode": "uniform"},
],
)
@pytest.mark.parametrize(
"knots_v",
[
torch.rand(n_knots[1]),
{"n": n_knots[1], "min": 0, "max": 1, "mode": "auto"},
{"n": n_knots[1], "min": 0, "max": 1, "mode": "uniform"},
],
)
@pytest.mark.parametrize(
"control_points", [torch.rand(n_ctrl_pts, n_ctrl_pts), None]
)
@pytest.mark.parametrize("pts", points)
def test_forward(knots_u, knots_v, control_points, pts):
# Define the model
model = SplineSurface(
orders=orders,
knots_u=knots_u,
knots_v=knots_v,
control_points=control_points,
)
# Evaluate the model
output_ = model(pts)
assert output_.shape == (*pts.shape[:-1], 1)
@pytest.mark.parametrize(
"knots_u",
[
torch.rand(n_knots[0]),
{"n": n_knots[0], "min": 0, "max": 1, "mode": "auto"},
{"n": n_knots[0], "min": 0, "max": 1, "mode": "uniform"},
],
)
@pytest.mark.parametrize(
"knots_v",
[
torch.rand(n_knots[1]),
{"n": n_knots[1], "min": 0, "max": 1, "mode": "auto"},
{"n": n_knots[1], "min": 0, "max": 1, "mode": "uniform"},
],
)
@pytest.mark.parametrize(
"control_points", [torch.rand(n_ctrl_pts, n_ctrl_pts), None]
)
@pytest.mark.parametrize("pts", points)
def test_backward(knots_u, knots_v, control_points, pts):
# Define the model
model = SplineSurface(
orders=orders,
knots_u=knots_u,
knots_v=knots_v,
control_points=control_points,
)
# Evaluate the model
output_ = model(pts)
loss = torch.mean(output_)
loss.backward()
assert model.control_points.grad.shape == model.control_points.shape