fix doc model part 1
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
Kernel Neural Operator Module.
|
||||
Module for the Kernel Neural Operator model class.
|
||||
"""
|
||||
|
||||
import torch
|
||||
@@ -8,13 +8,14 @@ from ..utils import check_consistency
|
||||
|
||||
class KernelNeuralOperator(torch.nn.Module):
|
||||
r"""
|
||||
Base class for composing Neural Operators with integral kernels.
|
||||
Base class for Neural Operators with integral kernels.
|
||||
|
||||
This is a base class for composing neural operator with multiple
|
||||
integral kernels. All neural operator models defined in PINA inherit
|
||||
from this class. The structure is inspired by the work of Kovachki, N.
|
||||
et al. see Figure 2 of the reference for extra details. The Neural
|
||||
Operators inheriting from this class can be written as:
|
||||
This class serves as a foundation for building Neural Operators that
|
||||
incorporate multiple integral kernels. All Neural Operator models in
|
||||
PINA inherit from this class. The design follows the framework proposed
|
||||
by Kovachki et al., as illustrated in Figure 2 of their work.
|
||||
|
||||
Neural Operators derived from this class can be expressed as:
|
||||
|
||||
.. math::
|
||||
G_\theta := P \circ K_m \circ \cdot \circ K_1 \circ L
|
||||
@@ -40,15 +41,18 @@ class KernelNeuralOperator(torch.nn.Module):
|
||||
|
||||
**Original reference**: Kovachki, N., Li, Z., Liu, B.,
|
||||
Azizzadenesheli, K., Bhattacharya, K., Stuart, A., & Anandkumar, A.
|
||||
(2023). *Neural operator: Learning maps between function
|
||||
spaces with applications to PDEs*. Journal of Machine Learning
|
||||
Research, 24(89), 1-97.
|
||||
(2023).
|
||||
*Neural operator: Learning maps between function spaces with
|
||||
applications to PDEs*.
|
||||
Journal of Machine Learning Research, 24(89), 1-97.
|
||||
"""
|
||||
|
||||
def __init__(self, lifting_operator, integral_kernels, projection_operator):
|
||||
"""
|
||||
:param torch.nn.Module lifting_operator: The lifting operator
|
||||
mapping the input to its hidden dimension.
|
||||
Initialization of the :class:`KernelNeuralOperator` class.
|
||||
|
||||
:param torch.nn.Module lifting_operator: The lifting operator mapping
|
||||
the input to its hidden dimension.
|
||||
:param torch.nn.Module integral_kernels: List of integral kernels
|
||||
mapping each hidden representation to the next one.
|
||||
:param torch.nn.Module projection_operator: The projection operator
|
||||
@@ -64,16 +68,19 @@ class KernelNeuralOperator(torch.nn.Module):
|
||||
@property
|
||||
def lifting_operator(self):
|
||||
"""
|
||||
The lifting operator property.
|
||||
The lifting operator module.
|
||||
|
||||
:return: The lifting operator module.
|
||||
:rtype: torch.nn.Module
|
||||
"""
|
||||
return self._lifting_operator
|
||||
|
||||
@lifting_operator.setter
|
||||
def lifting_operator(self, value):
|
||||
"""
|
||||
The lifting operator setter
|
||||
Set the lifting operator module.
|
||||
|
||||
:param torch.nn.Module value: The lifting operator torch module.
|
||||
:param torch.nn.Module value: The lifting operator module.
|
||||
"""
|
||||
check_consistency(value, torch.nn.Module)
|
||||
self._lifting_operator = value
|
||||
@@ -81,16 +88,19 @@ class KernelNeuralOperator(torch.nn.Module):
|
||||
@property
|
||||
def projection_operator(self):
|
||||
"""
|
||||
The projection operator property.
|
||||
The projection operator module.
|
||||
|
||||
:return: The projection operator module.
|
||||
:rtype: torch.nn.Module
|
||||
"""
|
||||
return self._projection_operator
|
||||
|
||||
@projection_operator.setter
|
||||
def projection_operator(self, value):
|
||||
"""
|
||||
The projection operator setter
|
||||
Set the projection operator module.
|
||||
|
||||
:param torch.nn.Module value: The projection operator torch module.
|
||||
:param torch.nn.Module value: The projection operator module.
|
||||
"""
|
||||
check_consistency(value, torch.nn.Module)
|
||||
self._projection_operator = value
|
||||
@@ -98,37 +108,41 @@ class KernelNeuralOperator(torch.nn.Module):
|
||||
@property
|
||||
def integral_kernels(self):
|
||||
"""
|
||||
The integral kernels operator property.
|
||||
The integral kernels operator module.
|
||||
|
||||
:return: The integral kernels operator module.
|
||||
:rtype: torch.nn.Module
|
||||
"""
|
||||
return self._integral_kernels
|
||||
|
||||
@integral_kernels.setter
|
||||
def integral_kernels(self, value):
|
||||
"""
|
||||
The integral kernels operator setter
|
||||
Set the integral kernels operator module.
|
||||
|
||||
:param torch.nn.Module value: The integral kernels operator torch
|
||||
module.
|
||||
:param torch.nn.Module value: The integral kernels operator module.
|
||||
"""
|
||||
check_consistency(value, torch.nn.Module)
|
||||
self._integral_kernels = value
|
||||
|
||||
def forward(self, x):
|
||||
r"""
|
||||
Forward computation for Base Neural Operator. It performs a
|
||||
lifting of the input by the ``lifting_operator``.
|
||||
Then different layers integral kernels are applied using
|
||||
``integral_kernels``. Finally the output is projected
|
||||
to the final dimensionality by the ``projection_operator``.
|
||||
Forward pass for the :class:`KernelNeuralOperator` model.
|
||||
|
||||
:param torch.Tensor x: The input tensor for performing the
|
||||
computation. It expects a tensor :math:`B \times N \times D`,
|
||||
where :math:`B` is the batch_size, :math:`N` the number of points
|
||||
in the mesh, :math:`D` the dimension of the problem. In particular
|
||||
:math:`D` is the number of spatial/paramtric/temporal variables
|
||||
plus the field variables. For example for 2D problems with 2
|
||||
output\ variables :math:`D=4`.
|
||||
:return: The output tensor obtained from the NO.
|
||||
The ``lifting_operator`` maps the input to the hidden dimension.
|
||||
The ``integral_kernels`` apply the integral kernels to the hidden
|
||||
representation. The ``projection_operator`` maps the hidden
|
||||
representation to the output function.
|
||||
|
||||
:param x: The input tensor for performing the computation. It expects
|
||||
a tensor :math:`B \times N \times D`, where :math:`B` is the
|
||||
batch_size, :math:`N` the number of points in the mesh, and
|
||||
:math:`D` the dimension of the problem. In particular, :math:`D`
|
||||
is the number of spatial, parametric, and/or temporal variables
|
||||
plus the field variables. For instance, for 2D problems with 2
|
||||
output variables, :math:`D=4`.
|
||||
:type x: torch.Tensor | LabelTensor
|
||||
:return: The output tensor.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
x = self.lifting_operator(x)
|
||||
|
||||
Reference in New Issue
Block a user