Add Averaging Neural Operator with tests and a tutorial (#230)
* add Averaging Neural Operator with tests * add backward test * minor changes * doc addition --------- Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local>
This commit is contained in:
@@ -56,6 +56,7 @@ Models
|
||||
MIONet <models/mionet.rst>
|
||||
FourierIntegralKernel <models/fourier_kernel.rst>
|
||||
FNO <models/fno.rst>
|
||||
AveragingNeuralOperator <models/avno.rst>
|
||||
|
||||
Layers
|
||||
-------------
|
||||
@@ -67,11 +68,11 @@ Layers
|
||||
EnhancedLinear layer <layers/enhanced_linear.rst>
|
||||
Spectral convolution <layers/spectral.rst>
|
||||
Fourier layers <layers/fourier.rst>
|
||||
Averaging layer <layers/avno_layer.rst>
|
||||
Continuous convolution <layers/convolution.rst>
|
||||
Proper Orthogonal Decomposition <layers/pod.rst>
|
||||
Periodic Boundary Condition embeddings <layers/embedding.rst>
|
||||
|
||||
|
||||
Equations and Operators
|
||||
-------------------------
|
||||
|
||||
|
||||
8
docs/source/_rst/layers/avno_layer.rst
Normal file
8
docs/source/_rst/layers/avno_layer.rst
Normal file
@@ -0,0 +1,8 @@
|
||||
Averaging layers
|
||||
====================
|
||||
.. currentmodule:: pina.model.layers.avno_layer
|
||||
|
||||
.. autoclass:: AVNOBlock
|
||||
:members:
|
||||
:show-inheritance:
|
||||
:noindex:
|
||||
7
docs/source/_rst/models/avno.rst
Normal file
7
docs/source/_rst/models/avno.rst
Normal file
@@ -0,0 +1,7 @@
|
||||
Averaging Neural Operator
|
||||
==============================
|
||||
.. currentmodule:: pina.model.avno
|
||||
|
||||
.. autoclass:: AveragingNeuralOperator
|
||||
:members:
|
||||
:show-inheritance:
|
||||
@@ -7,6 +7,7 @@ __all__ = [
|
||||
"FNO",
|
||||
"FourierIntegralKernel",
|
||||
"KernelNeuralOperator",
|
||||
"AveragingNeuralOperator",
|
||||
]
|
||||
|
||||
from .feed_forward import FeedForward, ResidualFeedForward
|
||||
@@ -14,3 +15,4 @@ from .multi_feed_forward import MultiFeedForward
|
||||
from .deeponet import DeepONet, MIONet
|
||||
from .fno import FNO, FourierIntegralKernel
|
||||
from .base_no import KernelNeuralOperator
|
||||
from .avno import AveragingNeuralOperator
|
||||
|
||||
104
pina/model/avno.py
Normal file
104
pina/model/avno.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Module Averaging Neural Operator."""
|
||||
|
||||
from torch import nn, concatenate
|
||||
from . import FeedForward
|
||||
from .layers import AVNOBlock
|
||||
from .base_no import KernelNeuralOperator
|
||||
from pina.utils import check_consistency
|
||||
|
||||
|
||||
class AveragingNeuralOperator(KernelNeuralOperator):
|
||||
"""
|
||||
Implementation of Averaging Neural Operator.
|
||||
|
||||
Averaging Neural Operator is a general architecture for
|
||||
learning Operators. Unlike traditional machine learning methods
|
||||
AveragingNeuralOperator is designed to map entire functions
|
||||
to other functions. It can be trained with Supervised learning strategies.
|
||||
AveragingNeuralOperator does convolution by performing a field average.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Lanthaler S. Li, Z., Kovachki,
|
||||
Stuart, A. (2020). *The Nonlocal Neural Operator:
|
||||
Universal Approximation*.
|
||||
DOI: `arXiv preprint arXiv:2304.13221.
|
||||
<https://arxiv.org/abs/2304.13221>`_
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_numb_fields,
|
||||
output_numb_fields,
|
||||
field_indices,
|
||||
coordinates_indices,
|
||||
dimension=3,
|
||||
inner_size=100,
|
||||
n_layers=4,
|
||||
func=nn.GELU,
|
||||
):
|
||||
"""
|
||||
:param int input_numb_fields: The number of input components
|
||||
of the model.
|
||||
:param int output_numb_fields: The number of output components
|
||||
of the model.
|
||||
:param int dimension: the dimension of the domain of the functions.
|
||||
:param int inner_size: number of neurons in the hidden layer(s).
|
||||
Defaults to 100.
|
||||
:param int n_layers: number of hidden layers. Default is 4.
|
||||
:param func: the activation function to use. Default to nn.GELU.
|
||||
:param list[str] field_indices: the label of the fields
|
||||
in the input tensor.
|
||||
:param list[str] coordinates_indices: the label of the
|
||||
coordinates in the input tensor.
|
||||
"""
|
||||
|
||||
# check consistency
|
||||
check_consistency(input_numb_fields, int)
|
||||
check_consistency(output_numb_fields, int)
|
||||
check_consistency(field_indices, str)
|
||||
check_consistency(coordinates_indices, str)
|
||||
check_consistency(dimension, int)
|
||||
check_consistency(inner_size, int)
|
||||
check_consistency(n_layers, int)
|
||||
check_consistency(func, nn.Module, subclass=True)
|
||||
|
||||
# assign
|
||||
self.input_numb_fields = input_numb_fields
|
||||
self.output_numb_fields = output_numb_fields
|
||||
self.dimension = dimension
|
||||
self.coordinates_indices = coordinates_indices
|
||||
self.field_indices = field_indices
|
||||
integral_net = nn.Sequential(
|
||||
*[AVNOBlock(inner_size, func) for _ in range(n_layers)])
|
||||
lifting_net = FeedForward(dimension + input_numb_fields, inner_size,
|
||||
inner_size, n_layers, func)
|
||||
projection_net = FeedForward(inner_size + dimension, output_numb_fields,
|
||||
inner_size, n_layers, func)
|
||||
super().__init__(lifting_net, integral_net, projection_net)
|
||||
|
||||
def forward(self, x):
|
||||
r"""
|
||||
Forward computation for Averaging Neural Operator. It performs a
|
||||
lifting of the input by the ``lifting_net``. Then different layers
|
||||
of Averaging Neural Operator Blocks are applied.
|
||||
Finally the output is projected to the final dimensionality
|
||||
by the ``projecting_net``.
|
||||
|
||||
:param torch.Tensor x: The input tensor for fourier block,
|
||||
depending on ``dimension`` in the initialization. It expects
|
||||
a tensor :math:`B \times N \times D`,
|
||||
where :math:`B` is the batch_size, :math:`N` the number of points
|
||||
in the mesh, :math:`D` the dimension of the problem, i.e. the sum
|
||||
of ``len(coordinates_indices)+len(field_indices)``.
|
||||
:return: The output tensor obtained from Average Neural Operator.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
points_tmp = x.extract(self.coordinates_indices)
|
||||
features_tmp = x.extract(self.field_indices)
|
||||
new_batch = concatenate((features_tmp, points_tmp), dim=2)
|
||||
new_batch = self._lifting_operator(new_batch)
|
||||
new_batch = self._integral_kernels(new_batch)
|
||||
new_batch = concatenate((new_batch, points_tmp), dim=2)
|
||||
new_batch = self._projection_operator(new_batch)
|
||||
return new_batch
|
||||
@@ -10,6 +10,7 @@ __all__ = [
|
||||
"FourierBlock3D",
|
||||
"PODBlock",
|
||||
"PeriodicBoundaryEmbedding",
|
||||
"AVNOBlock",
|
||||
]
|
||||
|
||||
from .convolution_2d import ContinuousConvBlock
|
||||
@@ -22,3 +23,4 @@ from .spectral import (
|
||||
from .fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D
|
||||
from .pod import PODBlock
|
||||
from .embedding import PeriodicBoundaryEmbedding
|
||||
from .avno_layer import AVNOBlock
|
||||
|
||||
67
pina/model/layers/avno_layer.py
Normal file
67
pina/model/layers/avno_layer.py
Normal file
@@ -0,0 +1,67 @@
|
||||
""" Module for Averaging Neural Operator Layer class. """
|
||||
|
||||
from torch import nn, mean
|
||||
from pina.utils import check_consistency
|
||||
|
||||
|
||||
class AVNOBlock(nn.Module):
|
||||
r"""
|
||||
The PINA implementation of the inner layer of the Averaging Neural Operator.
|
||||
|
||||
The operator layer performs an affine transformation where the convolution
|
||||
is approximated with a local average. Given the input function
|
||||
:math:`v(x)\in\mathbb{R}^{\rm{emb}}` the layer computes
|
||||
the operator update :math:`K(v)` as:
|
||||
|
||||
.. math::
|
||||
K(v) = \sigma\left(Wv(x) + b + \frac{1}{|\mathcal{A}|}\int v(y)dy\right)
|
||||
|
||||
where:
|
||||
|
||||
* :math:`\mathbb{R}^{\rm{emb}}` is the embedding (hidden) size
|
||||
corresponding to the ``hidden_size`` object
|
||||
* :math:`\sigma` is a non-linear activation, corresponding to the
|
||||
``func`` object
|
||||
* :math:`W\in\mathbb{R}^{\rm{emb}\times\rm{emb}}` is a tunable matrix.
|
||||
* :math:`b\in\mathbb{R}^{\rm{emb}}` is a tunable bias.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Lanthaler S. Li, Z., Kovachki,
|
||||
Stuart, A. (2020). *The Nonlocal Neural Operator: Universal
|
||||
Approximation*.
|
||||
DOI: `arXiv preprint arXiv:2304.13221.
|
||||
<https://arxiv.org/abs/2304.13221>`_
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size=100, func=nn.GELU):
|
||||
"""
|
||||
:param int hidden_size: Size of the hidden layer, defaults to 100.
|
||||
:param func: The activation function, default to nn.GELU.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Check type consistency
|
||||
check_consistency(hidden_size, int)
|
||||
check_consistency(func, nn.Module, subclass=True)
|
||||
# Assignment
|
||||
self._nn = nn.Linear(hidden_size, hidden_size)
|
||||
self._func = func()
|
||||
|
||||
def forward(self, x):
|
||||
r"""
|
||||
Forward pass of the layer, it performs a sum of local average
|
||||
and an affine transformation of the field.
|
||||
|
||||
:param torch.Tensor x: The input tensor for performing the
|
||||
computation. It expects a tensor :math:`B \times N \times D`,
|
||||
where :math:`B` is the batch_size, :math:`N` the number of points
|
||||
in the mesh, :math:`D` the dimension of the problem. In particular
|
||||
:math:`D` is the codomain of the function :math:`v`. For example
|
||||
a scalar function has :math:`D=1`, a 4-dimensional vector function
|
||||
:math:`D=4`.
|
||||
:return: The output tensor obtained from Average Neural Operator Block.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return self._func(self._nn(x) + mean(x, dim=1, keepdim=True))
|
||||
62
tests/test_model/test_avno.py
Normal file
62
tests/test_model/test_avno.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import torch
|
||||
from pina.model import AveragingNeuralOperator
|
||||
from pina import LabelTensor
|
||||
|
||||
output_numb_fields = 5
|
||||
batch_size = 15
|
||||
|
||||
|
||||
def test_constructor():
|
||||
input_numb_fields = 1
|
||||
output_numb_fields = 1
|
||||
#minimuum constructor
|
||||
AveragingNeuralOperator(input_numb_fields,
|
||||
output_numb_fields,
|
||||
coordinates_indices=['p'],
|
||||
field_indices=['v'])
|
||||
|
||||
#all constructor
|
||||
AveragingNeuralOperator(input_numb_fields,
|
||||
output_numb_fields,
|
||||
inner_size=5,
|
||||
n_layers=5,
|
||||
func=torch.nn.ReLU,
|
||||
coordinates_indices=['p'],
|
||||
field_indices=['v'])
|
||||
|
||||
|
||||
def test_forward():
|
||||
input_numb_fields = 1
|
||||
output_numb_fields = 1
|
||||
dimension = 1
|
||||
input_ = LabelTensor(
|
||||
torch.rand(batch_size, 1000, input_numb_fields + dimension), ['p', 'v'])
|
||||
ano = AveragingNeuralOperator(input_numb_fields,
|
||||
output_numb_fields,
|
||||
dimension=dimension,
|
||||
coordinates_indices=['p'],
|
||||
field_indices=['v'])
|
||||
out = ano(input_)
|
||||
assert out.shape == torch.Size(
|
||||
[batch_size, input_.shape[1], output_numb_fields])
|
||||
|
||||
|
||||
def test_backward():
|
||||
input_numb_fields = 1
|
||||
dimension = 1
|
||||
output_numb_fields = 1
|
||||
input_ = LabelTensor(
|
||||
torch.rand(batch_size, 1000, dimension + input_numb_fields),
|
||||
['p', 'v'])
|
||||
input_ = input_.requires_grad_()
|
||||
avno = AveragingNeuralOperator(input_numb_fields,
|
||||
output_numb_fields,
|
||||
dimension=dimension,
|
||||
coordinates_indices=['p'],
|
||||
field_indices=['v'])
|
||||
out = avno(input_)
|
||||
tmp = torch.linalg.norm(out)
|
||||
tmp.backward()
|
||||
grad = input_.grad
|
||||
assert grad.shape == torch.Size(
|
||||
[batch_size, input_.shape[1], dimension + input_numb_fields])
|
||||
Reference in New Issue
Block a user