diff --git a/pina/model/spline.py b/pina/model/spline.py index 2c5aa6e..2328986 100644 --- a/pina/model/spline.py +++ b/pina/model/spline.py @@ -148,19 +148,19 @@ class Spline(torch.nn.Module): self._knots = value - def forward(self, x_): + def forward(self, x): """ Forward pass of the spline model. - :param torch.Tensor x_: points to be evaluated. - :return: the spline evaluated at x_ + :param torch.Tensor x: points to be evaluated. + :return: the spline evaluated at x :rtype: torch.Tensor """ t = self.knots k = self.k c = self.control_points - basis = map(lambda i: self.basis(x_, k, i, t)[:, None], range(len(c))) + basis = map(lambda i: self.basis(x, k, i, t)[:, None], range(len(c))) y = (torch.cat(list(basis), dim=1) * c).sum(axis=1) return y diff --git a/setup.py b/setup.py index b1909a6..8c7b9ac 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ REQUIRED = [ EXTRAS = { 'docs': [ - 'sphinx', + 'sphinx>5.0', 'sphinx_rtd_theme', 'sphinx_copybutton', 'sphinx_design',