fix doc model part 2
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
"""Module for Averaging Neural Operator Layer class."""
|
||||
"""Module for the Low Rank Neural Operator Block class."""
|
||||
|
||||
import torch
|
||||
|
||||
@@ -6,30 +6,8 @@ from ...utils import check_consistency
|
||||
|
||||
|
||||
class LowRankBlock(torch.nn.Module):
|
||||
r"""
|
||||
The PINA implementation of the inner layer of the Averaging Neural Operator.
|
||||
|
||||
The operator layer performs an affine transformation where the convolution
|
||||
is approximated with a local average. Given the input function
|
||||
:math:`v(x)\in\mathbb{R}^{\rm{emb}}` the layer computes
|
||||
the operator update :math:`K(v)` as:
|
||||
|
||||
.. math::
|
||||
K(v) = \sigma\left(Wv(x) + b + \sum_{i=1}^r \langle
|
||||
\psi^{(i)} , v(x) \rangle \phi^{(i)} \right)
|
||||
|
||||
where:
|
||||
|
||||
* :math:`\mathbb{R}^{\rm{emb}}` is the embedding (hidden) size
|
||||
corresponding to the ``hidden_size`` object
|
||||
* :math:`\sigma` is a non-linear activation, corresponding to the
|
||||
``func`` object
|
||||
* :math:`W\in\mathbb{R}^{\rm{emb}\times\rm{emb}}` is a tunable matrix.
|
||||
* :math:`b\in\mathbb{R}^{\rm{emb}}` is a tunable bias.
|
||||
* :math:`\psi^{(i)}\in\mathbb{R}^{\rm{emb}}` and
|
||||
:math:`\phi^{(i)}\in\mathbb{R}^{\rm{emb}}` are :math:`r` a low rank
|
||||
basis functions mapping.
|
||||
* :math:`b\in\mathbb{R}^{\rm{emb}}` is a tunable bias.
|
||||
"""
|
||||
The inner block of the Low Rank Neural Operator.
|
||||
|
||||
.. seealso::
|
||||
|
||||
@@ -38,7 +16,6 @@ class LowRankBlock(torch.nn.Module):
|
||||
(2023). *Neural operator: Learning maps between function
|
||||
spaces with applications to PDEs*. Journal of Machine Learning
|
||||
Research, 24(89), 1-97.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -51,30 +28,25 @@ class LowRankBlock(torch.nn.Module):
|
||||
func=torch.nn.Tanh,
|
||||
bias=True,
|
||||
):
|
||||
"""
|
||||
:param int input_dimensions: The number of input components of the
|
||||
model.
|
||||
Expected tensor shape of the form :math:`(*, d)`, where *
|
||||
means any number of dimensions including none,
|
||||
and :math:`d` the ``input_dimensions``.
|
||||
:param int embedding_dimenion: Size of the embedding dimension of the
|
||||
field.
|
||||
:param int rank: The rank number of the basis approximation components
|
||||
of the model. Expected tensor shape of the form :math:`(*, 2d)`,
|
||||
where * means any number of dimensions including none,
|
||||
and :math:`2d` the ``rank`` for both basis functions.
|
||||
:param int inner_size: Number of neurons in the hidden layer(s) for the
|
||||
basis function network. Default is 20.
|
||||
:param int n_layers: Number of hidden layers. for the
|
||||
basis function network. Default is 2.
|
||||
:param func: The activation function to use for the
|
||||
basis function network. If a single
|
||||
:class:`torch.nn.Module` is passed, this is used as
|
||||
activation function after any layers, except the last one.
|
||||
If a list of Modules is passed,
|
||||
they are used as activation functions at any layers, in order.
|
||||
:param bool bias: If ``True`` the MLP will consider some bias for the
|
||||
basis function network.
|
||||
r"""
|
||||
Initialization of the :class:`LowRankBlock` class.
|
||||
|
||||
:param int input_dimensions: The input dimension of the field.
|
||||
:param int embedding_dimenion: The embedding dimension of the field.
|
||||
:param int rank: The rank of the low rank approximation. The expected
|
||||
value is :math:`2d`, where :math:`d` is the rank of each basis
|
||||
function.
|
||||
:param int inner_size: The number of neurons for each hidden layer in
|
||||
the basis function neural network. Default is ``20``.
|
||||
:param int n_layers: The number of hidden layers in the basis function
|
||||
neural network. Default is ``2``.
|
||||
:param func: The activation function. If a list is passed, it must have
|
||||
the same length as ``n_layers``. If a single function is passed, it
|
||||
is used for all layers, except for the last one.
|
||||
Default is :class:`torch.nn.Tanh`.
|
||||
:type func: torch.nn.Module | list[torch.nn.Module]
|
||||
:param bool bias: If ``True`` bias is considered for the basis function
|
||||
neural network. Default is ``True``.
|
||||
"""
|
||||
super().__init__()
|
||||
from ..feed_forward import FeedForward
|
||||
@@ -96,26 +68,16 @@ class LowRankBlock(torch.nn.Module):
|
||||
|
||||
def forward(self, x, coords):
|
||||
r"""
|
||||
Forward pass of the layer, it performs an affine transformation of
|
||||
the field, and a low rank approximation by
|
||||
doing a dot product of the basis
|
||||
:math:`\psi^{(i)}` with the filed vector :math:`v`, and use this
|
||||
coefficients to expand :math:`\phi^{(i)}` evaluated in the
|
||||
spatial input :math:`x`.
|
||||
Forward pass of the block. It performs an affine transformation of the
|
||||
field, followed by a low rank approximation. The latter is performed by
|
||||
means of a dot product of the basis :math:`\psi^{(i)}` with the vector
|
||||
field :math:`v` to compute coefficients used to expand
|
||||
:math:`\phi^{(i)}`, evaluated in the spatial input :math:`x`.
|
||||
|
||||
:param torch.Tensor x: The input tensor for performing the
|
||||
computation. It expects a tensor :math:`B \times N \times D`,
|
||||
where :math:`B` is the batch_size, :math:`N` the number of points
|
||||
in the mesh, :math:`D` the dimension of the problem. In particular
|
||||
:math:`D` is the codomain of the function :math:`v`. For example
|
||||
a scalar function has :math:`D=1`, a 4-dimensional vector function
|
||||
:math:`D=4`.
|
||||
:param torch.Tensor coords: The coordinates in which the field is
|
||||
evaluated for performing the computation. It expects a
|
||||
tensor :math:`B \times N \times d`,
|
||||
where :math:`B` is the batch_size, :math:`N` the number of points
|
||||
in the mesh, :math:`D` the dimension of the domain.
|
||||
:return: The output tensor obtained from Average Neural Operator Block.
|
||||
:param torch.Tensor x: The input tensor for performing the computation.
|
||||
:param torch.Tensor coords: The coordinates for which the field is
|
||||
evaluated to perform the computation.
|
||||
:return: The output tensor.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
# extract basis
|
||||
@@ -138,5 +100,8 @@ class LowRankBlock(torch.nn.Module):
|
||||
def rank(self):
|
||||
"""
|
||||
The basis rank.
|
||||
|
||||
:return: The basis rank.
|
||||
:rtype: int
|
||||
"""
|
||||
return self._rank
|
||||
|
||||
Reference in New Issue
Block a user