doc update

This commit is contained in:
Dario Coscia
2024-10-01 15:54:55 +02:00
committed by Nicola Demo
parent b3ffca3f11
commit 24317b6fa7
2 changed files with 5 additions and 5 deletions

View File

@@ -148,19 +148,19 @@ class Spline(torch.nn.Module):
self._knots = value
def forward(self, x_):
def forward(self, x):
"""
Forward pass of the spline model.
:param torch.Tensor x_: points to be evaluated.
:return: the spline evaluated at x_
:param torch.Tensor x: points to be evaluated.
:return: the spline evaluated at x
:rtype: torch.Tensor
"""
t = self.knots
k = self.k
c = self.control_points
basis = map(lambda i: self.basis(x_, k, i, t)[:, None], range(len(c)))
basis = map(lambda i: self.basis(x, k, i, t)[:, None], range(len(c)))
y = (torch.cat(list(basis), dim=1) * c).sum(axis=1)
return y