Neural Operator fix and addition

* Building FNO for 1D/2D/3D data
* Fixing bug in trunk/branch net in DeepONEt
* Fixing type check bug in spectral conv
* Adding tests for FNO
* Fixing bug in Fourier Layer (conv1d/2d/3d)
This commit is contained in:
Dario Coscia
2023-09-09 22:09:34 +02:00
committed by Nicola Demo
parent 83ecdb0eab
commit 603f56d264
6 changed files with 315 additions and 33 deletions

157
pina/model/fno.py Normal file
View File

@@ -0,0 +1,157 @@
import torch
import torch.nn as nn
from ..utils import check_consistency
from .layers.fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D
class FNO(torch.nn.Module):
"""
The PINA implementation of Fourier Neural Operator network.
Fourier Neural Operator (FNO) is a general architecture for learning Operators.
Unlike traditional machine learning methods FNO is designed to map
entire functions to other functions. It can be trained both with
Supervised learning strategies. FNO does global convolution by performing the
operation on the Fourier space.
.. seealso::
**Original reference**: Li, Z., Kovachki, N., Azizzadenesheli, K., Liu, B.,
Bhattacharya, K., Stuart, A., & Anandkumar, A. (2020). "Fourier neural operator for
parametric partial differential equations".
DOI: `arXiv preprint arXiv:2010.08895.
<https://arxiv.org/abs/2010.08895>`_
"""
def __init__(self,
lifting_net,
projecting_net,
n_modes,
dimensions = 3,
padding = 8,
padding_type = "constant",
inner_size = 20,
n_layers = 2,
func = nn.Tanh,
layers = None):
super().__init__()
# check type consistency
check_consistency(lifting_net, nn.Module)
check_consistency(projecting_net, nn.Module)
check_consistency(dimensions, int)
check_consistency(padding, int)
check_consistency(padding_type, str)
check_consistency(inner_size, int)
check_consistency(n_layers, int)
check_consistency(func, nn.Module, subclass=True)
if layers is not None:
if isinstance(layers, (tuple, list)):
check_consistency(layers, int)
else:
raise ValueError('layers must be tuple or list of int.')
if not isinstance(n_modes, (list, tuple, int)):
raise ValueError('n_modes must be a int or list or tuple of valid modes.'
' More information on the official documentation.')
# assign variables
# TODO check input lifting net and input projecting net
self._lifting_net = lifting_net
self._projecting_net = projecting_net
self._padding = padding
# initialize fourier layer for each dimension
if dimensions == 1:
fourier_layer = FourierBlock1D
elif dimensions == 2:
fourier_layer = FourierBlock2D
elif dimensions == 3:
fourier_layer = FourierBlock3D
else:
NotImplementedError('FNO implemented only for 1D/2D/3D data.')
# Here we build the FNO by stacking Fourier Blocks
# 1. Assign output dimensions for each FNO layer
if layers is None:
layers = [inner_size] * n_layers
# 2. Assign activation functions for each FNO layer
if isinstance(func, list):
if len(layers) != len(func):
raise RuntimeError('Uncosistent number of layers and functions.')
self._functions = func
else:
self._functions = [func for _ in range(len(layers))]
# 3. Assign modes functions for each FNO layer
if isinstance(n_modes, list):
if all(isinstance(i, list) for i in n_modes) and len(layers) != len(n_modes):
raise RuntimeError('Uncosistent number of layers and functions.')
elif all(isinstance(i, int) for i in n_modes):
n_modes = [n_modes] * len(layers)
else:
n_modes = [n_modes] * len(layers)
# 4. Build the FNO network
tmp_layers = layers.copy()
out_feats = lifting_net(torch.rand(10, dimensions)).shape[-1]
tmp_layers.insert(0, out_feats)
self._layers = []
for i in range(len(tmp_layers) - 1):
self._layers.append(
fourier_layer(input_numb_fields=tmp_layers[i],
output_numb_fields=tmp_layers[i + 1],
n_modes=n_modes[i],
activation=self._functions[i])
)
self._layers = nn.Sequential(*self._layers)
# 5. Padding values for spectral conv
if isinstance(padding, int):
padding = [padding] * dimensions
self._ipad = [-pad if pad > 0 else None for pad in padding[:dimensions]]
self._padding_type = padding_type
self._pad = [val for pair in zip([0]*dimensions, padding) for val in pair]
def forward(self, x):
"""
Forward computation for Fourier Neural Operator. It performs a
lifting of the input by the ``lifting_net``. Then different layers
of Fourier Blocks are applied. Finally the output is projected
to the final dimensionality by the ``projecting_net``.
:param torch.Tensor x: The input tensor for fourier block, depending on
``dimension`` in the initialization. In particular it is expected
* 1D tensors: ``[batch, X, channels]``
* 2D tensors: ``[batch, X, Y, channels]``
* 3D tensors: ``[batch, X, Y, Z, channels]``
:return: The output tensor obtained from the FNO.
:rtype: torch.Tensor
"""
# lifting the input in higher dimensional space
x = self._lifting_net(x)
# permuting the input [batch, channels, x, y, ...]
permutation_idx = [0, x.ndim-1, *[i for i in range(1, x.ndim-1)]]
x = x.permute(permutation_idx)
# padding the input
x = torch.nn.functional.pad(x, pad=self._pad, mode=self._padding_type)
# apply fourier layers
x = self._layers(x)
# remove padding
idxs = [slice(None), slice(None)] + [slice(pad) for pad in self._ipad]
x = x[idxs]
# permuting back [batch, x, y, ..., channels]
permutation_idx = [0, *[i for i in range(2, x.ndim)], 1]
x = x.permute(permutation_idx)
# apply projecting operator and return
return self._projecting_net(x)