add b-spline surface

This commit is contained in:
GiovanniCanali
2025-10-06 15:50:14 +02:00
parent 71ce8c55f6
commit df4ea64c74
7 changed files with 425 additions and 30 deletions

View File

@@ -1,8 +1,8 @@
"""Module for the B-Spline model class."""
import torch
import warnings
from ..utils import check_positive_integer
import torch
from ..utils import check_positive_integer, check_consistency
class Spline(torch.nn.Module):
@@ -75,6 +75,10 @@ class Spline(torch.nn.Module):
If None, they are initialized as learnable parameters with an
initial value of zero. Default is None.
:raises AssertionError: If ``order`` is not a positive integer.
:raises ValueError: If ``knots`` is neither a torch.Tensor nor a
dictionary, when provided.
:raises ValueError: If ``control_points`` is not a torch.Tensor,
when provided.
:raises ValueError: If both ``knots`` and ``control_points`` are None.
:raises ValueError: If ``knots`` is not one-dimensional.
:raises ValueError: If ``control_points`` is not one-dimensional.
@@ -87,6 +91,8 @@ class Spline(torch.nn.Module):
# Check consistency
check_positive_integer(value=order, strict=True)
check_consistency(knots, (type(None), torch.Tensor, dict))
check_consistency(control_points, (type(None), torch.Tensor))
# Raise error if neither knots nor control points are provided
if knots is None and control_points is None:
@@ -229,10 +235,10 @@ class Spline(torch.nn.Module):
:rtype: torch.Tensor
"""
return torch.einsum(
"bi, i -> b",
self.basis(x.as_subclass(torch.Tensor)).squeeze(1),
"...bi, i -> ...b",
self.basis(x.as_subclass(torch.Tensor)).squeeze(-1),
self.control_points,
).reshape(-1, 1)
)
@property
def control_points(self):
@@ -254,7 +260,6 @@ class Spline(torch.nn.Module):
initial value. Default is None.
:raises ValueError: If there are not enough knots to define the control
points, due to the relation: #knots = order + #control_points.
:raises ValueError: If control_points is not a torch.Tensor.
"""
# If control points are not provided, initialize them
if control_points is None:
@@ -270,13 +275,6 @@ class Spline(torch.nn.Module):
# Initialize control points to zero
control_points = torch.zeros(len(self.knots) - self.order)
# Check validity of control points
elif not isinstance(control_points, torch.Tensor):
raise ValueError(
"control_points must be a torch.Tensor,"
f" got {type(control_points)}"
)
# Set control points
self._control_points = torch.nn.Parameter(
control_points, requires_grad=True
@@ -308,18 +306,10 @@ class Spline(torch.nn.Module):
last control points. In this case, the number of knots is inferred
and the ``"n"`` key is ignored.
:type value: torch.Tensor | dict
:raises ValueError: If value is not a torch.Tensor or a dictionary.
:raises ValueError: If a dictionary is provided but does not contain
the required keys.
:raises ValueError: If the mode specified in the dictionary is invalid.
"""
# Check validity of knots
if not isinstance(value, (torch.Tensor, dict)):
raise ValueError(
"Knots must be a torch.Tensor or a dictionary,"
f" got {type(value)}."
)
# If a dictionary is provided, initialize knots accordingly
if isinstance(value, dict):