doc update
This commit is contained in:
committed by
Nicola Demo
parent
b3ffca3f11
commit
24317b6fa7
@@ -148,19 +148,19 @@ class Spline(torch.nn.Module):
|
|||||||
|
|
||||||
self._knots = value
|
self._knots = value
|
||||||
|
|
||||||
def forward(self, x_):
|
def forward(self, x):
|
||||||
"""
|
"""
|
||||||
Forward pass of the spline model.
|
Forward pass of the spline model.
|
||||||
|
|
||||||
:param torch.Tensor x_: points to be evaluated.
|
:param torch.Tensor x: points to be evaluated.
|
||||||
:return: the spline evaluated at x_
|
:return: the spline evaluated at x
|
||||||
:rtype: torch.Tensor
|
:rtype: torch.Tensor
|
||||||
"""
|
"""
|
||||||
t = self.knots
|
t = self.knots
|
||||||
k = self.k
|
k = self.k
|
||||||
c = self.control_points
|
c = self.control_points
|
||||||
|
|
||||||
basis = map(lambda i: self.basis(x_, k, i, t)[:, None], range(len(c)))
|
basis = map(lambda i: self.basis(x, k, i, t)[:, None], range(len(c)))
|
||||||
y = (torch.cat(list(basis), dim=1) * c).sum(axis=1)
|
y = (torch.cat(list(basis), dim=1) * c).sum(axis=1)
|
||||||
|
|
||||||
return y
|
return y
|
||||||
|
|||||||
Reference in New Issue
Block a user