🎨 Format Python code with psf/black (#348)
This commit is contained in:
committed by
GitHub
parent
4c5cb8f681
commit
f812d87727
@@ -4,6 +4,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from ..utils import check_consistency
|
from ..utils import check_consistency
|
||||||
|
|
||||||
|
|
||||||
class Spline(torch.nn.Module):
|
class Spline(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, order=4, knots=None, control_points=None) -> None:
|
def __init__(self, order=4, knots=None, control_points=None) -> None:
|
||||||
@@ -31,38 +32,37 @@ class Spline(torch.nn.Module):
|
|||||||
self.control_points = control_points
|
self.control_points = control_points
|
||||||
|
|
||||||
elif knots is not None:
|
elif knots is not None:
|
||||||
print('Warning: control points will be initialized automatically.')
|
print("Warning: control points will be initialized automatically.")
|
||||||
print(' experimental feature')
|
print(" experimental feature")
|
||||||
|
|
||||||
self.knots = knots
|
self.knots = knots
|
||||||
n = len(knots) - order
|
n = len(knots) - order
|
||||||
self.control_points = torch.nn.Parameter(
|
self.control_points = torch.nn.Parameter(
|
||||||
torch.zeros(n), requires_grad=True)
|
torch.zeros(n), requires_grad=True
|
||||||
|
)
|
||||||
|
|
||||||
elif control_points is not None:
|
elif control_points is not None:
|
||||||
print('Warning: knots will be initialized automatically.')
|
print("Warning: knots will be initialized automatically.")
|
||||||
print(' experimental feature')
|
print(" experimental feature")
|
||||||
|
|
||||||
self.control_points = control_points
|
self.control_points = control_points
|
||||||
|
|
||||||
n = len(self.control_points)-1
|
n = len(self.control_points) - 1
|
||||||
self.knots = {
|
self.knots = {
|
||||||
'type': 'auto',
|
"type": "auto",
|
||||||
'min': 0,
|
"min": 0,
|
||||||
'max': 1,
|
"max": 1,
|
||||||
'n': n+2+self.order}
|
"n": n + 2 + self.order,
|
||||||
|
}
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError("Knots and control points cannot be both None.")
|
||||||
"Knots and control points cannot be both None."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if self.knots.ndim != 1:
|
if self.knots.ndim != 1:
|
||||||
raise ValueError("Knot vector must be one-dimensional.")
|
raise ValueError("Knot vector must be one-dimensional.")
|
||||||
|
|
||||||
def basis(self, x, k, i, t):
|
def basis(self, x, k, i, t):
|
||||||
'''
|
"""
|
||||||
Recursive function to compute the basis functions of the spline.
|
Recursive function to compute the basis functions of the spline.
|
||||||
|
|
||||||
:param torch.Tensor x: points to be evaluated.
|
:param torch.Tensor x: points to be evaluated.
|
||||||
@@ -71,29 +71,33 @@ class Spline(torch.nn.Module):
|
|||||||
:param torch.Tensor t: vector of knots
|
:param torch.Tensor t: vector of knots
|
||||||
:return: the basis functions evaluated at x
|
:return: the basis functions evaluated at x
|
||||||
:rtype: torch.Tensor
|
:rtype: torch.Tensor
|
||||||
'''
|
"""
|
||||||
|
|
||||||
if k == 0:
|
if k == 0:
|
||||||
a = torch.where(torch.logical_and(t[i] <= x, x < t[i+1]), 1.0, 0.0)
|
a = torch.where(
|
||||||
|
torch.logical_and(t[i] <= x, x < t[i + 1]), 1.0, 0.0
|
||||||
|
)
|
||||||
if i == len(t) - self.order - 1:
|
if i == len(t) - self.order - 1:
|
||||||
a = torch.where(x == t[-1], 1.0, a)
|
a = torch.where(x == t[-1], 1.0, a)
|
||||||
a.requires_grad_(True)
|
a.requires_grad_(True)
|
||||||
return a
|
return a
|
||||||
|
|
||||||
|
if t[i + k] == t[i]:
|
||||||
if t[i+k] == t[i]:
|
c1 = torch.tensor([0.0] * len(x), requires_grad=True)
|
||||||
c1 = torch.tensor([0.0]*len(x), requires_grad=True)
|
|
||||||
else:
|
else:
|
||||||
c1 = (x - t[i])/(t[i+k] - t[i]) * self.basis(x, k-1, i, t)
|
c1 = (x - t[i]) / (t[i + k] - t[i]) * self.basis(x, k - 1, i, t)
|
||||||
|
|
||||||
if t[i+k+1] == t[i+1]:
|
if t[i + k + 1] == t[i + 1]:
|
||||||
c2 = torch.tensor([0.0]*len(x), requires_grad=True)
|
c2 = torch.tensor([0.0] * len(x), requires_grad=True)
|
||||||
else:
|
else:
|
||||||
c2 = (t[i+k+1] - x)/(t[i+k+1] - t[i+1]) * self.basis(x, k-1, i+1, t)
|
c2 = (
|
||||||
|
(t[i + k + 1] - x)
|
||||||
|
/ (t[i + k + 1] - t[i + 1])
|
||||||
|
* self.basis(x, k - 1, i + 1, t)
|
||||||
|
)
|
||||||
|
|
||||||
return c1 + c2
|
return c1 + c2
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def control_points(self):
|
def control_points(self):
|
||||||
return self._control_points
|
return self._control_points
|
||||||
@@ -101,14 +105,14 @@ class Spline(torch.nn.Module):
|
|||||||
@control_points.setter
|
@control_points.setter
|
||||||
def control_points(self, value):
|
def control_points(self, value):
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
if 'n' not in value:
|
if "n" not in value:
|
||||||
raise ValueError('Invalid value for control_points')
|
raise ValueError("Invalid value for control_points")
|
||||||
n = value['n']
|
n = value["n"]
|
||||||
dim = value.get('dim', 1)
|
dim = value.get("dim", 1)
|
||||||
value = torch.zeros(n, dim)
|
value = torch.zeros(n, dim)
|
||||||
|
|
||||||
if not isinstance(value, torch.Tensor):
|
if not isinstance(value, torch.Tensor):
|
||||||
raise ValueError('Invalid value for control_points')
|
raise ValueError("Invalid value for control_points")
|
||||||
self._control_points = torch.nn.Parameter(value, requires_grad=True)
|
self._control_points = torch.nn.Parameter(value, requires_grad=True)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -119,32 +123,28 @@ class Spline(torch.nn.Module):
|
|||||||
def knots(self, value):
|
def knots(self, value):
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
|
|
||||||
type_ = value.get('type', 'auto')
|
type_ = value.get("type", "auto")
|
||||||
min_ = value.get('min', 0)
|
min_ = value.get("min", 0)
|
||||||
max_ = value.get('max', 1)
|
max_ = value.get("max", 1)
|
||||||
n = value.get('n', 10)
|
n = value.get("n", 10)
|
||||||
|
|
||||||
if type_ == 'uniform':
|
if type_ == "uniform":
|
||||||
value = torch.linspace(min_, max_, n + self.k + 1)
|
value = torch.linspace(min_, max_, n + self.k + 1)
|
||||||
elif type_ == 'auto':
|
elif type_ == "auto":
|
||||||
initial_knots = torch.ones(self.order+1)*min_
|
initial_knots = torch.ones(self.order + 1) * min_
|
||||||
final_knots = torch.ones(self.order+1)*max_
|
final_knots = torch.ones(self.order + 1) * max_
|
||||||
|
|
||||||
if n < self.order + 1:
|
if n < self.order + 1:
|
||||||
value = torch.concatenate((initial_knots, final_knots))
|
value = torch.concatenate((initial_knots, final_knots))
|
||||||
elif n - 2*self.order + 1 == 1:
|
elif n - 2 * self.order + 1 == 1:
|
||||||
value = torch.Tensor([(max_ + min_)/2])
|
value = torch.Tensor([(max_ + min_) / 2])
|
||||||
else:
|
else:
|
||||||
value = torch.linspace(min_, max_, n - 2*self.order - 1)
|
value = torch.linspace(min_, max_, n - 2 * self.order - 1)
|
||||||
|
|
||||||
value = torch.concatenate(
|
value = torch.concatenate((initial_knots, value, final_knots))
|
||||||
(
|
|
||||||
initial_knots, value, final_knots
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if not isinstance(value, torch.Tensor):
|
if not isinstance(value, torch.Tensor):
|
||||||
raise ValueError('Invalid value for knots')
|
raise ValueError("Invalid value for knots")
|
||||||
|
|
||||||
self._knots = value
|
self._knots = value
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user