add spline model (#321)
* add spline model * add tests for splines * rst files for splines --------- Co-authored-by: AleDinve <giuseppealessio.d@student.unisi.it> Co-authored-by: dario-coscia <dariocos99@gmail.com>
This commit is contained in:
@@ -59,6 +59,7 @@ Models
|
||||
FeedForward <models/fnn.rst>
|
||||
MultiFeedForward <models/multifeedforward.rst>
|
||||
ResidualFeedForward <models/fnn_residual.rst>
|
||||
Spline <models/spline.rst>
|
||||
DeepONet <models/deeponet.rst>
|
||||
MIONet <models/mionet.rst>
|
||||
FourierIntegralKernel <models/fourier_kernel.rst>
|
||||
|
||||
7
docs/source/_rst/models/spline.rst
Normal file
7
docs/source/_rst/models/spline.rst
Normal file
@@ -0,0 +1,7 @@
|
||||
Spline
|
||||
========
|
||||
.. currentmodule:: pina.model.spline
|
||||
|
||||
.. autoclass:: Spline
|
||||
:members:
|
||||
:show-inheritance:
|
||||
@@ -9,6 +9,7 @@ __all__ = [
|
||||
"KernelNeuralOperator",
|
||||
"AveragingNeuralOperator",
|
||||
"LowRankNeuralOperator",
|
||||
"Spline",
|
||||
]
|
||||
|
||||
from .feed_forward import FeedForward, ResidualFeedForward
|
||||
@@ -18,3 +19,4 @@ from .fno import FNO, FourierIntegralKernel
|
||||
from .base_no import KernelNeuralOperator
|
||||
from .avno import AveragingNeuralOperator
|
||||
from .lno import LowRankNeuralOperator
|
||||
from .spline import Spline
|
||||
166
pina/model/spline.py
Normal file
166
pina/model/spline.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""Module for Spline model"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ..utils import check_consistency
|
||||
|
||||
class Spline(torch.nn.Module):
|
||||
|
||||
def __init__(self, order=4, knots=None, control_points=None) -> None:
|
||||
"""
|
||||
Spline model.
|
||||
|
||||
:param int order: the order of the spline.
|
||||
:param torch.Tensor knots: the knot vector.
|
||||
:param torch.Tensor control_points: the control points.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
check_consistency(order, int)
|
||||
|
||||
if order < 0:
|
||||
raise ValueError("Spline order cannot be negative.")
|
||||
if knots is None and control_points is None:
|
||||
raise ValueError("Knots and control points cannot be both None.")
|
||||
|
||||
self.order = order
|
||||
self.k = order - 1
|
||||
|
||||
if knots is not None and control_points is not None:
|
||||
self.knots = knots
|
||||
self.control_points = control_points
|
||||
|
||||
elif knots is not None:
|
||||
print('Warning: control points will be initialized automatically.')
|
||||
print(' experimental feature')
|
||||
|
||||
self.knots = knots
|
||||
n = len(knots) - order
|
||||
self.control_points = torch.nn.Parameter(
|
||||
torch.zeros(n), requires_grad=True)
|
||||
|
||||
elif control_points is not None:
|
||||
print('Warning: knots will be initialized automatically.')
|
||||
print(' experimental feature')
|
||||
|
||||
self.control_points = control_points
|
||||
|
||||
n = len(self.control_points)-1
|
||||
self.knots = {
|
||||
'type': 'auto',
|
||||
'min': 0,
|
||||
'max': 1,
|
||||
'n': n+2+self.order}
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"Knots and control points cannot be both None."
|
||||
)
|
||||
|
||||
|
||||
if self.knots.ndim != 1:
|
||||
raise ValueError("Knot vector must be one-dimensional.")
|
||||
|
||||
def basis(self, x, k, i, t):
|
||||
'''
|
||||
Recursive function to compute the basis functions of the spline.
|
||||
|
||||
:param torch.Tensor x: points to be evaluated.
|
||||
:param int k: spline degree
|
||||
:param int i: the index of the interval
|
||||
:param torch.Tensor t: vector of knots
|
||||
:return: the basis functions evaluated at x
|
||||
:rtype: torch.Tensor
|
||||
'''
|
||||
|
||||
if k == 0:
|
||||
a = torch.where(torch.logical_and(t[i] <= x, x < t[i+1]), 1.0, 0.0)
|
||||
if i == len(t) - self.order - 1:
|
||||
a = torch.where(x == t[-1], 1.0, a)
|
||||
a.requires_grad_(True)
|
||||
return a
|
||||
|
||||
|
||||
if t[i+k] == t[i]:
|
||||
c1 = torch.tensor([0.0]*len(x), requires_grad=True)
|
||||
else:
|
||||
c1 = (x - t[i])/(t[i+k] - t[i]) * self.basis(x, k-1, i, t)
|
||||
|
||||
if t[i+k+1] == t[i+1]:
|
||||
c2 = torch.tensor([0.0]*len(x), requires_grad=True)
|
||||
else:
|
||||
c2 = (t[i+k+1] - x)/(t[i+k+1] - t[i+1]) * self.basis(x, k-1, i+1, t)
|
||||
|
||||
return c1 + c2
|
||||
|
||||
|
||||
@property
|
||||
def control_points(self):
|
||||
return self._control_points
|
||||
|
||||
@control_points.setter
|
||||
def control_points(self, value):
|
||||
if isinstance(value, dict):
|
||||
if 'n' not in value:
|
||||
raise ValueError('Invalid value for control_points')
|
||||
n = value['n']
|
||||
dim = value.get('dim', 1)
|
||||
value = torch.zeros(n, dim)
|
||||
|
||||
if not isinstance(value, torch.Tensor):
|
||||
raise ValueError('Invalid value for control_points')
|
||||
self._control_points = torch.nn.Parameter(value, requires_grad=True)
|
||||
|
||||
@property
|
||||
def knots(self):
|
||||
return self._knots
|
||||
|
||||
@knots.setter
|
||||
def knots(self, value):
|
||||
if isinstance(value, dict):
|
||||
|
||||
type_ = value.get('type', 'auto')
|
||||
min_ = value.get('min', 0)
|
||||
max_ = value.get('max', 1)
|
||||
n = value.get('n', 10)
|
||||
|
||||
if type_ == 'uniform':
|
||||
value = torch.linspace(min_, max_, n + self.k + 1)
|
||||
elif type_ == 'auto':
|
||||
initial_knots = torch.ones(self.order+1)*min_
|
||||
final_knots = torch.ones(self.order+1)*max_
|
||||
|
||||
if n < self.order + 1:
|
||||
value = torch.concatenate((initial_knots, final_knots))
|
||||
elif n - 2*self.order + 1 == 1:
|
||||
value = torch.Tensor([(max_ + min_)/2])
|
||||
else:
|
||||
value = torch.linspace(min_, max_, n - 2*self.order - 1)
|
||||
|
||||
value = torch.concatenate(
|
||||
(
|
||||
initial_knots, value, final_knots
|
||||
)
|
||||
)
|
||||
|
||||
if not isinstance(value, torch.Tensor):
|
||||
raise ValueError('Invalid value for knots')
|
||||
|
||||
self._knots = value
|
||||
|
||||
def forward(self, x_):
|
||||
"""
|
||||
Forward pass of the spline model.
|
||||
|
||||
:param torch.Tensor x_: points to be evaluated.
|
||||
:return: the spline evaluated at x_
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
t = self.knots
|
||||
k = self.k
|
||||
c = self.control_points
|
||||
|
||||
basis = map(lambda i: self.basis(x_, k, i, t)[:, None], range(len(c)))
|
||||
y = (torch.cat(list(basis), dim=1) * c).sum(axis=1)
|
||||
|
||||
return y
|
||||
6
setup.py
6
setup.py
@@ -26,7 +26,11 @@ EXTRAS = {
|
||||
'sphinx_design',
|
||||
'pydata_sphinx_theme'
|
||||
],
|
||||
'test': ['pytest', 'pytest-cov'],
|
||||
'test': [
|
||||
'pytest',
|
||||
'pytest-cov',
|
||||
'scipy'
|
||||
],
|
||||
}
|
||||
|
||||
LDESCRIPTION = (
|
||||
|
||||
74
tests/test_model/test_spline.py
Normal file
74
tests/test_model/test_spline.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
from pina.model import Spline
|
||||
|
||||
data = torch.rand((20, 3))
|
||||
input_vars = 3
|
||||
output_vars = 4
|
||||
|
||||
valid_args = [
|
||||
{
|
||||
'knots': torch.tensor([0., 0., 0., 1., 2., 3., 3., 3.]),
|
||||
'control_points': torch.tensor([0., 0., 1., 0., 0.]),
|
||||
'order': 3
|
||||
},
|
||||
{
|
||||
'knots': torch.tensor([-2., -2., -2., -2., -1., 0., 1., 2., 2., 2., 2.]),
|
||||
'control_points': torch.tensor([0., 0., 0., 6., 0., 0., 0.]),
|
||||
'order': 4
|
||||
},
|
||||
# {'control_points': {'n': 5, 'dim': 1}, 'order': 2},
|
||||
# {'control_points': {'n': 7, 'dim': 1}, 'order': 3}
|
||||
]
|
||||
|
||||
def scipy_check(model, x, y):
|
||||
from scipy.interpolate._bsplines import BSpline
|
||||
import numpy as np
|
||||
spline = BSpline(
|
||||
t=model.knots.detach().numpy(),
|
||||
c=model.control_points.detach().numpy(),
|
||||
k=model.order-1
|
||||
)
|
||||
y_scipy = spline(x).flatten()
|
||||
y = y.detach().numpy()
|
||||
np.testing.assert_allclose(y, y_scipy, atol=1e-5)
|
||||
|
||||
@pytest.mark.parametrize("args", valid_args)
|
||||
def test_constructor(args):
|
||||
Spline(**args)
|
||||
|
||||
def test_constructor_wrong():
|
||||
with pytest.raises(ValueError):
|
||||
Spline()
|
||||
|
||||
@pytest.mark.parametrize("args", valid_args)
|
||||
def test_forward(args):
|
||||
min_x = args['knots'][0]
|
||||
max_x = args['knots'][-1]
|
||||
xi = torch.linspace(min_x, max_x, 1000)
|
||||
model = Spline(**args)
|
||||
yi = model(xi).squeeze()
|
||||
scipy_check(model, xi, yi)
|
||||
return
|
||||
|
||||
|
||||
@pytest.mark.parametrize("args", valid_args)
|
||||
def test_backward(args):
|
||||
min_x = args['knots'][0]
|
||||
max_x = args['knots'][-1]
|
||||
xi = torch.linspace(min_x, max_x, 100)
|
||||
model = Spline(**args)
|
||||
yi = model(xi)
|
||||
fake_loss = torch.sum(yi)
|
||||
assert model.control_points.grad is None
|
||||
fake_loss.backward()
|
||||
assert model.control_points.grad is not None
|
||||
|
||||
# dim_in, dim_out = 3, 2
|
||||
# fnn = FeedForward(dim_in, dim_out)
|
||||
# data.requires_grad = True
|
||||
# output_ = fnn(data)
|
||||
# l=torch.mean(output_)
|
||||
# l.backward()
|
||||
# assert data._grad.shape == torch.Size([20,3])
|
||||
Reference in New Issue
Block a user