fix doc model part 1

This commit is contained in:
giovanni
2025-03-14 12:24:27 +01:00
committed by FilippoOlivo
parent cf2825241e
commit 10a22fee6f
10 changed files with 676 additions and 433 deletions

View File

@@ -1,19 +1,26 @@
"""Module for Spline model"""
"""Module for the Spline model class"""
import torch
from ..utils import check_consistency
class Spline(torch.nn.Module):
"""TODO: Docstring for Spline."""
"""
Spline model class.
"""
def __init__(self, order=4, knots=None, control_points=None) -> None:
"""
Spline model.
Initialization of the :class:`Spline` class.
:param int order: the order of the spline.
:param torch.Tensor knots: the knot vector.
:param torch.Tensor control_points: the control points.
:param int order: The order of the spline. Default is ``4``.
:param torch.Tensor knots: The tensor representing knots. If ``None``,
the knots will be initialized automatically. Default is ``None``.
:param torch.Tensor control_points: The control points. Default is
``None``.
:raises ValueError: If the order is negative.
:raises ValueError: If both knots and control points are ``None``.
:raises ValueError: If the knot tensor is not one-dimensional.
"""
super().__init__()
@@ -63,13 +70,13 @@ class Spline(torch.nn.Module):
def basis(self, x, k, i, t):
"""
Recursive function to compute the basis functions of the spline.
Recursive method to compute the basis functions of the spline.
:param torch.Tensor x: points to be evaluated.
:param int k: spline degree
:param int i: the index of the interval
:param torch.Tensor t: vector of knots
:return: the basis functions evaluated at x
:param torch.Tensor x: The points to be evaluated.
:param int k: The spline degree.
:param int i: The index of the interval.
:param torch.Tensor t: The tensor of knots.
:return: The basis functions evaluated at x
:rtype: torch.Tensor
"""
@@ -100,11 +107,23 @@ class Spline(torch.nn.Module):
@property
def control_points(self):
"""TODO: Docstring for control_points."""
"""
The control points of the spline.
:return: The control points.
:rtype: torch.Tensor
"""
return self._control_points
@control_points.setter
def control_points(self, value):
"""
Set the control points of the spline.
:param value: The control points.
:type value: torch.Tensor | dict
:raises ValueError: If invalid value is passed.
"""
if isinstance(value, dict):
if "n" not in value:
raise ValueError("Invalid value for control_points")
@@ -118,11 +137,23 @@ class Spline(torch.nn.Module):
@property
def knots(self):
"""TODO: Docstring for knots."""
"""
The knots of the spline.
:return: The knots.
:rtype: torch.Tensor
"""
return self._knots
@knots.setter
def knots(self, value):
"""
Set the knots of the spline.
:param value: The knots.
:type value: torch.Tensor | dict
:raises ValueError: If invalid value is passed.
"""
if isinstance(value, dict):
type_ = value.get("type", "auto")
@@ -152,10 +183,10 @@ class Spline(torch.nn.Module):
def forward(self, x):
"""
Forward pass of the spline model.
Forward pass for the :class:`Spline` model.
:param torch.Tensor x: points to be evaluated.
:return: the spline evaluated at x
:param torch.Tensor x: The input tensor.
:return: The output tensor.
:rtype: torch.Tensor
"""
t = self.knots