fix doc model part 2

This commit is contained in:
giovanni
2025-03-14 16:07:08 +01:00
committed by FilippoOlivo
parent 194f5d24c4
commit 28d24f3f41
18 changed files with 887 additions and 851 deletions

View File

@@ -1,4 +1,4 @@
"""Module for Averaging Neural Operator Layer class."""
"""Module for the Low Rank Neural Operator Block class."""
import torch
@@ -6,30 +6,8 @@ from ...utils import check_consistency
class LowRankBlock(torch.nn.Module):
r"""
The PINA implementation of the inner layer of the Averaging Neural Operator.
The operator layer performs an affine transformation where the convolution
is approximated with a local average. Given the input function
:math:`v(x)\in\mathbb{R}^{\rm{emb}}` the layer computes
the operator update :math:`K(v)` as:
.. math::
K(v) = \sigma\left(Wv(x) + b + \sum_{i=1}^r \langle
\psi^{(i)} , v(x) \rangle \phi^{(i)} \right)
where:
* :math:`\mathbb{R}^{\rm{emb}}` is the embedding (hidden) size
corresponding to the ``hidden_size`` object
* :math:`\sigma` is a non-linear activation, corresponding to the
``func`` object
* :math:`W\in\mathbb{R}^{\rm{emb}\times\rm{emb}}` is a tunable matrix.
* :math:`b\in\mathbb{R}^{\rm{emb}}` is a tunable bias.
* :math:`\psi^{(i)}\in\mathbb{R}^{\rm{emb}}` and
:math:`\phi^{(i)}\in\mathbb{R}^{\rm{emb}}` are :math:`r` a low rank
basis functions mapping.
* :math:`b\in\mathbb{R}^{\rm{emb}}` is a tunable bias.
"""
The inner block of the Low Rank Neural Operator.
.. seealso::
@@ -38,7 +16,6 @@ class LowRankBlock(torch.nn.Module):
(2023). *Neural operator: Learning maps between function
spaces with applications to PDEs*. Journal of Machine Learning
Research, 24(89), 1-97.
"""
def __init__(
@@ -51,30 +28,25 @@ class LowRankBlock(torch.nn.Module):
func=torch.nn.Tanh,
bias=True,
):
"""
:param int input_dimensions: The number of input components of the
model.
Expected tensor shape of the form :math:`(*, d)`, where *
means any number of dimensions including none,
and :math:`d` the ``input_dimensions``.
:param int embedding_dimenion: Size of the embedding dimension of the
field.
:param int rank: The rank number of the basis approximation components
of the model. Expected tensor shape of the form :math:`(*, 2d)`,
where * means any number of dimensions including none,
and :math:`2d` the ``rank`` for both basis functions.
:param int inner_size: Number of neurons in the hidden layer(s) for the
basis function network. Default is 20.
:param int n_layers: Number of hidden layers. for the
basis function network. Default is 2.
:param func: The activation function to use for the
basis function network. If a single
:class:`torch.nn.Module` is passed, this is used as
activation function after any layers, except the last one.
If a list of Modules is passed,
they are used as activation functions at any layers, in order.
:param bool bias: If ``True`` the MLP will consider some bias for the
basis function network.
r"""
Initialization of the :class:`LowRankBlock` class.
:param int input_dimensions: The input dimension of the field.
:param int embedding_dimenion: The embedding dimension of the field.
:param int rank: The rank of the low rank approximation. The expected
value is :math:`2d`, where :math:`d` is the rank of each basis
function.
:param int inner_size: The number of neurons for each hidden layer in
the basis function neural network. Default is ``20``.
:param int n_layers: The number of hidden layers in the basis function
neural network. Default is ``2``.
:param func: The activation function. If a list is passed, it must have
the same length as ``n_layers``. If a single function is passed, it
is used for all layers, except for the last one.
Default is :class:`torch.nn.Tanh`.
:type func: torch.nn.Module | list[torch.nn.Module]
:param bool bias: If ``True`` bias is considered for the basis function
neural network. Default is ``True``.
"""
super().__init__()
from ..feed_forward import FeedForward
@@ -96,26 +68,16 @@ class LowRankBlock(torch.nn.Module):
def forward(self, x, coords):
r"""
Forward pass of the layer, it performs an affine transformation of
the field, and a low rank approximation by
doing a dot product of the basis
:math:`\psi^{(i)}` with the filed vector :math:`v`, and use this
coefficients to expand :math:`\phi^{(i)}` evaluated in the
spatial input :math:`x`.
Forward pass of the block. It performs an affine transformation of the
field, followed by a low rank approximation. The latter is performed by
means of a dot product of the basis :math:`\psi^{(i)}` with the vector
field :math:`v` to compute coefficients used to expand
:math:`\phi^{(i)}`, evaluated in the spatial input :math:`x`.
:param torch.Tensor x: The input tensor for performing the
computation. It expects a tensor :math:`B \times N \times D`,
where :math:`B` is the batch_size, :math:`N` the number of points
in the mesh, :math:`D` the dimension of the problem. In particular
:math:`D` is the codomain of the function :math:`v`. For example
a scalar function has :math:`D=1`, a 4-dimensional vector function
:math:`D=4`.
:param torch.Tensor coords: The coordinates in which the field is
evaluated for performing the computation. It expects a
tensor :math:`B \times N \times d`,
where :math:`B` is the batch_size, :math:`N` the number of points
in the mesh, :math:`D` the dimension of the domain.
:return: The output tensor obtained from Average Neural Operator Block.
:param torch.Tensor x: The input tensor for performing the computation.
:param torch.Tensor coords: The coordinates for which the field is
evaluated to perform the computation.
:return: The output tensor.
:rtype: torch.Tensor
"""
# extract basis
@@ -138,5 +100,8 @@ class LowRankBlock(torch.nn.Module):
def rank(self):
"""
The basis rank.
:return: The basis rank.
:rtype: int
"""
return self._rank