🎨 Format Python code with psf/black (#348)

This commit is contained in:
github-actions[bot]
2024-09-27 10:35:37 +02:00
committed by GitHub
parent 4c5cb8f681
commit f812d87727
2 changed files with 60 additions and 60 deletions

View File

@@ -4,6 +4,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from ..utils import check_consistency from ..utils import check_consistency
class Spline(torch.nn.Module): class Spline(torch.nn.Module):
def __init__(self, order=4, knots=None, control_points=None) -> None: def __init__(self, order=4, knots=None, control_points=None) -> None:
@@ -31,38 +32,37 @@ class Spline(torch.nn.Module):
self.control_points = control_points self.control_points = control_points
elif knots is not None: elif knots is not None:
print('Warning: control points will be initialized automatically.') print("Warning: control points will be initialized automatically.")
print(' experimental feature') print(" experimental feature")
self.knots = knots self.knots = knots
n = len(knots) - order n = len(knots) - order
self.control_points = torch.nn.Parameter( self.control_points = torch.nn.Parameter(
torch.zeros(n), requires_grad=True) torch.zeros(n), requires_grad=True
)
elif control_points is not None: elif control_points is not None:
print('Warning: knots will be initialized automatically.') print("Warning: knots will be initialized automatically.")
print(' experimental feature') print(" experimental feature")
self.control_points = control_points self.control_points = control_points
n = len(self.control_points)-1 n = len(self.control_points) - 1
self.knots = { self.knots = {
'type': 'auto', "type": "auto",
'min': 0, "min": 0,
'max': 1, "max": 1,
'n': n+2+self.order} "n": n + 2 + self.order,
}
else: else:
raise ValueError( raise ValueError("Knots and control points cannot be both None.")
"Knots and control points cannot be both None."
)
if self.knots.ndim != 1: if self.knots.ndim != 1:
raise ValueError("Knot vector must be one-dimensional.") raise ValueError("Knot vector must be one-dimensional.")
def basis(self, x, k, i, t): def basis(self, x, k, i, t):
''' """
Recursive function to compute the basis functions of the spline. Recursive function to compute the basis functions of the spline.
:param torch.Tensor x: points to be evaluated. :param torch.Tensor x: points to be evaluated.
@@ -71,29 +71,33 @@ class Spline(torch.nn.Module):
:param torch.Tensor t: vector of knots :param torch.Tensor t: vector of knots
:return: the basis functions evaluated at x :return: the basis functions evaluated at x
:rtype: torch.Tensor :rtype: torch.Tensor
''' """
if k == 0: if k == 0:
a = torch.where(torch.logical_and(t[i] <= x, x < t[i+1]), 1.0, 0.0) a = torch.where(
torch.logical_and(t[i] <= x, x < t[i + 1]), 1.0, 0.0
)
if i == len(t) - self.order - 1: if i == len(t) - self.order - 1:
a = torch.where(x == t[-1], 1.0, a) a = torch.where(x == t[-1], 1.0, a)
a.requires_grad_(True) a.requires_grad_(True)
return a return a
if t[i + k] == t[i]:
if t[i+k] == t[i]: c1 = torch.tensor([0.0] * len(x), requires_grad=True)
c1 = torch.tensor([0.0]*len(x), requires_grad=True)
else: else:
c1 = (x - t[i])/(t[i+k] - t[i]) * self.basis(x, k-1, i, t) c1 = (x - t[i]) / (t[i + k] - t[i]) * self.basis(x, k - 1, i, t)
if t[i+k+1] == t[i+1]: if t[i + k + 1] == t[i + 1]:
c2 = torch.tensor([0.0]*len(x), requires_grad=True) c2 = torch.tensor([0.0] * len(x), requires_grad=True)
else: else:
c2 = (t[i+k+1] - x)/(t[i+k+1] - t[i+1]) * self.basis(x, k-1, i+1, t) c2 = (
(t[i + k + 1] - x)
/ (t[i + k + 1] - t[i + 1])
* self.basis(x, k - 1, i + 1, t)
)
return c1 + c2 return c1 + c2
@property @property
def control_points(self): def control_points(self):
return self._control_points return self._control_points
@@ -101,14 +105,14 @@ class Spline(torch.nn.Module):
@control_points.setter @control_points.setter
def control_points(self, value): def control_points(self, value):
if isinstance(value, dict): if isinstance(value, dict):
if 'n' not in value: if "n" not in value:
raise ValueError('Invalid value for control_points') raise ValueError("Invalid value for control_points")
n = value['n'] n = value["n"]
dim = value.get('dim', 1) dim = value.get("dim", 1)
value = torch.zeros(n, dim) value = torch.zeros(n, dim)
if not isinstance(value, torch.Tensor): if not isinstance(value, torch.Tensor):
raise ValueError('Invalid value for control_points') raise ValueError("Invalid value for control_points")
self._control_points = torch.nn.Parameter(value, requires_grad=True) self._control_points = torch.nn.Parameter(value, requires_grad=True)
@property @property
@@ -119,32 +123,28 @@ class Spline(torch.nn.Module):
def knots(self, value): def knots(self, value):
if isinstance(value, dict): if isinstance(value, dict):
type_ = value.get('type', 'auto') type_ = value.get("type", "auto")
min_ = value.get('min', 0) min_ = value.get("min", 0)
max_ = value.get('max', 1) max_ = value.get("max", 1)
n = value.get('n', 10) n = value.get("n", 10)
if type_ == 'uniform': if type_ == "uniform":
value = torch.linspace(min_, max_, n + self.k + 1) value = torch.linspace(min_, max_, n + self.k + 1)
elif type_ == 'auto': elif type_ == "auto":
initial_knots = torch.ones(self.order+1)*min_ initial_knots = torch.ones(self.order + 1) * min_
final_knots = torch.ones(self.order+1)*max_ final_knots = torch.ones(self.order + 1) * max_
if n < self.order + 1: if n < self.order + 1:
value = torch.concatenate((initial_knots, final_knots)) value = torch.concatenate((initial_knots, final_knots))
elif n - 2*self.order + 1 == 1: elif n - 2 * self.order + 1 == 1:
value = torch.Tensor([(max_ + min_)/2]) value = torch.Tensor([(max_ + min_) / 2])
else: else:
value = torch.linspace(min_, max_, n - 2*self.order - 1) value = torch.linspace(min_, max_, n - 2 * self.order - 1)
value = torch.concatenate( value = torch.concatenate((initial_knots, value, final_knots))
(
initial_knots, value, final_knots
)
)
if not isinstance(value, torch.Tensor): if not isinstance(value, torch.Tensor):
raise ValueError('Invalid value for knots') raise ValueError("Invalid value for knots")
self._knots = value self._knots = value