doc update
This commit is contained in:
committed by
Nicola Demo
parent
b3ffca3f11
commit
24317b6fa7
@@ -148,19 +148,19 @@ class Spline(torch.nn.Module):
|
||||
|
||||
self._knots = value
|
||||
|
||||
def forward(self, x_):
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass of the spline model.
|
||||
|
||||
:param torch.Tensor x_: points to be evaluated.
|
||||
:return: the spline evaluated at x_
|
||||
:param torch.Tensor x: points to be evaluated.
|
||||
:return: the spline evaluated at x
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
t = self.knots
|
||||
k = self.k
|
||||
c = self.control_points
|
||||
|
||||
basis = map(lambda i: self.basis(x_, k, i, t)[:, None], range(len(c)))
|
||||
basis = map(lambda i: self.basis(x, k, i, t)[:, None], range(len(c)))
|
||||
y = (torch.cat(list(basis), dim=1) * c).sum(axis=1)
|
||||
|
||||
return y
|
||||
|
||||
Reference in New Issue
Block a user