Add Averaging Neural Operator with tests and a tutorial (#230)

* add Averaging Neural Operator with tests
* add backward test
* minor changes
* doc addition

---------

Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local>
This commit is contained in:
guglielmopadula
2024-03-05 12:30:53 +01:00
committed by GitHub
parent b10e02103b
commit 43f69242ab
8 changed files with 254 additions and 1 deletions

View File

@@ -7,6 +7,7 @@ __all__ = [
"FNO",
"FourierIntegralKernel",
"KernelNeuralOperator",
"AveragingNeuralOperator",
]
from .feed_forward import FeedForward, ResidualFeedForward
@@ -14,3 +15,4 @@ from .multi_feed_forward import MultiFeedForward
from .deeponet import DeepONet, MIONet
from .fno import FNO, FourierIntegralKernel
from .base_no import KernelNeuralOperator
from .avno import AveragingNeuralOperator

104
pina/model/avno.py Normal file
View File

@@ -0,0 +1,104 @@
"""Module Averaging Neural Operator."""
from torch import nn, concatenate
from . import FeedForward
from .layers import AVNOBlock
from .base_no import KernelNeuralOperator
from pina.utils import check_consistency
class AveragingNeuralOperator(KernelNeuralOperator):
"""
Implementation of Averaging Neural Operator.
Averaging Neural Operator is a general architecture for
learning Operators. Unlike traditional machine learning methods
AveragingNeuralOperator is designed to map entire functions
to other functions. It can be trained with Supervised learning strategies.
AveragingNeuralOperator does convolution by performing a field average.
.. seealso::
**Original reference**: Lanthaler S. Li, Z., Kovachki,
Stuart, A. (2020). *The Nonlocal Neural Operator:
Universal Approximation*.
DOI: `arXiv preprint arXiv:2304.13221.
<https://arxiv.org/abs/2304.13221>`_
"""
def __init__(
self,
input_numb_fields,
output_numb_fields,
field_indices,
coordinates_indices,
dimension=3,
inner_size=100,
n_layers=4,
func=nn.GELU,
):
"""
:param int input_numb_fields: The number of input components
of the model.
:param int output_numb_fields: The number of output components
of the model.
:param int dimension: the dimension of the domain of the functions.
:param int inner_size: number of neurons in the hidden layer(s).
Defaults to 100.
:param int n_layers: number of hidden layers. Default is 4.
:param func: the activation function to use. Default to nn.GELU.
:param list[str] field_indices: the label of the fields
in the input tensor.
:param list[str] coordinates_indices: the label of the
coordinates in the input tensor.
"""
# check consistency
check_consistency(input_numb_fields, int)
check_consistency(output_numb_fields, int)
check_consistency(field_indices, str)
check_consistency(coordinates_indices, str)
check_consistency(dimension, int)
check_consistency(inner_size, int)
check_consistency(n_layers, int)
check_consistency(func, nn.Module, subclass=True)
# assign
self.input_numb_fields = input_numb_fields
self.output_numb_fields = output_numb_fields
self.dimension = dimension
self.coordinates_indices = coordinates_indices
self.field_indices = field_indices
integral_net = nn.Sequential(
*[AVNOBlock(inner_size, func) for _ in range(n_layers)])
lifting_net = FeedForward(dimension + input_numb_fields, inner_size,
inner_size, n_layers, func)
projection_net = FeedForward(inner_size + dimension, output_numb_fields,
inner_size, n_layers, func)
super().__init__(lifting_net, integral_net, projection_net)
def forward(self, x):
r"""
Forward computation for Averaging Neural Operator. It performs a
lifting of the input by the ``lifting_net``. Then different layers
of Averaging Neural Operator Blocks are applied.
Finally the output is projected to the final dimensionality
by the ``projecting_net``.
:param torch.Tensor x: The input tensor for fourier block,
depending on ``dimension`` in the initialization. It expects
a tensor :math:`B \times N \times D`,
where :math:`B` is the batch_size, :math:`N` the number of points
in the mesh, :math:`D` the dimension of the problem, i.e. the sum
of ``len(coordinates_indices)+len(field_indices)``.
:return: The output tensor obtained from Average Neural Operator.
:rtype: torch.Tensor
"""
points_tmp = x.extract(self.coordinates_indices)
features_tmp = x.extract(self.field_indices)
new_batch = concatenate((features_tmp, points_tmp), dim=2)
new_batch = self._lifting_operator(new_batch)
new_batch = self._integral_kernels(new_batch)
new_batch = concatenate((new_batch, points_tmp), dim=2)
new_batch = self._projection_operator(new_batch)
return new_batch

View File

@@ -10,6 +10,7 @@ __all__ = [
"FourierBlock3D",
"PODBlock",
"PeriodicBoundaryEmbedding",
"AVNOBlock",
]
from .convolution_2d import ContinuousConvBlock
@@ -22,3 +23,4 @@ from .spectral import (
from .fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D
from .pod import PODBlock
from .embedding import PeriodicBoundaryEmbedding
from .avno_layer import AVNOBlock

View File

@@ -0,0 +1,67 @@
""" Module for Averaging Neural Operator Layer class. """
from torch import nn, mean
from pina.utils import check_consistency
class AVNOBlock(nn.Module):
r"""
The PINA implementation of the inner layer of the Averaging Neural Operator.
The operator layer performs an affine transformation where the convolution
is approximated with a local average. Given the input function
:math:`v(x)\in\mathbb{R}^{\rm{emb}}` the layer computes
the operator update :math:`K(v)` as:
.. math::
K(v) = \sigma\left(Wv(x) + b + \frac{1}{|\mathcal{A}|}\int v(y)dy\right)
where:
* :math:`\mathbb{R}^{\rm{emb}}` is the embedding (hidden) size
corresponding to the ``hidden_size`` object
* :math:`\sigma` is a non-linear activation, corresponding to the
``func`` object
* :math:`W\in\mathbb{R}^{\rm{emb}\times\rm{emb}}` is a tunable matrix.
* :math:`b\in\mathbb{R}^{\rm{emb}}` is a tunable bias.
.. seealso::
**Original reference**: Lanthaler S. Li, Z., Kovachki,
Stuart, A. (2020). *The Nonlocal Neural Operator: Universal
Approximation*.
DOI: `arXiv preprint arXiv:2304.13221.
<https://arxiv.org/abs/2304.13221>`_
"""
def __init__(self, hidden_size=100, func=nn.GELU):
"""
:param int hidden_size: Size of the hidden layer, defaults to 100.
:param func: The activation function, default to nn.GELU.
"""
super().__init__()
# Check type consistency
check_consistency(hidden_size, int)
check_consistency(func, nn.Module, subclass=True)
# Assignment
self._nn = nn.Linear(hidden_size, hidden_size)
self._func = func()
def forward(self, x):
r"""
Forward pass of the layer, it performs a sum of local average
and an affine transformation of the field.
:param torch.Tensor x: The input tensor for performing the
computation. It expects a tensor :math:`B \times N \times D`,
where :math:`B` is the batch_size, :math:`N` the number of points
in the mesh, :math:`D` the dimension of the problem. In particular
:math:`D` is the codomain of the function :math:`v`. For example
a scalar function has :math:`D=1`, a 4-dimensional vector function
:math:`D=4`.
:return: The output tensor obtained from Average Neural Operator Block.
:rtype: torch.Tensor
"""
return self._func(self._nn(x) + mean(x, dim=1, keepdim=True))