@@ -95,7 +95,6 @@ Models
|
|||||||
MultiFeedForward <model/multi_feed_forward.rst>
|
MultiFeedForward <model/multi_feed_forward.rst>
|
||||||
ResidualFeedForward <model/residual_feed_forward.rst>
|
ResidualFeedForward <model/residual_feed_forward.rst>
|
||||||
Spline <model/spline.rst>
|
Spline <model/spline.rst>
|
||||||
SplineSurface <model/spline_surface.rst>
|
|
||||||
DeepONet <model/deeponet.rst>
|
DeepONet <model/deeponet.rst>
|
||||||
MIONet <model/mionet.rst>
|
MIONet <model/mionet.rst>
|
||||||
KernelNeuralOperator <model/kernel_neural_operator.rst>
|
KernelNeuralOperator <model/kernel_neural_operator.rst>
|
||||||
|
|||||||
@@ -1,7 +0,0 @@
|
|||||||
Spline Surface
|
|
||||||
================
|
|
||||||
.. currentmodule:: pina.model.spline_surface
|
|
||||||
|
|
||||||
.. autoclass:: SplineSurface
|
|
||||||
:members:
|
|
||||||
:show-inheritance:
|
|
||||||
@@ -26,7 +26,6 @@ from .kernel_neural_operator import KernelNeuralOperator
|
|||||||
from .average_neural_operator import AveragingNeuralOperator
|
from .average_neural_operator import AveragingNeuralOperator
|
||||||
from .low_rank_neural_operator import LowRankNeuralOperator
|
from .low_rank_neural_operator import LowRankNeuralOperator
|
||||||
from .spline import Spline
|
from .spline import Spline
|
||||||
from .spline_surface import SplineSurface
|
|
||||||
from .graph_neural_operator import GraphNeuralOperator
|
from .graph_neural_operator import GraphNeuralOperator
|
||||||
from .pirate_network import PirateNet
|
from .pirate_network import PirateNet
|
||||||
from .equivariant_graph_neural_operator import EquivariantGraphNeuralOperator
|
from .equivariant_graph_neural_operator import EquivariantGraphNeuralOperator
|
||||||
|
|||||||
@@ -1,244 +1,109 @@
|
|||||||
"""Module for the B-Spline model class."""
|
"""Module for the Spline model class."""
|
||||||
|
|
||||||
import warnings
|
|
||||||
import torch
|
import torch
|
||||||
from ..utils import check_positive_integer, check_consistency
|
from ..utils import check_consistency
|
||||||
|
|
||||||
|
|
||||||
class Spline(torch.nn.Module):
|
class Spline(torch.nn.Module):
|
||||||
r"""
|
"""
|
||||||
The univariate B-Spline curve model class.
|
Spline model class.
|
||||||
|
|
||||||
A univariate B-spline curve of order :math:`k` is a parametric curve defined
|
|
||||||
as a linear combination of B-spline basis functions and control points:
|
|
||||||
|
|
||||||
.. math::
|
|
||||||
|
|
||||||
S(x) = \sum_{i=1}^{n} B_{i,k}(x) C_i, \quad x \in [x_1, x_m]
|
|
||||||
|
|
||||||
where:
|
|
||||||
|
|
||||||
- :math:`C_i \in \mathbb{R}` are the control points. These fixed points
|
|
||||||
influence the shape of the curve but are not generally interpolated,
|
|
||||||
except at the boundaries under certain knot multiplicities.
|
|
||||||
- :math:`B_{i,k}(x)` are the B-spline basis functions of order :math:`k`,
|
|
||||||
i.e., piecewise polynomials of degree :math:`k-1` with support on the
|
|
||||||
interval :math:`[x_i, x_{i+k}]`.
|
|
||||||
- :math:`X = \{ x_1, x_2, \dots, x_m \}` is the non-decreasing knot vector.
|
|
||||||
|
|
||||||
If the first and last knots are repeated :math:`k` times, then the curve
|
|
||||||
interpolates the first and last control points.
|
|
||||||
|
|
||||||
|
|
||||||
.. note::
|
|
||||||
|
|
||||||
The curve is forced to be zero outside the interval defined by the
|
|
||||||
first and last knots.
|
|
||||||
|
|
||||||
|
|
||||||
:Example:
|
|
||||||
|
|
||||||
>>> from pina.model import Spline
|
|
||||||
>>> import torch
|
|
||||||
|
|
||||||
>>> knots1 = torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0, 2.0, 2.0])
|
|
||||||
>>> spline1 = Spline(order=3, knots=knots1, control_points=None)
|
|
||||||
|
|
||||||
>>> knots2 = {"n": 7, "min": 0.0, "max": 2.0, "mode": "auto"}
|
|
||||||
>>> spline2 = Spline(order=3, knots=knots2, control_points=None)
|
|
||||||
|
|
||||||
>>> knots3 = torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0, 2.0, 2.0])
|
|
||||||
>>> control_points3 = torch.tensor([0.0, 1.0, 3.0, 2.0])
|
|
||||||
>>> spline3 = Spline(order=3, knots=knots3, control_points=control_points3)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, order=4, knots=None, control_points=None):
|
def __init__(self, order=4, knots=None, control_points=None) -> None:
|
||||||
"""
|
"""
|
||||||
Initialization of the :class:`Spline` class.
|
Initialization of the :class:`Spline` class.
|
||||||
|
|
||||||
:param int order: The order of the spline. The corresponding basis
|
:param int order: The order of the spline. Default is ``4``.
|
||||||
functions are polynomials of degree ``order - 1``. Default is 4.
|
:param torch.Tensor knots: The tensor representing knots. If ``None``,
|
||||||
:param knots: The knots of the spline. If a tensor is provided, knots
|
the knots will be initialized automatically. Default is ``None``.
|
||||||
are set directly from the tensor. If a dictionary is provided, it
|
:param torch.Tensor control_points: The control points. Default is
|
||||||
must contain the keys ``"n"``, ``"min"``, ``"max"``, and ``"mode"``.
|
``None``.
|
||||||
Here, ``"n"`` specifies the number of knots, ``"min"`` and ``"max"``
|
:raises ValueError: If the order is negative.
|
||||||
define the interval, and ``"mode"`` selects the sampling strategy.
|
:raises ValueError: If both knots and control points are ``None``.
|
||||||
The supported modes are ``"uniform"``, where the knots are evenly
|
:raises ValueError: If the knot tensor is not one-dimensional.
|
||||||
spaced over :math:`[min, max]`, and ``"auto"``, where knots are
|
|
||||||
constructed to ensure that the spline interpolates the first and
|
|
||||||
last control points. In this case, the number of knots is adjusted
|
|
||||||
if :math:`n < 2 * order`. If None is given, knots are initialized
|
|
||||||
automatically over :math:`[0, 1]` ensuring interpolation of the
|
|
||||||
first and last control points. Default is None.
|
|
||||||
:type knots: torch.Tensor | dict
|
|
||||||
:param torch.Tensor control_points: The control points of the spline.
|
|
||||||
If None, they are initialized as learnable parameters with an
|
|
||||||
initial value of zero. Default is None.
|
|
||||||
:raises AssertionError: If ``order`` is not a positive integer.
|
|
||||||
:raises ValueError: If ``knots`` is neither a torch.Tensor nor a
|
|
||||||
dictionary, when provided.
|
|
||||||
:raises ValueError: If ``control_points`` is not a torch.Tensor,
|
|
||||||
when provided.
|
|
||||||
:raises ValueError: If both ``knots`` and ``control_points`` are None.
|
|
||||||
:raises ValueError: If ``knots`` is not one-dimensional.
|
|
||||||
:raises ValueError: If ``control_points`` is not one-dimensional.
|
|
||||||
:raises ValueError: If the number of ``knots`` is not equal to the sum
|
|
||||||
of ``order`` and the number of ``control_points.``
|
|
||||||
:raises UserWarning: If the number of control points is lower than the
|
|
||||||
order, resulting in a degenerate spline.
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Check consistency
|
check_consistency(order, int)
|
||||||
check_positive_integer(value=order, strict=True)
|
|
||||||
check_consistency(knots, (type(None), torch.Tensor, dict))
|
|
||||||
check_consistency(control_points, (type(None), torch.Tensor))
|
|
||||||
|
|
||||||
# Raise error if neither knots nor control points are provided
|
if order < 0:
|
||||||
|
raise ValueError("Spline order cannot be negative.")
|
||||||
if knots is None and control_points is None:
|
if knots is None and control_points is None:
|
||||||
raise ValueError("knots and control_points cannot both be None.")
|
raise ValueError("Knots and control points cannot be both None.")
|
||||||
|
|
||||||
# Initialize knots if not provided
|
self.order = order
|
||||||
if knots is None and control_points is not None:
|
self.k = order - 1
|
||||||
knots = {
|
|
||||||
"n": len(control_points) + order,
|
if knots is not None and control_points is not None:
|
||||||
|
self.knots = knots
|
||||||
|
self.control_points = control_points
|
||||||
|
|
||||||
|
elif knots is not None:
|
||||||
|
print("Warning: control points will be initialized automatically.")
|
||||||
|
print(" experimental feature")
|
||||||
|
|
||||||
|
self.knots = knots
|
||||||
|
n = len(knots) - order
|
||||||
|
self.control_points = torch.nn.Parameter(
|
||||||
|
torch.zeros(n), requires_grad=True
|
||||||
|
)
|
||||||
|
|
||||||
|
elif control_points is not None:
|
||||||
|
print("Warning: knots will be initialized automatically.")
|
||||||
|
print(" experimental feature")
|
||||||
|
|
||||||
|
self.control_points = control_points
|
||||||
|
|
||||||
|
n = len(self.control_points) - 1
|
||||||
|
self.knots = {
|
||||||
|
"type": "auto",
|
||||||
"min": 0,
|
"min": 0,
|
||||||
"max": 1,
|
"max": 1,
|
||||||
"mode": "auto",
|
"n": n + 2 + self.order,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Initialization - knots and control points managed by their setters
|
else:
|
||||||
self.order = order
|
raise ValueError("Knots and control points cannot be both None.")
|
||||||
self.knots = knots
|
|
||||||
self.control_points = control_points
|
|
||||||
|
|
||||||
# Check dimensionality of knots
|
if self.knots.ndim != 1:
|
||||||
if self.knots.ndim > 1:
|
raise ValueError("Knot vector must be one-dimensional.")
|
||||||
raise ValueError("knots must be one-dimensional.")
|
|
||||||
|
|
||||||
# Check dimensionality of control points
|
def basis(self, x, k, i, t):
|
||||||
if self.control_points.ndim > 1:
|
|
||||||
raise ValueError("control_points must be one-dimensional.")
|
|
||||||
|
|
||||||
# Raise error if #knots != order + #control_points
|
|
||||||
if len(self.knots) != self.order + len(self.control_points):
|
|
||||||
raise ValueError(
|
|
||||||
f" The number of knots must be equal to order + number of"
|
|
||||||
f" control points. Got {len(self.knots)} knots, {self.order}"
|
|
||||||
f" order and {len(self.control_points)} control points."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Raise warning if spline is degenerate
|
|
||||||
if len(self.control_points) < self.order:
|
|
||||||
warnings.warn(
|
|
||||||
"The number of control points is smaller than the spline order."
|
|
||||||
" This creates a degenerate spline with limited flexibility.",
|
|
||||||
UserWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Precompute boundary interval index
|
|
||||||
self._boundary_interval_idx = self._compute_boundary_interval()
|
|
||||||
|
|
||||||
def _compute_boundary_interval(self):
|
|
||||||
"""
|
"""
|
||||||
Precompute the index of the rightmost non-degenerate interval to improve
|
Recursive method to compute the basis functions of the spline.
|
||||||
performance, eliminating the need to perform a search loop in the basis
|
|
||||||
function on each call.
|
|
||||||
|
|
||||||
:return: The index of the rightmost non-degenerate interval.
|
|
||||||
:rtype: int
|
|
||||||
"""
|
|
||||||
# Return 0 if there is a single interval
|
|
||||||
if len(self.knots) < 2:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
# Find all indices where knots are strictly increasing
|
|
||||||
diffs = self.knots[1:] - self.knots[:-1]
|
|
||||||
valid = torch.nonzero(diffs > 0, as_tuple=False)
|
|
||||||
|
|
||||||
# If all knots are equal, return 0 for degenerate spline
|
|
||||||
if valid.numel() == 0:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
# Otherwise, return the last valid index
|
|
||||||
return int(valid[-1])
|
|
||||||
|
|
||||||
def basis(self, x):
|
|
||||||
"""
|
|
||||||
Compute the basis functions for the spline using an iterative approach.
|
|
||||||
This is a vectorized implementation based on the Cox-de Boor recursion.
|
|
||||||
|
|
||||||
:param torch.Tensor x: The points to be evaluated.
|
:param torch.Tensor x: The points to be evaluated.
|
||||||
:return: The basis functions evaluated at x.
|
:param int k: The spline degree.
|
||||||
|
:param int i: The index of the interval.
|
||||||
|
:param torch.Tensor t: The tensor of knots.
|
||||||
|
:return: The basis functions evaluated at x
|
||||||
:rtype: torch.Tensor
|
:rtype: torch.Tensor
|
||||||
"""
|
"""
|
||||||
# Add a final dimension to x
|
|
||||||
x = x.unsqueeze(-1)
|
|
||||||
|
|
||||||
# Add an initial dimension to knots
|
if k == 0:
|
||||||
knots = self.knots.unsqueeze(0)
|
a = torch.where(
|
||||||
|
torch.logical_and(t[i] <= x, x < t[i + 1]), 1.0, 0.0
|
||||||
# Base case of recursion: indicator functions for the intervals
|
|
||||||
basis = (x >= knots[..., :-1]) & (x < knots[..., 1:])
|
|
||||||
basis = basis.to(x.dtype)
|
|
||||||
|
|
||||||
# One-dimensional knots case: ensure rightmost boundary inclusion
|
|
||||||
if self._boundary_interval_idx is not None:
|
|
||||||
|
|
||||||
# Extract left and right knots of the rightmost interval
|
|
||||||
knot_left = knots[..., self._boundary_interval_idx]
|
|
||||||
knot_right = knots[..., self._boundary_interval_idx + 1]
|
|
||||||
|
|
||||||
# Identify points at the rightmost boundary
|
|
||||||
at_rightmost_boundary = (
|
|
||||||
x.squeeze(-1) >= knot_left
|
|
||||||
) & torch.isclose(x.squeeze(-1), knot_right, rtol=1e-8, atol=1e-10)
|
|
||||||
|
|
||||||
# Ensure the correct value is set at the rightmost boundary
|
|
||||||
if torch.any(at_rightmost_boundary):
|
|
||||||
basis[..., self._boundary_interval_idx] = torch.logical_or(
|
|
||||||
basis[..., self._boundary_interval_idx].bool(),
|
|
||||||
at_rightmost_boundary,
|
|
||||||
).to(basis.dtype)
|
|
||||||
|
|
||||||
# Iterative case of recursion
|
|
||||||
for i in range(1, self.order):
|
|
||||||
|
|
||||||
# Compute the denominators for both terms
|
|
||||||
denom1 = knots[..., i:-1] - knots[..., : -(i + 1)]
|
|
||||||
denom2 = knots[..., i + 1 :] - knots[..., 1:-i]
|
|
||||||
|
|
||||||
# Ensure no division by zero
|
|
||||||
denom1 = torch.where(
|
|
||||||
torch.abs(denom1) < 1e-8, torch.ones_like(denom1), denom1
|
|
||||||
)
|
)
|
||||||
denom2 = torch.where(
|
if i == len(t) - self.order - 1:
|
||||||
torch.abs(denom2) < 1e-8, torch.ones_like(denom2), denom2
|
a = torch.where(x == t[-1], 1.0, a)
|
||||||
|
a.requires_grad_(True)
|
||||||
|
return a
|
||||||
|
|
||||||
|
if t[i + k] == t[i]:
|
||||||
|
c1 = torch.tensor([0.0] * len(x), requires_grad=True)
|
||||||
|
else:
|
||||||
|
c1 = (x - t[i]) / (t[i + k] - t[i]) * self.basis(x, k - 1, i, t)
|
||||||
|
|
||||||
|
if t[i + k + 1] == t[i + 1]:
|
||||||
|
c2 = torch.tensor([0.0] * len(x), requires_grad=True)
|
||||||
|
else:
|
||||||
|
c2 = (
|
||||||
|
(t[i + k + 1] - x)
|
||||||
|
/ (t[i + k + 1] - t[i + 1])
|
||||||
|
* self.basis(x, k - 1, i + 1, t)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute the two terms of the recursion
|
return c1 + c2
|
||||||
term1 = ((x - knots[..., : -(i + 1)]) / denom1) * basis[..., :-1]
|
|
||||||
term2 = ((knots[..., i + 1 :] - x) / denom2) * basis[..., 1:]
|
|
||||||
|
|
||||||
# Combine terms to get the new basis
|
|
||||||
basis = term1 + term2
|
|
||||||
|
|
||||||
return basis
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
"""
|
|
||||||
Forward pass for the :class:`Spline` model.
|
|
||||||
|
|
||||||
:param x: The input tensor.
|
|
||||||
:type x: torch.Tensor | LabelTensor
|
|
||||||
:return: The output tensor.
|
|
||||||
:rtype: torch.Tensor
|
|
||||||
"""
|
|
||||||
return torch.einsum(
|
|
||||||
"...bi, i -> ...b",
|
|
||||||
self.basis(x.as_subclass(torch.Tensor)).squeeze(-1),
|
|
||||||
self.control_points,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def control_points(self):
|
def control_points(self):
|
||||||
@@ -251,34 +116,24 @@ class Spline(torch.nn.Module):
|
|||||||
return self._control_points
|
return self._control_points
|
||||||
|
|
||||||
@control_points.setter
|
@control_points.setter
|
||||||
def control_points(self, control_points):
|
def control_points(self, value):
|
||||||
"""
|
"""
|
||||||
Set the control points of the spline.
|
Set the control points of the spline.
|
||||||
|
|
||||||
:param torch.Tensor control_points: The control points tensor. If None,
|
:param value: The control points.
|
||||||
control points are initialized to learnable parameters with zero
|
:type value: torch.Tensor | dict
|
||||||
initial value. Default is None.
|
:raises ValueError: If invalid value is passed.
|
||||||
:raises ValueError: If there are not enough knots to define the control
|
|
||||||
points, due to the relation: #knots = order + #control_points.
|
|
||||||
"""
|
"""
|
||||||
# If control points are not provided, initialize them
|
if isinstance(value, dict):
|
||||||
if control_points is None:
|
if "n" not in value:
|
||||||
|
raise ValueError("Invalid value for control_points")
|
||||||
|
n = value["n"]
|
||||||
|
dim = value.get("dim", 1)
|
||||||
|
value = torch.zeros(n, dim)
|
||||||
|
|
||||||
# Check that there are enough knots to define control points
|
if not isinstance(value, torch.Tensor):
|
||||||
if len(self.knots) < self.order + 1:
|
raise ValueError("Invalid value for control_points")
|
||||||
raise ValueError(
|
self._control_points = torch.nn.Parameter(value, requires_grad=True)
|
||||||
f"Not enough knots to define control points. Got "
|
|
||||||
f"{len(self.knots)} knots, but need at least "
|
|
||||||
f"{self.order + 1}."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize control points to zero
|
|
||||||
control_points = torch.zeros(len(self.knots) - self.order)
|
|
||||||
|
|
||||||
# Set control points
|
|
||||||
self._control_points = torch.nn.Parameter(
|
|
||||||
control_points, requires_grad=True
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def knots(self):
|
def knots(self):
|
||||||
@@ -295,72 +150,50 @@ class Spline(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
Set the knots of the spline.
|
Set the knots of the spline.
|
||||||
|
|
||||||
:param value: The knots of the spline. If a tensor is provided, knots
|
:param value: The knots.
|
||||||
are set directly from the tensor. If a dictionary is provided, it
|
|
||||||
must contain the keys ``"n"``, ``"min"``, ``"max"``, and ``"mode"``.
|
|
||||||
Here, ``"n"`` specifies the number of knots, ``"min"`` and ``"max"``
|
|
||||||
define the interval, and ``"mode"`` selects the sampling strategy.
|
|
||||||
The supported modes are ``"uniform"``, where the knots are evenly
|
|
||||||
spaced over :math:`[min, max]`, and ``"auto"``, where knots are
|
|
||||||
constructed to ensure that the spline interpolates the first and
|
|
||||||
last control points. In this case, the number of knots is inferred
|
|
||||||
and the ``"n"`` key is ignored.
|
|
||||||
:type value: torch.Tensor | dict
|
:type value: torch.Tensor | dict
|
||||||
:raises ValueError: If a dictionary is provided but does not contain
|
:raises ValueError: If invalid value is passed.
|
||||||
the required keys.
|
|
||||||
:raises ValueError: If the mode specified in the dictionary is invalid.
|
|
||||||
"""
|
"""
|
||||||
# If a dictionary is provided, initialize knots accordingly
|
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
|
|
||||||
# Check that required keys are present
|
type_ = value.get("type", "auto")
|
||||||
required_keys = {"n", "min", "max", "mode"}
|
min_ = value.get("min", 0)
|
||||||
if not required_keys.issubset(value.keys()):
|
max_ = value.get("max", 1)
|
||||||
raise ValueError(
|
n = value.get("n", 10)
|
||||||
f"When providing knots as a dictionary, the following "
|
|
||||||
f"keys must be present: {required_keys}. Got "
|
|
||||||
f"{value.keys()}."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Uniform sampling of knots
|
if type_ == "uniform":
|
||||||
if value["mode"] == "uniform":
|
value = torch.linspace(min_, max_, n + self.k + 1)
|
||||||
value = torch.linspace(value["min"], value["max"], value["n"])
|
elif type_ == "auto":
|
||||||
|
initial_knots = torch.ones(self.order + 1) * min_
|
||||||
|
final_knots = torch.ones(self.order + 1) * max_
|
||||||
|
|
||||||
# Automatic sampling of interpolating knots
|
if n < self.order + 1:
|
||||||
elif value["mode"] == "auto":
|
value = torch.concatenate((initial_knots, final_knots))
|
||||||
|
elif n - 2 * self.order + 1 == 1:
|
||||||
# Repeat the first and last knots 'order' times
|
value = torch.Tensor([(max_ + min_) / 2])
|
||||||
initial_knots = torch.ones(self.order) * value["min"]
|
|
||||||
final_knots = torch.ones(self.order) * value["max"]
|
|
||||||
|
|
||||||
# Number of internal knots
|
|
||||||
n_internal = value["n"] - 2 * self.order
|
|
||||||
|
|
||||||
# If no internal knots are needed, just concatenate boundaries
|
|
||||||
if n_internal <= 0:
|
|
||||||
value = torch.cat((initial_knots, final_knots))
|
|
||||||
|
|
||||||
# Else, sample internal knots uniformly and exclude boundaries
|
|
||||||
# Recover the correct number of internal knots when slicing by
|
|
||||||
# adding 2 to n_internal
|
|
||||||
else:
|
else:
|
||||||
internal_knots = torch.linspace(
|
value = torch.linspace(min_, max_, n - 2 * self.order - 1)
|
||||||
value["min"], value["max"], n_internal + 2
|
|
||||||
)[1:-1]
|
|
||||||
value = torch.cat(
|
|
||||||
(initial_knots, internal_knots, final_knots)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Raise error if mode is invalid
|
value = torch.concatenate((initial_knots, value, final_knots))
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid mode for knots initialization. Got "
|
|
||||||
f"{value['mode']}, but expected 'uniform' or 'auto'."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set knots
|
if not isinstance(value, torch.Tensor):
|
||||||
self.register_buffer("_knots", value.sort(dim=0).values)
|
raise ValueError("Invalid value for knots")
|
||||||
|
|
||||||
# Recompute boundary interval when knots change
|
self._knots = value
|
||||||
if hasattr(self, "_boundary_interval_idx"):
|
|
||||||
self._boundary_interval_idx = self._compute_boundary_interval()
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Forward pass for the :class:`Spline` model.
|
||||||
|
|
||||||
|
:param torch.Tensor x: The input tensor.
|
||||||
|
:return: The output tensor.
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
"""
|
||||||
|
t = self.knots
|
||||||
|
k = self.k
|
||||||
|
c = self.control_points
|
||||||
|
|
||||||
|
basis = map(lambda i: self.basis(x, k, i, t)[:, None], range(len(c)))
|
||||||
|
y = (torch.cat(list(basis), dim=1) * c).sum(axis=1)
|
||||||
|
|
||||||
|
return y
|
||||||
|
|||||||
@@ -1,212 +0,0 @@
|
|||||||
"""Module for the bivariate B-Spline surface model class."""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from .spline import Spline
|
|
||||||
from ..utils import check_consistency
|
|
||||||
|
|
||||||
|
|
||||||
class SplineSurface(torch.nn.Module):
|
|
||||||
r"""
|
|
||||||
The bivariate B-Spline surface model class.
|
|
||||||
|
|
||||||
A bivariate B-spline surface is a parametric surface defined as the tensor
|
|
||||||
product of two univariate B-spline curves:
|
|
||||||
|
|
||||||
.. math::
|
|
||||||
|
|
||||||
S(x, y) = \sum_{i,j=1}^{n_x, n_y} B_{i,k}(x) B_{j,s}(y) C_{i,j},
|
|
||||||
\quad x \in [x_1, x_m], y \in [y_1, y_l]
|
|
||||||
|
|
||||||
where:
|
|
||||||
|
|
||||||
- :math:`C_{i,j} \in \mathbb{R}^2` are the control points. These fixed
|
|
||||||
points influence the shape of the surface but are not generally
|
|
||||||
interpolated, except at the boundaries under certain knot multiplicities.
|
|
||||||
- :math:`B_{i,k}(x)` and :math:`B_{j,s}(y)` are the B-spline basis functions
|
|
||||||
defined over two orthogonal directions, with orders :math:`k` and
|
|
||||||
:math:`s`, respectively.
|
|
||||||
- :math:`X = \{ x_1, x_2, \dots, x_m \}` and
|
|
||||||
:math:`Y = \{ y_1, y_2, \dots, y_l \}` are the non-decreasing knot
|
|
||||||
vectors along the two directions.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, orders, knots_u=None, knots_v=None, control_points=None):
|
|
||||||
"""
|
|
||||||
Initialization of the :class:`SplineSurface` class.
|
|
||||||
|
|
||||||
:param list[int] orders: The orders of the spline along each parametric
|
|
||||||
direction. Each order defines the degree of the corresponding basis
|
|
||||||
as ``degree = order - 1``.
|
|
||||||
:param knots_u: The knots of the spline along the first direction.
|
|
||||||
For details on valid formats and initialization modes, see the
|
|
||||||
:class:`Spline` class. Default is None.
|
|
||||||
:type knots_u: torch.Tensor | dict
|
|
||||||
:param knots_v: The knots of the spline along the second direction.
|
|
||||||
For details on valid formats and initialization modes, see the
|
|
||||||
:class:`Spline` class. Default is None.
|
|
||||||
:type knots_v: torch.Tensor | dict
|
|
||||||
:param torch.Tensor control_points: The control points defining the
|
|
||||||
surface geometry. It must be a two-dimensional tensor of shape
|
|
||||||
``[len(knots_u) - orders[0], len(knots_v) - orders[1]]``.
|
|
||||||
If None, they are initialized as learnable parameters with zero
|
|
||||||
values. Default is None.
|
|
||||||
:raises ValueError: If ``orders`` is not a list of integers.
|
|
||||||
:raises ValueError: If ``knots_u`` is neither a torch.Tensor nor a
|
|
||||||
dictionary, when provided.
|
|
||||||
:raises ValueError: If ``knots_v`` is neither a torch.Tensor nor a
|
|
||||||
dictionary, when provided.
|
|
||||||
:raises ValueError: If ``control_points`` is not a torch.Tensor,
|
|
||||||
when provided.
|
|
||||||
:raises ValueError: If ``orders`` is not a list of two elements.
|
|
||||||
:raises ValueError: If ``knots_u``, ``knots_v``, and ``control_points``
|
|
||||||
are all None.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
# Check consistency
|
|
||||||
check_consistency(orders, int)
|
|
||||||
check_consistency(control_points, (type(None), torch.Tensor))
|
|
||||||
check_consistency(knots_u, (type(None), torch.Tensor, dict))
|
|
||||||
check_consistency(knots_v, (type(None), torch.Tensor, dict))
|
|
||||||
|
|
||||||
# Check orders is a list of two elements
|
|
||||||
if len(orders) != 2:
|
|
||||||
raise ValueError("orders must be a list of two elements.")
|
|
||||||
|
|
||||||
# Raise error if neither knots nor control points are provided
|
|
||||||
if (knots_u is None or knots_v is None) and control_points is None:
|
|
||||||
raise ValueError(
|
|
||||||
"control_points cannot be None if knots_u or knots_v is None."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize knots_u if not provided
|
|
||||||
if knots_u is None and control_points is not None:
|
|
||||||
knots_u = {
|
|
||||||
"n": control_points.shape[0] + orders[0],
|
|
||||||
"min": 0,
|
|
||||||
"max": 1,
|
|
||||||
"mode": "auto",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Initialize knots_v if not provided
|
|
||||||
if knots_v is None and control_points is not None:
|
|
||||||
knots_v = {
|
|
||||||
"n": control_points.shape[1] + orders[1],
|
|
||||||
"min": 0,
|
|
||||||
"max": 1,
|
|
||||||
"mode": "auto",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Create two univariate b-splines
|
|
||||||
self.spline_u = Spline(order=orders[0], knots=knots_u)
|
|
||||||
self.spline_v = Spline(order=orders[1], knots=knots_v)
|
|
||||||
self.control_points = control_points
|
|
||||||
|
|
||||||
# Delete unneeded parameters
|
|
||||||
delattr(self.spline_u, "_control_points")
|
|
||||||
delattr(self.spline_v, "_control_points")
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
"""
|
|
||||||
Forward pass for the :class:`SplineSurface` model.
|
|
||||||
|
|
||||||
:param x: The input tensor.
|
|
||||||
:type x: torch.Tensor | LabelTensor
|
|
||||||
:return: The output tensor.
|
|
||||||
:rtype: torch.Tensor
|
|
||||||
"""
|
|
||||||
return torch.einsum(
|
|
||||||
"...bi, ...bj, ij -> ...b",
|
|
||||||
self.spline_u.basis(x.as_subclass(torch.Tensor)[..., 0]),
|
|
||||||
self.spline_v.basis(x.as_subclass(torch.Tensor)[..., 1]),
|
|
||||||
self.control_points,
|
|
||||||
).unsqueeze(-1)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def knots(self):
|
|
||||||
"""
|
|
||||||
The knots of the univariate splines defining the spline surface.
|
|
||||||
|
|
||||||
:return: The knots.
|
|
||||||
:rtype: tuple(torch.Tensor, torch.Tensor)
|
|
||||||
"""
|
|
||||||
return self.spline_u.knots, self.spline_v.knots
|
|
||||||
|
|
||||||
@knots.setter
|
|
||||||
def knots(self, value):
|
|
||||||
"""
|
|
||||||
Set the knots of the spline surface.
|
|
||||||
|
|
||||||
:param value: A tuple (knots_u, knots_v) containing the knots for both
|
|
||||||
parametric directions.
|
|
||||||
:type value: tuple(torch.Tensor | dict, torch.Tensor | dict)
|
|
||||||
:raises ValueError: If value is not a tuple of two elements.
|
|
||||||
"""
|
|
||||||
# Check value is a tuple of two elements
|
|
||||||
if not (isinstance(value, tuple) and len(value) == 2):
|
|
||||||
raise ValueError("Knots must be a tuple of two elements.")
|
|
||||||
|
|
||||||
knots_u, knots_v = value
|
|
||||||
self.spline_u.knots = knots_u
|
|
||||||
self.spline_v.knots = knots_v
|
|
||||||
|
|
||||||
@property
|
|
||||||
def control_points(self):
|
|
||||||
"""
|
|
||||||
The control points of the spline.
|
|
||||||
|
|
||||||
:return: The control points.
|
|
||||||
:rtype: torch.Tensor
|
|
||||||
"""
|
|
||||||
return self._control_points
|
|
||||||
|
|
||||||
@control_points.setter
|
|
||||||
def control_points(self, control_points):
|
|
||||||
"""
|
|
||||||
Set the control points of the spline surface.
|
|
||||||
|
|
||||||
:param torch.Tensor control_points: The bidimensional control points
|
|
||||||
tensor, where each dimension refers to a direction in the parameter
|
|
||||||
space. If None, control points are initialized to learnable
|
|
||||||
parameters with zero initial value. Default is None.
|
|
||||||
:raises ValueError: If in any direction there are not enough knots to
|
|
||||||
define the control points, due to the relation:
|
|
||||||
#knots = order + #control_points.
|
|
||||||
:raises ValueError: If ``control_points`` is not of the correct shape.
|
|
||||||
"""
|
|
||||||
# Save correct shape of control points
|
|
||||||
__valid_shape = (
|
|
||||||
len(self.spline_u.knots) - self.spline_u.order,
|
|
||||||
len(self.spline_v.knots) - self.spline_v.order,
|
|
||||||
)
|
|
||||||
|
|
||||||
# If control points are not provided, initialize them
|
|
||||||
if control_points is None:
|
|
||||||
|
|
||||||
# Check that there are enough knots to define control points
|
|
||||||
if (
|
|
||||||
len(self.spline_u.knots) < self.spline_u.order + 1
|
|
||||||
or len(self.spline_v.knots) < self.spline_v.order + 1
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
f"Not enough knots to define control points. Got "
|
|
||||||
f"{len(self.spline_u.knots)} knots along u and "
|
|
||||||
f"{len(self.spline_v.knots)} knots along v, but need at "
|
|
||||||
f"least {self.spline_u.order + 1} and "
|
|
||||||
f"{self.spline_v.order + 1}, respectively."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize control points to zero
|
|
||||||
control_points = torch.zeros(__valid_shape)
|
|
||||||
|
|
||||||
# Check control points
|
|
||||||
if control_points.shape != __valid_shape:
|
|
||||||
raise ValueError(
|
|
||||||
"control_points must be of the correct shape. ",
|
|
||||||
f"Expected {__valid_shape}, got {control_points.shape}.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Register control points as a learnable parameter
|
|
||||||
self._control_points = torch.nn.Parameter(
|
|
||||||
control_points, requires_grad=True
|
|
||||||
)
|
|
||||||
@@ -1,175 +1,81 @@
|
|||||||
import torch
|
import torch
|
||||||
import pytest
|
import pytest
|
||||||
from scipy.interpolate import BSpline
|
|
||||||
from pina.model import Spline
|
from pina.model import Spline
|
||||||
from pina import LabelTensor
|
|
||||||
|
|
||||||
|
data = torch.rand((20, 3))
|
||||||
|
input_vars = 3
|
||||||
|
output_vars = 4
|
||||||
|
|
||||||
# Utility quantities for testing
|
valid_args = [
|
||||||
order = torch.randint(1, 8, (1,)).item()
|
{
|
||||||
n_ctrl_pts = torch.randint(order, order + 5, (1,)).item()
|
"knots": torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 3.0, 3.0]),
|
||||||
n_knots = order + n_ctrl_pts
|
"control_points": torch.tensor([0.0, 0.0, 1.0, 0.0, 0.0]),
|
||||||
|
"order": 3,
|
||||||
# Input tensor
|
},
|
||||||
points = [
|
{
|
||||||
LabelTensor(torch.rand(100, 1), ["x"]),
|
"knots": torch.tensor(
|
||||||
LabelTensor(torch.rand(2, 100, 1), ["x"]),
|
[-2.0, -2.0, -2.0, -2.0, -1.0, 0.0, 1.0, 2.0, 2.0, 2.0, 2.0]
|
||||||
|
),
|
||||||
|
"control_points": torch.tensor([0.0, 0.0, 0.0, 6.0, 0.0, 0.0, 0.0]),
|
||||||
|
"order": 4,
|
||||||
|
},
|
||||||
|
# {'control_points': {'n': 5, 'dim': 1}, 'order': 2},
|
||||||
|
# {'control_points': {'n': 7, 'dim': 1}, 'order': 3}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# Function to compare with scipy implementation
|
def scipy_check(model, x, y):
|
||||||
def check_scipy_spline(model, x, output_):
|
from scipy.interpolate._bsplines import BSpline
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
# Define scipy spline
|
spline = BSpline(
|
||||||
scipy_spline = BSpline(
|
|
||||||
t=model.knots.detach().numpy(),
|
t=model.knots.detach().numpy(),
|
||||||
c=model.control_points.detach().numpy(),
|
c=model.control_points.detach().numpy(),
|
||||||
k=model.order - 1,
|
k=model.order - 1,
|
||||||
)
|
)
|
||||||
|
y_scipy = spline(x).flatten()
|
||||||
# Compare outputs
|
y = y.detach().numpy()
|
||||||
torch.allclose(
|
np.testing.assert_allclose(y, y_scipy, atol=1e-5)
|
||||||
output_,
|
|
||||||
torch.tensor(scipy_spline(x), dtype=output_.dtype),
|
|
||||||
atol=1e-5,
|
|
||||||
rtol=1e-5,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Define all possible combinations of valid arguments for Spline class
|
|
||||||
valid_args = [
|
|
||||||
{
|
|
||||||
"order": order,
|
|
||||||
"control_points": torch.rand(n_ctrl_pts),
|
|
||||||
"knots": torch.linspace(0, 1, n_knots),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"order": order,
|
|
||||||
"control_points": torch.rand(n_ctrl_pts),
|
|
||||||
"knots": {"n": n_knots, "min": 0, "max": 1, "mode": "auto"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"order": order,
|
|
||||||
"control_points": torch.rand(n_ctrl_pts),
|
|
||||||
"knots": {"n": n_knots, "min": 0, "max": 1, "mode": "uniform"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"order": order,
|
|
||||||
"control_points": None,
|
|
||||||
"knots": torch.linspace(0, 1, n_knots),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"order": order,
|
|
||||||
"control_points": None,
|
|
||||||
"knots": {"n": n_knots, "min": 0, "max": 1, "mode": "auto"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"order": order,
|
|
||||||
"control_points": None,
|
|
||||||
"knots": {"n": n_knots, "min": 0, "max": 1, "mode": "uniform"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"order": order,
|
|
||||||
"control_points": torch.rand(n_ctrl_pts),
|
|
||||||
"knots": None,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("args", valid_args)
|
@pytest.mark.parametrize("args", valid_args)
|
||||||
def test_constructor(args):
|
def test_constructor(args):
|
||||||
Spline(**args)
|
Spline(**args)
|
||||||
|
|
||||||
# Should fail if order is not a positive integer
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
Spline(
|
|
||||||
order=-1, control_points=args["control_points"], knots=args["knots"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should fail if control_points is not None or a torch.Tensor
|
def test_constructor_wrong():
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
Spline(
|
Spline()
|
||||||
order=args["order"], control_points=[1, 2, 3], knots=args["knots"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should fail if knots is not None, a torch.Tensor, or a dict
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
Spline(
|
|
||||||
order=args["order"], control_points=args["control_points"], knots=5
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should fail if both knots and control_points are None
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
Spline(order=args["order"], control_points=None, knots=None)
|
|
||||||
|
|
||||||
# Should fail if knots is not one-dimensional
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
Spline(
|
|
||||||
order=args["order"],
|
|
||||||
control_points=args["control_points"],
|
|
||||||
knots=torch.rand(n_knots, 4),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should fail if control_points is not one-dimensional
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
Spline(
|
|
||||||
order=args["order"],
|
|
||||||
control_points=torch.rand(n_ctrl_pts, 4),
|
|
||||||
knots=args["knots"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should fail if the number of knots != order + number of control points
|
|
||||||
# If control points are None, they are initialized to fulfill this condition
|
|
||||||
if args["control_points"] is not None:
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
Spline(
|
|
||||||
order=args["order"],
|
|
||||||
control_points=args["control_points"],
|
|
||||||
knots=torch.linspace(0, 1, n_knots + 1),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should fail if the knot dict is missing required keys
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
Spline(
|
|
||||||
order=args["order"],
|
|
||||||
control_points=args["control_points"],
|
|
||||||
knots={"n": n_knots, "min": 0, "max": 1},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should fail if the knot dict has invalid 'mode' key
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
Spline(
|
|
||||||
order=args["order"],
|
|
||||||
control_points=args["control_points"],
|
|
||||||
knots={"n": n_knots, "min": 0, "max": 1, "mode": "invalid"},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("args", valid_args)
|
@pytest.mark.parametrize("args", valid_args)
|
||||||
@pytest.mark.parametrize("pts", points)
|
def test_forward(args):
|
||||||
def test_forward(args, pts):
|
min_x = args["knots"][0]
|
||||||
|
max_x = args["knots"][-1]
|
||||||
# Define the model
|
xi = torch.linspace(min_x, max_x, 1000)
|
||||||
model = Spline(**args)
|
model = Spline(**args)
|
||||||
|
yi = model(xi).squeeze()
|
||||||
# Evaluate the model
|
scipy_check(model, xi, yi)
|
||||||
output_ = model(pts)
|
return
|
||||||
assert output_.shape == pts.shape
|
|
||||||
|
|
||||||
# Compare with scipy implementation only for interpolant knots (mode: auto)
|
|
||||||
if isinstance(args["knots"], dict) and args["knots"]["mode"] == "auto":
|
|
||||||
check_scipy_spline(model, pts, output_)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("args", valid_args)
|
@pytest.mark.parametrize("args", valid_args)
|
||||||
@pytest.mark.parametrize("pts", points)
|
def test_backward(args):
|
||||||
def test_backward(args, pts):
|
min_x = args["knots"][0]
|
||||||
|
max_x = args["knots"][-1]
|
||||||
# Define the model
|
xi = torch.linspace(min_x, max_x, 100)
|
||||||
model = Spline(**args)
|
model = Spline(**args)
|
||||||
|
yi = model(xi)
|
||||||
|
fake_loss = torch.sum(yi)
|
||||||
|
assert model.control_points.grad is None
|
||||||
|
fake_loss.backward()
|
||||||
|
assert model.control_points.grad is not None
|
||||||
|
|
||||||
# Evaluate the model
|
# dim_in, dim_out = 3, 2
|
||||||
output_ = model(pts)
|
# fnn = FeedForward(dim_in, dim_out)
|
||||||
loss = torch.mean(output_)
|
# data.requires_grad = True
|
||||||
loss.backward()
|
# output_ = fnn(data)
|
||||||
assert model.control_points.grad.shape == model.control_points.shape
|
# l=torch.mean(output_)
|
||||||
|
# l.backward()
|
||||||
|
# assert data._grad.shape == torch.Size([20,3])
|
||||||
|
|||||||
@@ -1,180 +0,0 @@
|
|||||||
import torch
|
|
||||||
import random
|
|
||||||
import pytest
|
|
||||||
from pina.model import SplineSurface
|
|
||||||
from pina import LabelTensor
|
|
||||||
|
|
||||||
|
|
||||||
# Utility quantities for testing
|
|
||||||
orders = [random.randint(1, 8) for _ in range(2)]
|
|
||||||
n_ctrl_pts = random.randint(max(orders), max(orders) + 5)
|
|
||||||
n_knots = [orders[i] + n_ctrl_pts for i in range(2)]
|
|
||||||
|
|
||||||
# Input tensor
|
|
||||||
points = [
|
|
||||||
LabelTensor(torch.rand(100, 2), ["x", "y"]),
|
|
||||||
LabelTensor(torch.rand(2, 100, 2), ["x", "y"]),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"knots_u",
|
|
||||||
[
|
|
||||||
torch.rand(n_knots[0]),
|
|
||||||
{"n": n_knots[0], "min": 0, "max": 1, "mode": "auto"},
|
|
||||||
{"n": n_knots[0], "min": 0, "max": 1, "mode": "uniform"},
|
|
||||||
None,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"knots_v",
|
|
||||||
[
|
|
||||||
torch.rand(n_knots[1]),
|
|
||||||
{"n": n_knots[1], "min": 0, "max": 1, "mode": "auto"},
|
|
||||||
{"n": n_knots[1], "min": 0, "max": 1, "mode": "uniform"},
|
|
||||||
None,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"control_points", [torch.rand(n_ctrl_pts, n_ctrl_pts), None]
|
|
||||||
)
|
|
||||||
def test_constructor(knots_u, knots_v, control_points):
|
|
||||||
|
|
||||||
# Skip if knots_u, knots_v, and control_points are all None
|
|
||||||
if (knots_u is None or knots_v is None) and control_points is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
SplineSurface(
|
|
||||||
orders=orders,
|
|
||||||
knots_u=knots_u,
|
|
||||||
knots_v=knots_v,
|
|
||||||
control_points=control_points,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should fail if orders is not list of two elements
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
SplineSurface(
|
|
||||||
orders=[orders[0]],
|
|
||||||
knots_u=knots_u,
|
|
||||||
knots_v=knots_v,
|
|
||||||
control_points=control_points,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should fail if both knots and control_points are None
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
SplineSurface(
|
|
||||||
orders=orders,
|
|
||||||
knots_u=None,
|
|
||||||
knots_v=None,
|
|
||||||
control_points=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should fail if control_points is not a torch.Tensor when provided
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
SplineSurface(
|
|
||||||
orders=orders,
|
|
||||||
knots_u=knots_u,
|
|
||||||
knots_v=knots_v,
|
|
||||||
control_points=[[0.0] * n_ctrl_pts] * n_ctrl_pts,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should fail if control_points is not of the correct shape when provided
|
|
||||||
# It assumes that at least one among knots_u and knots_v is not None
|
|
||||||
if knots_u is not None or knots_v is not None:
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
SplineSurface(
|
|
||||||
orders=orders,
|
|
||||||
knots_u=knots_u,
|
|
||||||
knots_v=knots_v,
|
|
||||||
control_points=torch.rand(n_ctrl_pts + 1, n_ctrl_pts + 1),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should fail if there are not enough knots_u to define the control points
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
SplineSurface(
|
|
||||||
orders=orders,
|
|
||||||
knots_u=torch.linspace(0, 1, orders[0]),
|
|
||||||
knots_v=knots_v,
|
|
||||||
control_points=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should fail if there are not enough knots_v to define the control points
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
SplineSurface(
|
|
||||||
orders=orders,
|
|
||||||
knots_u=knots_u,
|
|
||||||
knots_v=torch.linspace(0, 1, orders[1]),
|
|
||||||
control_points=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"knots_u",
|
|
||||||
[
|
|
||||||
torch.rand(n_knots[0]),
|
|
||||||
{"n": n_knots[0], "min": 0, "max": 1, "mode": "auto"},
|
|
||||||
{"n": n_knots[0], "min": 0, "max": 1, "mode": "uniform"},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"knots_v",
|
|
||||||
[
|
|
||||||
torch.rand(n_knots[1]),
|
|
||||||
{"n": n_knots[1], "min": 0, "max": 1, "mode": "auto"},
|
|
||||||
{"n": n_knots[1], "min": 0, "max": 1, "mode": "uniform"},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"control_points", [torch.rand(n_ctrl_pts, n_ctrl_pts), None]
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize("pts", points)
|
|
||||||
def test_forward(knots_u, knots_v, control_points, pts):
|
|
||||||
|
|
||||||
# Define the model
|
|
||||||
model = SplineSurface(
|
|
||||||
orders=orders,
|
|
||||||
knots_u=knots_u,
|
|
||||||
knots_v=knots_v,
|
|
||||||
control_points=control_points,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Evaluate the model
|
|
||||||
output_ = model(pts)
|
|
||||||
assert output_.shape == (*pts.shape[:-1], 1)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"knots_u",
|
|
||||||
[
|
|
||||||
torch.rand(n_knots[0]),
|
|
||||||
{"n": n_knots[0], "min": 0, "max": 1, "mode": "auto"},
|
|
||||||
{"n": n_knots[0], "min": 0, "max": 1, "mode": "uniform"},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"knots_v",
|
|
||||||
[
|
|
||||||
torch.rand(n_knots[1]),
|
|
||||||
{"n": n_knots[1], "min": 0, "max": 1, "mode": "auto"},
|
|
||||||
{"n": n_knots[1], "min": 0, "max": 1, "mode": "uniform"},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"control_points", [torch.rand(n_ctrl_pts, n_ctrl_pts), None]
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize("pts", points)
|
|
||||||
def test_backward(knots_u, knots_v, control_points, pts):
|
|
||||||
|
|
||||||
# Define the model
|
|
||||||
model = SplineSurface(
|
|
||||||
orders=orders,
|
|
||||||
knots_u=knots_u,
|
|
||||||
knots_v=knots_v,
|
|
||||||
control_points=control_points,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Evaluate the model
|
|
||||||
output_ = model(pts)
|
|
||||||
loss = torch.mean(output_)
|
|
||||||
loss.backward()
|
|
||||||
assert model.control_points.grad.shape == model.control_points.shape
|
|
||||||
Reference in New Issue
Block a user