add b-spline surface
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
"""Module for the B-Spline model class."""
|
||||
|
||||
import torch
|
||||
import warnings
|
||||
from ..utils import check_positive_integer
|
||||
import torch
|
||||
from ..utils import check_positive_integer, check_consistency
|
||||
|
||||
|
||||
class Spline(torch.nn.Module):
|
||||
@@ -75,6 +75,10 @@ class Spline(torch.nn.Module):
|
||||
If None, they are initialized as learnable parameters with an
|
||||
initial value of zero. Default is None.
|
||||
:raises AssertionError: If ``order`` is not a positive integer.
|
||||
:raises ValueError: If ``knots`` is neither a torch.Tensor nor a
|
||||
dictionary, when provided.
|
||||
:raises ValueError: If ``control_points`` is not a torch.Tensor,
|
||||
when provided.
|
||||
:raises ValueError: If both ``knots`` and ``control_points`` are None.
|
||||
:raises ValueError: If ``knots`` is not one-dimensional.
|
||||
:raises ValueError: If ``control_points`` is not one-dimensional.
|
||||
@@ -87,6 +91,8 @@ class Spline(torch.nn.Module):
|
||||
|
||||
# Check consistency
|
||||
check_positive_integer(value=order, strict=True)
|
||||
check_consistency(knots, (type(None), torch.Tensor, dict))
|
||||
check_consistency(control_points, (type(None), torch.Tensor))
|
||||
|
||||
# Raise error if neither knots nor control points are provided
|
||||
if knots is None and control_points is None:
|
||||
@@ -229,10 +235,10 @@ class Spline(torch.nn.Module):
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return torch.einsum(
|
||||
"bi, i -> b",
|
||||
self.basis(x.as_subclass(torch.Tensor)).squeeze(1),
|
||||
"...bi, i -> ...b",
|
||||
self.basis(x.as_subclass(torch.Tensor)).squeeze(-1),
|
||||
self.control_points,
|
||||
).reshape(-1, 1)
|
||||
)
|
||||
|
||||
@property
|
||||
def control_points(self):
|
||||
@@ -254,7 +260,6 @@ class Spline(torch.nn.Module):
|
||||
initial value. Default is None.
|
||||
:raises ValueError: If there are not enough knots to define the control
|
||||
points, due to the relation: #knots = order + #control_points.
|
||||
:raises ValueError: If control_points is not a torch.Tensor.
|
||||
"""
|
||||
# If control points are not provided, initialize them
|
||||
if control_points is None:
|
||||
@@ -270,13 +275,6 @@ class Spline(torch.nn.Module):
|
||||
# Initialize control points to zero
|
||||
control_points = torch.zeros(len(self.knots) - self.order)
|
||||
|
||||
# Check validity of control points
|
||||
elif not isinstance(control_points, torch.Tensor):
|
||||
raise ValueError(
|
||||
"control_points must be a torch.Tensor,"
|
||||
f" got {type(control_points)}"
|
||||
)
|
||||
|
||||
# Set control points
|
||||
self._control_points = torch.nn.Parameter(
|
||||
control_points, requires_grad=True
|
||||
@@ -308,18 +306,10 @@ class Spline(torch.nn.Module):
|
||||
last control points. In this case, the number of knots is inferred
|
||||
and the ``"n"`` key is ignored.
|
||||
:type value: torch.Tensor | dict
|
||||
:raises ValueError: If value is not a torch.Tensor or a dictionary.
|
||||
:raises ValueError: If a dictionary is provided but does not contain
|
||||
the required keys.
|
||||
:raises ValueError: If the mode specified in the dictionary is invalid.
|
||||
"""
|
||||
# Check validity of knots
|
||||
if not isinstance(value, (torch.Tensor, dict)):
|
||||
raise ValueError(
|
||||
"Knots must be a torch.Tensor or a dictionary,"
|
||||
f" got {type(value)}."
|
||||
)
|
||||
|
||||
# If a dictionary is provided, initialize knots accordingly
|
||||
if isinstance(value, dict):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user