fix doc model part 1
This commit is contained in:
@@ -1,19 +1,26 @@
|
||||
"""Module for Spline model"""
|
||||
"""Module for the Spline model class"""
|
||||
|
||||
import torch
|
||||
from ..utils import check_consistency
|
||||
|
||||
|
||||
class Spline(torch.nn.Module):
|
||||
"""TODO: Docstring for Spline."""
|
||||
"""
|
||||
Spline model class.
|
||||
"""
|
||||
|
||||
def __init__(self, order=4, knots=None, control_points=None) -> None:
|
||||
"""
|
||||
Spline model.
|
||||
Initialization of the :class:`Spline` class.
|
||||
|
||||
:param int order: the order of the spline.
|
||||
:param torch.Tensor knots: the knot vector.
|
||||
:param torch.Tensor control_points: the control points.
|
||||
:param int order: The order of the spline. Default is ``4``.
|
||||
:param torch.Tensor knots: The tensor representing knots. If ``None``,
|
||||
the knots will be initialized automatically. Default is ``None``.
|
||||
:param torch.Tensor control_points: The control points. Default is
|
||||
``None``.
|
||||
:raises ValueError: If the order is negative.
|
||||
:raises ValueError: If both knots and control points are ``None``.
|
||||
:raises ValueError: If the knot tensor is not one-dimensional.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@@ -63,13 +70,13 @@ class Spline(torch.nn.Module):
|
||||
|
||||
def basis(self, x, k, i, t):
|
||||
"""
|
||||
Recursive function to compute the basis functions of the spline.
|
||||
Recursive method to compute the basis functions of the spline.
|
||||
|
||||
:param torch.Tensor x: points to be evaluated.
|
||||
:param int k: spline degree
|
||||
:param int i: the index of the interval
|
||||
:param torch.Tensor t: vector of knots
|
||||
:return: the basis functions evaluated at x
|
||||
:param torch.Tensor x: The points to be evaluated.
|
||||
:param int k: The spline degree.
|
||||
:param int i: The index of the interval.
|
||||
:param torch.Tensor t: The tensor of knots.
|
||||
:return: The basis functions evaluated at x
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
|
||||
@@ -100,11 +107,23 @@ class Spline(torch.nn.Module):
|
||||
|
||||
@property
|
||||
def control_points(self):
|
||||
"""TODO: Docstring for control_points."""
|
||||
"""
|
||||
The control points of the spline.
|
||||
|
||||
:return: The control points.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return self._control_points
|
||||
|
||||
@control_points.setter
|
||||
def control_points(self, value):
|
||||
"""
|
||||
Set the control points of the spline.
|
||||
|
||||
:param value: The control points.
|
||||
:type value: torch.Tensor | dict
|
||||
:raises ValueError: If invalid value is passed.
|
||||
"""
|
||||
if isinstance(value, dict):
|
||||
if "n" not in value:
|
||||
raise ValueError("Invalid value for control_points")
|
||||
@@ -118,11 +137,23 @@ class Spline(torch.nn.Module):
|
||||
|
||||
@property
|
||||
def knots(self):
|
||||
"""TODO: Docstring for knots."""
|
||||
"""
|
||||
The knots of the spline.
|
||||
|
||||
:return: The knots.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return self._knots
|
||||
|
||||
@knots.setter
|
||||
def knots(self, value):
|
||||
"""
|
||||
Set the knots of the spline.
|
||||
|
||||
:param value: The knots.
|
||||
:type value: torch.Tensor | dict
|
||||
:raises ValueError: If invalid value is passed.
|
||||
"""
|
||||
if isinstance(value, dict):
|
||||
|
||||
type_ = value.get("type", "auto")
|
||||
@@ -152,10 +183,10 @@ class Spline(torch.nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass of the spline model.
|
||||
Forward pass for the :class:`Spline` model.
|
||||
|
||||
:param torch.Tensor x: points to be evaluated.
|
||||
:return: the spline evaluated at x
|
||||
:param torch.Tensor x: The input tensor.
|
||||
:return: The output tensor.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
t = self.knots
|
||||
|
||||
Reference in New Issue
Block a user