add b-spline surface
This commit is contained in:
@@ -95,6 +95,7 @@ Models
|
|||||||
MultiFeedForward <model/multi_feed_forward.rst>
|
MultiFeedForward <model/multi_feed_forward.rst>
|
||||||
ResidualFeedForward <model/residual_feed_forward.rst>
|
ResidualFeedForward <model/residual_feed_forward.rst>
|
||||||
Spline <model/spline.rst>
|
Spline <model/spline.rst>
|
||||||
|
SplineSurface <model/spline_surface.rst>
|
||||||
DeepONet <model/deeponet.rst>
|
DeepONet <model/deeponet.rst>
|
||||||
MIONet <model/mionet.rst>
|
MIONet <model/mionet.rst>
|
||||||
KernelNeuralOperator <model/kernel_neural_operator.rst>
|
KernelNeuralOperator <model/kernel_neural_operator.rst>
|
||||||
|
|||||||
7
docs/source/_rst/model/spline_surface.rst
Normal file
7
docs/source/_rst/model/spline_surface.rst
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
Spline Surface
|
||||||
|
================
|
||||||
|
.. currentmodule:: pina.model.spline_surface
|
||||||
|
|
||||||
|
.. autoclass:: SplineSurface
|
||||||
|
:members:
|
||||||
|
:show-inheritance:
|
||||||
@@ -26,6 +26,7 @@ from .kernel_neural_operator import KernelNeuralOperator
|
|||||||
from .average_neural_operator import AveragingNeuralOperator
|
from .average_neural_operator import AveragingNeuralOperator
|
||||||
from .low_rank_neural_operator import LowRankNeuralOperator
|
from .low_rank_neural_operator import LowRankNeuralOperator
|
||||||
from .spline import Spline
|
from .spline import Spline
|
||||||
|
from .spline_surface import SplineSurface
|
||||||
from .graph_neural_operator import GraphNeuralOperator
|
from .graph_neural_operator import GraphNeuralOperator
|
||||||
from .pirate_network import PirateNet
|
from .pirate_network import PirateNet
|
||||||
from .equivariant_graph_neural_operator import EquivariantGraphNeuralOperator
|
from .equivariant_graph_neural_operator import EquivariantGraphNeuralOperator
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
"""Module for the B-Spline model class."""
|
"""Module for the B-Spline model class."""
|
||||||
|
|
||||||
import torch
|
|
||||||
import warnings
|
import warnings
|
||||||
from ..utils import check_positive_integer
|
import torch
|
||||||
|
from ..utils import check_positive_integer, check_consistency
|
||||||
|
|
||||||
|
|
||||||
class Spline(torch.nn.Module):
|
class Spline(torch.nn.Module):
|
||||||
@@ -75,6 +75,10 @@ class Spline(torch.nn.Module):
|
|||||||
If None, they are initialized as learnable parameters with an
|
If None, they are initialized as learnable parameters with an
|
||||||
initial value of zero. Default is None.
|
initial value of zero. Default is None.
|
||||||
:raises AssertionError: If ``order`` is not a positive integer.
|
:raises AssertionError: If ``order`` is not a positive integer.
|
||||||
|
:raises ValueError: If ``knots`` is neither a torch.Tensor nor a
|
||||||
|
dictionary, when provided.
|
||||||
|
:raises ValueError: If ``control_points`` is not a torch.Tensor,
|
||||||
|
when provided.
|
||||||
:raises ValueError: If both ``knots`` and ``control_points`` are None.
|
:raises ValueError: If both ``knots`` and ``control_points`` are None.
|
||||||
:raises ValueError: If ``knots`` is not one-dimensional.
|
:raises ValueError: If ``knots`` is not one-dimensional.
|
||||||
:raises ValueError: If ``control_points`` is not one-dimensional.
|
:raises ValueError: If ``control_points`` is not one-dimensional.
|
||||||
@@ -87,6 +91,8 @@ class Spline(torch.nn.Module):
|
|||||||
|
|
||||||
# Check consistency
|
# Check consistency
|
||||||
check_positive_integer(value=order, strict=True)
|
check_positive_integer(value=order, strict=True)
|
||||||
|
check_consistency(knots, (type(None), torch.Tensor, dict))
|
||||||
|
check_consistency(control_points, (type(None), torch.Tensor))
|
||||||
|
|
||||||
# Raise error if neither knots nor control points are provided
|
# Raise error if neither knots nor control points are provided
|
||||||
if knots is None and control_points is None:
|
if knots is None and control_points is None:
|
||||||
@@ -229,10 +235,10 @@ class Spline(torch.nn.Module):
|
|||||||
:rtype: torch.Tensor
|
:rtype: torch.Tensor
|
||||||
"""
|
"""
|
||||||
return torch.einsum(
|
return torch.einsum(
|
||||||
"bi, i -> b",
|
"...bi, i -> ...b",
|
||||||
self.basis(x.as_subclass(torch.Tensor)).squeeze(1),
|
self.basis(x.as_subclass(torch.Tensor)).squeeze(-1),
|
||||||
self.control_points,
|
self.control_points,
|
||||||
).reshape(-1, 1)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def control_points(self):
|
def control_points(self):
|
||||||
@@ -254,7 +260,6 @@ class Spline(torch.nn.Module):
|
|||||||
initial value. Default is None.
|
initial value. Default is None.
|
||||||
:raises ValueError: If there are not enough knots to define the control
|
:raises ValueError: If there are not enough knots to define the control
|
||||||
points, due to the relation: #knots = order + #control_points.
|
points, due to the relation: #knots = order + #control_points.
|
||||||
:raises ValueError: If control_points is not a torch.Tensor.
|
|
||||||
"""
|
"""
|
||||||
# If control points are not provided, initialize them
|
# If control points are not provided, initialize them
|
||||||
if control_points is None:
|
if control_points is None:
|
||||||
@@ -270,13 +275,6 @@ class Spline(torch.nn.Module):
|
|||||||
# Initialize control points to zero
|
# Initialize control points to zero
|
||||||
control_points = torch.zeros(len(self.knots) - self.order)
|
control_points = torch.zeros(len(self.knots) - self.order)
|
||||||
|
|
||||||
# Check validity of control points
|
|
||||||
elif not isinstance(control_points, torch.Tensor):
|
|
||||||
raise ValueError(
|
|
||||||
"control_points must be a torch.Tensor,"
|
|
||||||
f" got {type(control_points)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set control points
|
# Set control points
|
||||||
self._control_points = torch.nn.Parameter(
|
self._control_points = torch.nn.Parameter(
|
||||||
control_points, requires_grad=True
|
control_points, requires_grad=True
|
||||||
@@ -308,18 +306,10 @@ class Spline(torch.nn.Module):
|
|||||||
last control points. In this case, the number of knots is inferred
|
last control points. In this case, the number of knots is inferred
|
||||||
and the ``"n"`` key is ignored.
|
and the ``"n"`` key is ignored.
|
||||||
:type value: torch.Tensor | dict
|
:type value: torch.Tensor | dict
|
||||||
:raises ValueError: If value is not a torch.Tensor or a dictionary.
|
|
||||||
:raises ValueError: If a dictionary is provided but does not contain
|
:raises ValueError: If a dictionary is provided but does not contain
|
||||||
the required keys.
|
the required keys.
|
||||||
:raises ValueError: If the mode specified in the dictionary is invalid.
|
:raises ValueError: If the mode specified in the dictionary is invalid.
|
||||||
"""
|
"""
|
||||||
# Check validity of knots
|
|
||||||
if not isinstance(value, (torch.Tensor, dict)):
|
|
||||||
raise ValueError(
|
|
||||||
"Knots must be a torch.Tensor or a dictionary,"
|
|
||||||
f" got {type(value)}."
|
|
||||||
)
|
|
||||||
|
|
||||||
# If a dictionary is provided, initialize knots accordingly
|
# If a dictionary is provided, initialize knots accordingly
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
|
|
||||||
|
|||||||
212
pina/model/spline_surface.py
Normal file
212
pina/model/spline_surface.py
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
"""Module for the bivariate B-Spline surface model class."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from .spline import Spline
|
||||||
|
from ..utils import check_consistency
|
||||||
|
|
||||||
|
|
||||||
|
class SplineSurface(torch.nn.Module):
|
||||||
|
r"""
|
||||||
|
The bivariate B-Spline surface model class.
|
||||||
|
|
||||||
|
A bivariate B-spline surface is a parametric surface defined as the tensor
|
||||||
|
product of two univariate B-spline curves:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
S(x, y) = \sum_{i,j=1}^{n_x, n_y} B_{i,k}(x) B_{j,s}(y) C_{i,j},
|
||||||
|
\quad x \in [x_1, x_m], y \in [y_1, y_l]
|
||||||
|
|
||||||
|
where:
|
||||||
|
|
||||||
|
- :math:`C_{i,j} \in \mathbb{R}^2` are the control points. These fixed
|
||||||
|
points influence the shape of the surface but are not generally
|
||||||
|
interpolated, except at the boundaries under certain knot multiplicities.
|
||||||
|
- :math:`B_{i,k}(x)` and :math:`B_{j,s}(y)` are the B-spline basis functions
|
||||||
|
defined over two orthogonal directions, with orders :math:`k` and
|
||||||
|
:math:`s`, respectively.
|
||||||
|
- :math:`X = \{ x_1, x_2, \dots, x_m \}` and
|
||||||
|
:math:`Y = \{ y_1, y_2, \dots, y_l \}` are the non-decreasing knot
|
||||||
|
vectors along the two directions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, orders, knots_u=None, knots_v=None, control_points=None):
|
||||||
|
"""
|
||||||
|
Initialization of the :class:`SplineSurface` class.
|
||||||
|
|
||||||
|
:param list[int] orders: The orders of the spline along each parametric
|
||||||
|
direction. Each order defines the degree of the corresponding basis
|
||||||
|
as ``degree = order - 1``.
|
||||||
|
:param knots_u: The knots of the spline along the first direction.
|
||||||
|
For details on valid formats and initialization modes, see the
|
||||||
|
:class:`Spline` class. Default is None.
|
||||||
|
:type knots_u: torch.Tensor | dict
|
||||||
|
:param knots_v: The knots of the spline along the second direction.
|
||||||
|
For details on valid formats and initialization modes, see the
|
||||||
|
:class:`Spline` class. Default is None.
|
||||||
|
:type knots_v: torch.Tensor | dict
|
||||||
|
:param torch.Tensor control_points: The control points defining the
|
||||||
|
surface geometry. It must be a two-dimensional tensor of shape
|
||||||
|
``[len(knots_u) - orders[0], len(knots_v) - orders[1]]``.
|
||||||
|
If None, they are initialized as learnable parameters with zero
|
||||||
|
values. Default is None.
|
||||||
|
:raises ValueError: If ``orders`` is not a list of integers.
|
||||||
|
:raises ValueError: If ``knots_u`` is neither a torch.Tensor nor a
|
||||||
|
dictionary, when provided.
|
||||||
|
:raises ValueError: If ``knots_v`` is neither a torch.Tensor nor a
|
||||||
|
dictionary, when provided.
|
||||||
|
:raises ValueError: If ``control_points`` is not a torch.Tensor,
|
||||||
|
when provided.
|
||||||
|
:raises ValueError: If ``orders`` is not a list of two elements.
|
||||||
|
:raises ValueError: If ``knots_u``, ``knots_v``, and ``control_points``
|
||||||
|
are all None.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Check consistency
|
||||||
|
check_consistency(orders, int)
|
||||||
|
check_consistency(control_points, (type(None), torch.Tensor))
|
||||||
|
check_consistency(knots_u, (type(None), torch.Tensor, dict))
|
||||||
|
check_consistency(knots_v, (type(None), torch.Tensor, dict))
|
||||||
|
|
||||||
|
# Check orders is a list of two elements
|
||||||
|
if len(orders) != 2:
|
||||||
|
raise ValueError("orders must be a list of two elements.")
|
||||||
|
|
||||||
|
# Raise error if neither knots nor control points are provided
|
||||||
|
if (knots_u is None or knots_v is None) and control_points is None:
|
||||||
|
raise ValueError(
|
||||||
|
"control_points cannot be None if knots_u or knots_v is None."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize knots_u if not provided
|
||||||
|
if knots_u is None and control_points is not None:
|
||||||
|
knots_u = {
|
||||||
|
"n": control_points.shape[0] + orders[0],
|
||||||
|
"min": 0,
|
||||||
|
"max": 1,
|
||||||
|
"mode": "auto",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Initialize knots_v if not provided
|
||||||
|
if knots_v is None and control_points is not None:
|
||||||
|
knots_v = {
|
||||||
|
"n": control_points.shape[1] + orders[1],
|
||||||
|
"min": 0,
|
||||||
|
"max": 1,
|
||||||
|
"mode": "auto",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create two univariate b-splines
|
||||||
|
self.spline_u = Spline(order=orders[0], knots=knots_u)
|
||||||
|
self.spline_v = Spline(order=orders[1], knots=knots_v)
|
||||||
|
self.control_points = control_points
|
||||||
|
|
||||||
|
# Delete unneeded parameters
|
||||||
|
delattr(self.spline_u, "_control_points")
|
||||||
|
delattr(self.spline_v, "_control_points")
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Forward pass for the :class:`SplineSurface` model.
|
||||||
|
|
||||||
|
:param x: The input tensor.
|
||||||
|
:type x: torch.Tensor | LabelTensor
|
||||||
|
:return: The output tensor.
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
"""
|
||||||
|
return torch.einsum(
|
||||||
|
"...bi, ...bj, ij -> ...b",
|
||||||
|
self.spline_u.basis(x.as_subclass(torch.Tensor)[..., 0]),
|
||||||
|
self.spline_v.basis(x.as_subclass(torch.Tensor)[..., 1]),
|
||||||
|
self.control_points,
|
||||||
|
).unsqueeze(-1)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def knots(self):
|
||||||
|
"""
|
||||||
|
The knots of the univariate splines defining the spline surface.
|
||||||
|
|
||||||
|
:return: The knots.
|
||||||
|
:rtype: tuple(torch.Tensor, torch.Tensor)
|
||||||
|
"""
|
||||||
|
return self.spline_u.knots, self.spline_v.knots
|
||||||
|
|
||||||
|
@knots.setter
|
||||||
|
def knots(self, value):
|
||||||
|
"""
|
||||||
|
Set the knots of the spline surface.
|
||||||
|
|
||||||
|
:param value: A tuple (knots_u, knots_v) containing the knots for both
|
||||||
|
parametric directions.
|
||||||
|
:type value: tuple(torch.Tensor | dict, torch.Tensor | dict)
|
||||||
|
:raises ValueError: If value is not a tuple of two elements.
|
||||||
|
"""
|
||||||
|
# Check value is a tuple of two elements
|
||||||
|
if not (isinstance(value, tuple) and len(value) == 2):
|
||||||
|
raise ValueError("Knots must be a tuple of two elements.")
|
||||||
|
|
||||||
|
knots_u, knots_v = value
|
||||||
|
self.spline_u.knots = knots_u
|
||||||
|
self.spline_v.knots = knots_v
|
||||||
|
|
||||||
|
@property
|
||||||
|
def control_points(self):
|
||||||
|
"""
|
||||||
|
The control points of the spline.
|
||||||
|
|
||||||
|
:return: The control points.
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
"""
|
||||||
|
return self._control_points
|
||||||
|
|
||||||
|
@control_points.setter
|
||||||
|
def control_points(self, control_points):
|
||||||
|
"""
|
||||||
|
Set the control points of the spline surface.
|
||||||
|
|
||||||
|
:param torch.Tensor control_points: The bidimensional control points
|
||||||
|
tensor, where each dimension refers to a direction in the parameter
|
||||||
|
space. If None, control points are initialized to learnable
|
||||||
|
parameters with zero initial value. Default is None.
|
||||||
|
:raises ValueError: If in any direction there are not enough knots to
|
||||||
|
define the control points, due to the relation:
|
||||||
|
#knots = order + #control_points.
|
||||||
|
:raises ValueError: If ``control_points`` is not of the correct shape.
|
||||||
|
"""
|
||||||
|
# Save correct shape of control points
|
||||||
|
__valid_shape = (
|
||||||
|
len(self.spline_u.knots) - self.spline_u.order,
|
||||||
|
len(self.spline_v.knots) - self.spline_v.order,
|
||||||
|
)
|
||||||
|
|
||||||
|
# If control points are not provided, initialize them
|
||||||
|
if control_points is None:
|
||||||
|
|
||||||
|
# Check that there are enough knots to define control points
|
||||||
|
if (
|
||||||
|
len(self.spline_u.knots) < self.spline_u.order + 1
|
||||||
|
or len(self.spline_v.knots) < self.spline_v.order + 1
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough knots to define control points. Got "
|
||||||
|
f"{len(self.spline_u.knots)} knots along u and "
|
||||||
|
f"{len(self.spline_v.knots)} knots along v, but need at "
|
||||||
|
f"least {self.spline_u.order + 1} and "
|
||||||
|
f"{self.spline_v.order + 1}, respectively."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize control points to zero
|
||||||
|
control_points = torch.zeros(__valid_shape)
|
||||||
|
|
||||||
|
# Check control points
|
||||||
|
if control_points.shape != __valid_shape:
|
||||||
|
raise ValueError(
|
||||||
|
"control_points must be of the correct shape. ",
|
||||||
|
f"Expected {__valid_shape}, got {control_points.shape}.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register control points as a learnable parameter
|
||||||
|
self._control_points = torch.nn.Parameter(
|
||||||
|
control_points, requires_grad=True
|
||||||
|
)
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
import pytest
|
import pytest
|
||||||
import numpy as np
|
|
||||||
from scipy.interpolate import BSpline
|
from scipy.interpolate import BSpline
|
||||||
from pina.model import Spline
|
from pina.model import Spline
|
||||||
from pina import LabelTensor
|
from pina import LabelTensor
|
||||||
@@ -12,7 +11,10 @@ n_ctrl_pts = torch.randint(order, order + 5, (1,)).item()
|
|||||||
n_knots = order + n_ctrl_pts
|
n_knots = order + n_ctrl_pts
|
||||||
|
|
||||||
# Input tensor
|
# Input tensor
|
||||||
pts = LabelTensor(torch.linspace(0, 1, 100).reshape(-1, 1), ["x"])
|
points = [
|
||||||
|
LabelTensor(torch.rand(100, 1), ["x"]),
|
||||||
|
LabelTensor(torch.rand(2, 100, 1), ["x"]),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
# Function to compare with scipy implementation
|
# Function to compare with scipy implementation
|
||||||
@@ -26,15 +28,15 @@ def check_scipy_spline(model, x, output_):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Compare outputs
|
# Compare outputs
|
||||||
np.testing.assert_allclose(
|
torch.allclose(
|
||||||
output_.squeeze().detach().numpy(),
|
output_,
|
||||||
scipy_spline(x).flatten(),
|
torch.tensor(scipy_spline(x), dtype=output_.dtype),
|
||||||
atol=1e-5,
|
atol=1e-5,
|
||||||
rtol=1e-5,
|
rtol=1e-5,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Define all possible combinations of valid arguments for the Spline class
|
# Define all possible combinations of valid arguments for Spline class
|
||||||
valid_args = [
|
valid_args = [
|
||||||
{
|
{
|
||||||
"order": order,
|
"order": order,
|
||||||
@@ -144,14 +146,15 @@ def test_constructor(args):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("args", valid_args)
|
@pytest.mark.parametrize("args", valid_args)
|
||||||
def test_forward(args):
|
@pytest.mark.parametrize("pts", points)
|
||||||
|
def test_forward(args, pts):
|
||||||
|
|
||||||
# Define the model
|
# Define the model
|
||||||
model = Spline(**args)
|
model = Spline(**args)
|
||||||
|
|
||||||
# Evaluate the model
|
# Evaluate the model
|
||||||
output_ = model(pts)
|
output_ = model(pts)
|
||||||
assert output_.shape == (pts.shape[0], 1)
|
assert output_.shape == pts.shape
|
||||||
|
|
||||||
# Compare with scipy implementation only for interpolant knots (mode: auto)
|
# Compare with scipy implementation only for interpolant knots (mode: auto)
|
||||||
if isinstance(args["knots"], dict) and args["knots"]["mode"] == "auto":
|
if isinstance(args["knots"], dict) and args["knots"]["mode"] == "auto":
|
||||||
@@ -159,7 +162,8 @@ def test_forward(args):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("args", valid_args)
|
@pytest.mark.parametrize("args", valid_args)
|
||||||
def test_backward(args):
|
@pytest.mark.parametrize("pts", points)
|
||||||
|
def test_backward(args, pts):
|
||||||
|
|
||||||
# Define the model
|
# Define the model
|
||||||
model = Spline(**args)
|
model = Spline(**args)
|
||||||
|
|||||||
180
tests/test_model/test_spline_surface.py
Normal file
180
tests/test_model/test_spline_surface.py
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
import torch
|
||||||
|
import random
|
||||||
|
import pytest
|
||||||
|
from pina.model import SplineSurface
|
||||||
|
from pina import LabelTensor
|
||||||
|
|
||||||
|
|
||||||
|
# Utility quantities for testing
|
||||||
|
orders = [random.randint(1, 8) for _ in range(2)]
|
||||||
|
n_ctrl_pts = random.randint(max(orders), max(orders) + 5)
|
||||||
|
n_knots = [orders[i] + n_ctrl_pts for i in range(2)]
|
||||||
|
|
||||||
|
# Input tensor
|
||||||
|
points = [
|
||||||
|
LabelTensor(torch.rand(100, 2), ["x", "y"]),
|
||||||
|
LabelTensor(torch.rand(2, 100, 2), ["x", "y"]),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"knots_u",
|
||||||
|
[
|
||||||
|
torch.rand(n_knots[0]),
|
||||||
|
{"n": n_knots[0], "min": 0, "max": 1, "mode": "auto"},
|
||||||
|
{"n": n_knots[0], "min": 0, "max": 1, "mode": "uniform"},
|
||||||
|
None,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"knots_v",
|
||||||
|
[
|
||||||
|
torch.rand(n_knots[1]),
|
||||||
|
{"n": n_knots[1], "min": 0, "max": 1, "mode": "auto"},
|
||||||
|
{"n": n_knots[1], "min": 0, "max": 1, "mode": "uniform"},
|
||||||
|
None,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"control_points", [torch.rand(n_ctrl_pts, n_ctrl_pts), None]
|
||||||
|
)
|
||||||
|
def test_constructor(knots_u, knots_v, control_points):
|
||||||
|
|
||||||
|
# Skip if knots_u, knots_v, and control_points are all None
|
||||||
|
if (knots_u is None or knots_v is None) and control_points is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
SplineSurface(
|
||||||
|
orders=orders,
|
||||||
|
knots_u=knots_u,
|
||||||
|
knots_v=knots_v,
|
||||||
|
control_points=control_points,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should fail if orders is not list of two elements
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
SplineSurface(
|
||||||
|
orders=[orders[0]],
|
||||||
|
knots_u=knots_u,
|
||||||
|
knots_v=knots_v,
|
||||||
|
control_points=control_points,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should fail if both knots and control_points are None
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
SplineSurface(
|
||||||
|
orders=orders,
|
||||||
|
knots_u=None,
|
||||||
|
knots_v=None,
|
||||||
|
control_points=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should fail if control_points is not a torch.Tensor when provided
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
SplineSurface(
|
||||||
|
orders=orders,
|
||||||
|
knots_u=knots_u,
|
||||||
|
knots_v=knots_v,
|
||||||
|
control_points=[[0.0] * n_ctrl_pts] * n_ctrl_pts,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should fail if control_points is not of the correct shape when provided
|
||||||
|
# It assumes that at least one among knots_u and knots_v is not None
|
||||||
|
if knots_u is not None or knots_v is not None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
SplineSurface(
|
||||||
|
orders=orders,
|
||||||
|
knots_u=knots_u,
|
||||||
|
knots_v=knots_v,
|
||||||
|
control_points=torch.rand(n_ctrl_pts + 1, n_ctrl_pts + 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should fail if there are not enough knots_u to define the control points
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
SplineSurface(
|
||||||
|
orders=orders,
|
||||||
|
knots_u=torch.linspace(0, 1, orders[0]),
|
||||||
|
knots_v=knots_v,
|
||||||
|
control_points=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should fail if there are not enough knots_v to define the control points
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
SplineSurface(
|
||||||
|
orders=orders,
|
||||||
|
knots_u=knots_u,
|
||||||
|
knots_v=torch.linspace(0, 1, orders[1]),
|
||||||
|
control_points=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"knots_u",
|
||||||
|
[
|
||||||
|
torch.rand(n_knots[0]),
|
||||||
|
{"n": n_knots[0], "min": 0, "max": 1, "mode": "auto"},
|
||||||
|
{"n": n_knots[0], "min": 0, "max": 1, "mode": "uniform"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"knots_v",
|
||||||
|
[
|
||||||
|
torch.rand(n_knots[1]),
|
||||||
|
{"n": n_knots[1], "min": 0, "max": 1, "mode": "auto"},
|
||||||
|
{"n": n_knots[1], "min": 0, "max": 1, "mode": "uniform"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"control_points", [torch.rand(n_ctrl_pts, n_ctrl_pts), None]
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("pts", points)
|
||||||
|
def test_forward(knots_u, knots_v, control_points, pts):
|
||||||
|
|
||||||
|
# Define the model
|
||||||
|
model = SplineSurface(
|
||||||
|
orders=orders,
|
||||||
|
knots_u=knots_u,
|
||||||
|
knots_v=knots_v,
|
||||||
|
control_points=control_points,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Evaluate the model
|
||||||
|
output_ = model(pts)
|
||||||
|
assert output_.shape == (*pts.shape[:-1], 1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"knots_u",
|
||||||
|
[
|
||||||
|
torch.rand(n_knots[0]),
|
||||||
|
{"n": n_knots[0], "min": 0, "max": 1, "mode": "auto"},
|
||||||
|
{"n": n_knots[0], "min": 0, "max": 1, "mode": "uniform"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"knots_v",
|
||||||
|
[
|
||||||
|
torch.rand(n_knots[1]),
|
||||||
|
{"n": n_knots[1], "min": 0, "max": 1, "mode": "auto"},
|
||||||
|
{"n": n_knots[1], "min": 0, "max": 1, "mode": "uniform"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"control_points", [torch.rand(n_ctrl_pts, n_ctrl_pts), None]
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("pts", points)
|
||||||
|
def test_backward(knots_u, knots_v, control_points, pts):
|
||||||
|
|
||||||
|
# Define the model
|
||||||
|
model = SplineSurface(
|
||||||
|
orders=orders,
|
||||||
|
knots_u=knots_u,
|
||||||
|
knots_v=knots_v,
|
||||||
|
control_points=control_points,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Evaluate the model
|
||||||
|
output_ = model(pts)
|
||||||
|
loss = torch.mean(output_)
|
||||||
|
loss.backward()
|
||||||
|
assert model.control_points.grad.shape == model.control_points.shape
|
||||||
Reference in New Issue
Block a user