Unifying integral kernel NO architectures (#239)

* Unify integral kernel NO architectures with NeuralKernelOperator
* Implement FNO based on NeuralKernelOperator
* modify doc for FNO and add for FourierIntegralKernel, NeuralKernelOperator
* adding tests

---------

Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local>
Co-authored-by: Dario Coscia <dariocoscia@dhcp-015.eduroam.sissa.it>
This commit is contained in:
Dario Coscia
2024-02-21 11:15:40 +01:00
committed by GitHub
parent eb1af0b50e
commit e516e779f9
7 changed files with 354 additions and 80 deletions

View File

@@ -48,11 +48,13 @@ Models
:maxdepth: 5
Network <models/network.rst>
KernelNeuralOperator <models/base_no.rst>
FeedForward <models/fnn.rst>
MultiFeedForward <models/multifeedforward.rst>
ResidualFeedForward <models/fnn_residual.rst>
DeepONet <models/deeponet.rst>
MIONet <models/mionet.rst>
FourierIntegralKernel <models/fourier_kernel.rst>
FNO <models/fno.rst>
Layers

View File

@@ -0,0 +1,7 @@
KernelNeuralOperator
=======================
.. currentmodule:: pina.model.base_no
.. autoclass:: KernelNeuralOperator
:members:
:show-inheritance:

View File

@@ -0,0 +1,7 @@
FourierIntegralKernel
=========================
.. currentmodule:: pina.model.fno
.. autoclass:: FourierIntegralKernel
:members:
:show-inheritance:

View File

@@ -1,13 +1,16 @@
__all__ = [
"FeedForward",
"ResidualFeedForward",
"MultiFeedForward",
"DeepONet",
"MIONet",
"FNO",
'FeedForward',
'ResidualFeedForward',
'MultiFeedForward',
'DeepONet',
'MIONet',
'FNO',
'FourierIntegralKernel',
'KernelNeuralOperator'
]
from .feed_forward import FeedForward, ResidualFeedForward
from .multi_feed_forward import MultiFeedForward
from .deeponet import DeepONet, MIONet
from .fno import FNO
from .fno import FNO, FourierIntegralKernel
from .base_no import KernelNeuralOperator

136
pina/model/base_no.py Normal file
View File

@@ -0,0 +1,136 @@
"""
Kernel Neural Operator Module.
"""
import torch
from pina.utils import check_consistency
class KernelNeuralOperator(torch.nn.Module):
r"""
Base class for composing Neural Operators with integral kernels.
This is a base class for composing neural operators with multiple
integral kernels. All neural operator models defined in PINA inherit
from this class. The structure is inspired by the work of Kovachki, N.
et al. see Figure 2 of the reference for extra details. The Neural
Operators inheriting from this class can be written as:
.. math::
G_\theta := P \circ K_m \circ \cdot \circ K_1 \circ L
where:
* :math:`G_\theta: \mathcal{A}\subset \mathbb{R}^{\rm{in}} \rightarrow
\mathcal{D}\subset \mathbb{R}^{\rm{out}}` is the neural operator
approximation of the unknown real operator :math:`G`, that is
:math:`G \approx G_\theta`
* :math:`L: \mathcal{A}\subset \mathbb{R}^{\rm{in}} \rightarrow
\mathbb{R}^{\rm{emb}}` is a lifting operator mapping the input
from its domain :math:`\mathcal{A}\subset \mathbb{R}^{\rm{in}}`
to its embedding dimension :math:`\mathbb{R}^{\rm{emb}}`
* :math:`\{K_i : \mathbb{R}^{\rm{emb}} \rightarrow
\mathbb{R}^{\rm{emb}} \}_{i=1}^m` are :math:`m` integral kernels
mapping each hidden representation to the next one.
* :math:`P : \mathbb{R}^{\rm{emb}} \rightarrow \mathcal{D}\subset
\mathbb{R}^{\rm{out}}` is a projection operator mapping the hidden
representation to the output function.
.. seealso::
**Original reference**: Kovachki, N., Li, Z., Liu, B.,
Azizzadenesheli, K., Bhattacharya, K., Stuart, A., & Anandkumar, A.
(2023). *Neural operator: Learning maps between function
spaces with applications to PDEs*. Journal of Machine Learning
Research, 24(89), 1-97.
"""
def __init__(self, lifting_operator, integral_kernels, projection_operator):
"""
:param torch.nn.Module lifting_operator: The lifting operator
mapping the input to its hidden dimension.
:param torch.nn.Module integral_kernels: List of integral kernels
mapping each hidden representation to the next one.
:param torch.nn.Module projection_operator: The projection operator
mapping the hidden representation to the output function.
"""
super().__init__()
self._lifting_operator = lifting_operator
self._integral_kernels = integral_kernels
self._projection_operator = projection_operator
@property
def lifting_operator(self):
"""
The lifting operator property.
"""
return self._lifting_operator
@lifting_operator.setter
def lifting_operator(self, value):
"""
The lifting operator setter
:param torch.nn.Module value: The lifting operator torch module.
"""
check_consistency(value, torch.nn.Module)
self._lifting_operator = value
@property
def projection_operator(self):
"""
The projection operator property.
"""
return self._projection_operator
@projection_operator.setter
def projection_operator(self, value):
"""
The projection operator setter
:param torch.nn.Module value: The projection operator torch module.
"""
check_consistency(value, torch.nn.Module)
self._projection_operator = value
@property
def integral_kernels(self):
"""
The integral kernels operator property.
"""
return self._integral_kernels
@integral_kernels.setter
def integral_kernels(self, value):
"""
The integral kernels operator setter
:param torch.nn.Module value: The integral kernels operator torch
module.
"""
check_consistency(value, torch.nn.Module)
self._integral_kernels = value
def forward(self, x):
r"""
Forward computation for Base Neural Operator. It performs a
lifting of the input by the ``lifting_operator``.
Then different layers integral kernels are applied using
``integral_kernels``. Finally the output is projected
to the final dimensionality by the ``projection_operator``.
:param torch.Tensor x: The input tensor for performing the
computation. It expects a tensor :math:`B \times N \times D`,
where :math:`B` is the batch_size, :math:`N` the number of points
in the mesh, :math:`D` the dimension of the problem. In particular
:math:`D` is the number of spatial/paramtric/temporal variables
plus the field variables. For example for 2D problems with 2
output\ variables :math:`D=4`.
:return: The output tensor obtained from the NO.
:rtype: torch.Tensor
"""
x = self.lifting_operator(x)
x = self.integral_kernels(x)
x = self.projection_operator(x)
return x

View File

@@ -1,54 +1,66 @@
"""
Fourier Neural Operator Module.
"""
import torch
import torch.nn as nn
from ..utils import check_consistency
from .layers.fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D
from pina import LabelTensor
import warnings
from ..utils import check_consistency
from .layers.fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D
from .base_no import KernelNeuralOperator
class FNO(torch.nn.Module):
class FourierIntegralKernel(torch.nn.Module):
"""
The PINA implementation of Fourier Neural Operator network.
Implementation of Fourier Integral Kernel network.
Fourier Neural Operator (FNO) is a general architecture for learning Operators.
Unlike traditional machine learning methods FNO is designed to map
entire functions to other functions. It can be trained both with
Supervised learning strategies. FNO does global convolution by performing the
operation on the Fourier space.
This class implements the Fourier Integral Kernel network, which is a
PINA implementation of Fourier Neural Operator kernel network.
It performs global convolution by operating in the Fourier space.
.. seealso::
**Original reference**: Li, Z., Kovachki, N., Azizzadenesheli, K., Liu, B.,
Bhattacharya, K., Stuart, A., & Anandkumar, A. (2020). *Fourier neural operator for
parametric partial differential equations*.
**Original reference**: Li, Z., Kovachki, N., Azizzadenesheli,
K., Liu, B., Bhattacharya, K., Stuart, A., & Anandkumar, A.
(2020). *Fourier neural operator for parametric partial
differential equations*.
DOI: `arXiv preprint arXiv:2010.08895.
<https://arxiv.org/abs/2010.08895>`_
"""
def __init__(
self,
lifting_net,
projecting_net,
n_modes,
dimensions=3,
padding=8,
padding_type="constant",
inner_size=20,
n_layers=2,
func=nn.Tanh,
layers=None,
):
def __init__(self,
input_numb_fields,
output_numb_fields,
n_modes,
dimensions=3,
padding=8,
padding_type="constant",
inner_size=20,
n_layers=2,
func=nn.Tanh,
layers=None):
"""
:param int input_numb_fields: Number of input fields.
:param int output_numb_fields: Number of output fields.
:param int | list[int] n_modes: Number of modes.
:param int dimensions: Number of dimensions (1, 2, or 3).
:param int padding: Padding size, defaults to 8.
:param str padding_type: Type of padding, defaults to "constant".
:param int inner_size: Inner size, defaults to 20.
:param int n_layers: Number of layers, defaults to 2.
:param torch.nn.Module func: Activation function, defaults to nn.Tanh.
:param list[int] layers: List of layer sizes, defaults to None.
"""
super().__init__()
# check type consistency
check_consistency(lifting_net, nn.Module)
check_consistency(projecting_net, nn.Module)
check_consistency(dimensions, int)
check_consistency(padding, int)
check_consistency(padding_type, str)
check_consistency(inner_size, int)
check_consistency(n_layers, int)
check_consistency(func, nn.Module, subclass=True)
if layers is not None:
if isinstance(layers, (tuple, list)):
check_consistency(layers, int)
@@ -57,13 +69,9 @@ class FNO(torch.nn.Module):
if not isinstance(n_modes, (list, tuple, int)):
raise ValueError(
"n_modes must be a int or list or tuple of valid modes."
" More information on the official documentation."
)
" More information on the official documentation.")
# assign variables
# TODO check input lifting net and input projecting net
self._lifting_net = lifting_net
self._projecting_net = projecting_net
# assign padding
self._padding = padding
# initialize fourier layer for each dimension
@@ -74,9 +82,11 @@ class FNO(torch.nn.Module):
elif dimensions == 3:
fourier_layer = FourierBlock3D
else:
raise NotImplementedError("FNO implemented only for 1D/2D/3D data.")
raise NotImplementedError(
"FNO implemented only for 1D/2D/3D data."
)
# Here we build the FNO by stacking Fourier Blocks
# Here we build the FNO kernels by stacking Fourier Blocks
# 1. Assign output dimensions for each FNO layer
if layers is None:
@@ -86,43 +96,33 @@ class FNO(torch.nn.Module):
if isinstance(func, list):
if len(layers) != len(func):
raise RuntimeError(
"Uncosistent number of layers and functions."
)
self._functions = func
'Uncosistent number of layers and functions.')
_functions = func
else:
self._functions = [func for _ in range(len(layers))]
_functions = [func for _ in range(len(layers) - 1)]
_functions.append(torch.nn.Identity)
# 3. Assign modes functions for each FNO layer
if isinstance(n_modes, list):
if all(isinstance(i, list) for i in n_modes) and len(layers) != len(
n_modes
):
if all(isinstance(i, list)
for i in n_modes) and len(layers) != len(n_modes):
raise RuntimeError(
"Uncosistent number of layers and functions."
)
"Uncosistent number of layers and functions.")
elif all(isinstance(i, int) for i in n_modes):
n_modes = [n_modes] * len(layers)
else:
n_modes = [n_modes] * len(layers)
# 4. Build the FNO network
tmp_layers = layers.copy()
first_parameter = next(lifting_net.parameters())
input_shape = first_parameter.size()
out_feats = lifting_net(torch.rand(size=input_shape)).shape[-1]
tmp_layers.insert(0, out_feats)
self._layers = []
for i in range(len(tmp_layers) - 1):
self._layers.append(
fourier_layer(
input_numb_fields=tmp_layers[i],
output_numb_fields=tmp_layers[i + 1],
n_modes=n_modes[i],
activation=self._functions[i],
)
)
self._layers = nn.Sequential(*self._layers)
_layers = []
tmp_layers = [input_numb_fields] + layers + [output_numb_fields]
for i in range(len(layers)):
_layers.append(
fourier_layer(input_numb_fields=tmp_layers[i],
output_numb_fields=tmp_layers[i + 1],
n_modes=n_modes[i],
activation=_functions[i]))
self._layers = nn.Sequential(*_layers)
# 5. Padding values for spectral conv
if isinstance(padding, int):
@@ -140,23 +140,22 @@ class FNO(torch.nn.Module):
of Fourier Blocks are applied. Finally the output is projected
to the final dimensionality by the ``projecting_net``.
:param torch.Tensor x: The input tensor for fourier block, depending on
``dimension`` in the initialization. In particular it is expected
:param torch.Tensor x: The input tensor for fourier block,
depending on ``dimension`` in the initialization.
In particular it is expected:
* 1D tensors: ``[batch, X, channels]``
* 2D tensors: ``[batch, X, Y, channels]``
* 3D tensors: ``[batch, X, Y, Z, channels]``
:return: The output tensor obtained from the FNO.
:return: The output tensor obtained from the kernels convolution.
:rtype: torch.Tensor
"""
if isinstance(x, LabelTensor): # TODO remove when Network is fixed
if isinstance(x, LabelTensor): #TODO remove when Network is fixed
warnings.warn(
"LabelTensor passed as input is not allowed, casting LabelTensor to Torch.Tensor"
'LabelTensor passed as input is not allowed,'
' casting LabelTensor to Torch.Tensor'
)
x = x.as_subclass(torch.Tensor)
# lifting the input in higher dimensional space
x = self._lifting_net(x)
# permuting the input [batch, channels, x, y, ...]
permutation_idx = [0, x.ndim - 1, *[i for i in range(1, x.ndim - 1)]]
x = x.permute(permutation_idx)
@@ -175,5 +174,85 @@ class FNO(torch.nn.Module):
permutation_idx = [0, *[i for i in range(2, x.ndim)], 1]
x = x.permute(permutation_idx)
# apply projecting operator and return
return self._projecting_net(x)
return x
class FNO(KernelNeuralOperator):
"""
The PINA implementation of Fourier Neural Operator network.
Fourier Neural Operator (FNO) is a general architecture for
learning Operators. Unlike traditional machine learning methods
FNO is designed to map entire functions to other functions. It
can be trained with Supervised learning strategies. FNO does global
convolution by performing the operation on the Fourier space.
.. seealso::
**Original reference**: Li, Z., Kovachki, N., Azizzadenesheli,
K., Liu, B., Bhattacharya, K., Stuart, A., & Anandkumar, A.
(2020). *Fourier neural operator for parametric partial
differential equations*.
DOI: `arXiv preprint arXiv:2010.08895.
<https://arxiv.org/abs/2010.08895>`_
"""
def __init__(self,
lifting_net,
projecting_net,
n_modes,
dimensions=3,
padding=8,
padding_type="constant",
inner_size=20,
n_layers=2,
func=nn.Tanh,
layers=None):
"""
:param torch.nn.Module lifting_net: The neural network for lifting
the input.
:param torch.nn.Module projecting_net: The neural network for
projecting the output.
:param int | list[int] n_modes: Number of modes.
:param int dimensions: Number of dimensions (1, 2, or 3).
:param int padding: Padding size, defaults to 8.
:param str padding_type: Type of padding, defaults to `constant`.
:param int inner_size: Inner size, defaults to 20.
:param int n_layers: Number of layers, defaults to 2.
:param torch.nn.Module func: Activation function, defaults to nn.Tanh.
:param list[int] layers: List of layer sizes, defaults to None.
"""
lifting_operator_out = lifting_net(
torch.rand(size=next(lifting_net.parameters()).size())).shape[-1]
super().__init__(lifting_operator=lifting_net,
projection_operator=projecting_net,
integral_kernels=FourierIntegralKernel(
input_numb_fields=lifting_operator_out,
output_numb_fields=next(
projecting_net.parameters()).size(),
n_modes=n_modes,
dimensions=dimensions,
padding=padding,
padding_type=padding_type,
inner_size=inner_size,
n_layers=n_layers,
func=func,
layers=layers))
def forward(self, x):
"""
Forward computation for Fourier Neural Operator. It performs a
lifting of the input by the ``lifting_net``. Then different layers
of Fourier Blocks are applied. Finally the output is projected
to the final dimensionality by the ``projecting_net``.
:param torch.Tensor x: The input tensor for fourier block,
depending on ``dimension`` in the initialization. In
particular it is expected:
* 1D tensors: ``[batch, X, channels]``
* 2D tensors: ``[batch, X, Y, channels]``
* 3D tensors: ``[batch, X, Y, Z, channels]``
:return: The output tensor obtained from FNO.
:rtype: torch.Tensor
"""
return super().forward(x)

View File

@@ -0,0 +1,40 @@
import torch
from pina.model import KernelNeuralOperator, FeedForward
input_dim = 2
output_dim = 4
embedding_dim = 24
batch_size = 10
numb = 256
data = torch.rand(size=(batch_size, numb, input_dim), requires_grad=True)
output_shape = torch.Size([batch_size, numb, output_dim])
lifting_operator = FeedForward(input_dimensions=input_dim, output_dimensions=embedding_dim)
projection_operator = FeedForward(input_dimensions=embedding_dim, output_dimensions=output_dim)
integral_kernels = torch.nn.Sequential(FeedForward(input_dimensions=embedding_dim,
output_dimensions=embedding_dim),
FeedForward(input_dimensions=embedding_dim,
output_dimensions=embedding_dim),)
def test_constructor():
KernelNeuralOperator(lifting_operator=lifting_operator,
integral_kernels=integral_kernels,
projection_operator=projection_operator)
def test_forward():
operator = KernelNeuralOperator(lifting_operator=lifting_operator,
integral_kernels=integral_kernels,
projection_operator=projection_operator)
out = operator(data)
assert out.shape == output_shape
def test_backward():
operator = KernelNeuralOperator(lifting_operator=lifting_operator,
integral_kernels=integral_kernels,
projection_operator=projection_operator)
out = operator(data)
loss = torch.nn.functional.mse_loss(out, torch.zeros_like(out))
loss.backward()
grad = data.grad
assert grad.shape == data.shape