Neural Operator fix and addition
* Building FNO for 1D/2D/3D data * Fixing bug in trunk/branch net in DeepONEt * Fixing type check bug in spectral conv * Adding tests for FNO * Fixing bug in Fourier Layer (conv1d/2d/3d)
This commit is contained in:
committed by
Nicola Demo
parent
83ecdb0eab
commit
603f56d264
157
pina/model/fno.py
Normal file
157
pina/model/fno.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ..utils import check_consistency
|
||||
from .layers.fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D
|
||||
|
||||
|
||||
class FNO(torch.nn.Module):
|
||||
"""
|
||||
The PINA implementation of Fourier Neural Operator network.
|
||||
|
||||
Fourier Neural Operator (FNO) is a general architecture for learning Operators.
|
||||
Unlike traditional machine learning methods FNO is designed to map
|
||||
entire functions to other functions. It can be trained both with
|
||||
Supervised learning strategies. FNO does global convolution by performing the
|
||||
operation on the Fourier space.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Li, Z., Kovachki, N., Azizzadenesheli, K., Liu, B.,
|
||||
Bhattacharya, K., Stuart, A., & Anandkumar, A. (2020). "Fourier neural operator for
|
||||
parametric partial differential equations".
|
||||
DOI: `arXiv preprint arXiv:2010.08895.
|
||||
<https://arxiv.org/abs/2010.08895>`_
|
||||
"""
|
||||
def __init__(self,
|
||||
lifting_net,
|
||||
projecting_net,
|
||||
n_modes,
|
||||
dimensions = 3,
|
||||
padding = 8,
|
||||
padding_type = "constant",
|
||||
inner_size = 20,
|
||||
n_layers = 2,
|
||||
func = nn.Tanh,
|
||||
layers = None):
|
||||
super().__init__()
|
||||
|
||||
# check type consistency
|
||||
check_consistency(lifting_net, nn.Module)
|
||||
check_consistency(projecting_net, nn.Module)
|
||||
check_consistency(dimensions, int)
|
||||
check_consistency(padding, int)
|
||||
check_consistency(padding_type, str)
|
||||
check_consistency(inner_size, int)
|
||||
check_consistency(n_layers, int)
|
||||
check_consistency(func, nn.Module, subclass=True)
|
||||
if layers is not None:
|
||||
if isinstance(layers, (tuple, list)):
|
||||
check_consistency(layers, int)
|
||||
else:
|
||||
raise ValueError('layers must be tuple or list of int.')
|
||||
if not isinstance(n_modes, (list, tuple, int)):
|
||||
raise ValueError('n_modes must be a int or list or tuple of valid modes.'
|
||||
' More information on the official documentation.')
|
||||
|
||||
# assign variables
|
||||
# TODO check input lifting net and input projecting net
|
||||
self._lifting_net = lifting_net
|
||||
self._projecting_net = projecting_net
|
||||
self._padding = padding
|
||||
|
||||
# initialize fourier layer for each dimension
|
||||
if dimensions == 1:
|
||||
fourier_layer = FourierBlock1D
|
||||
elif dimensions == 2:
|
||||
fourier_layer = FourierBlock2D
|
||||
elif dimensions == 3:
|
||||
fourier_layer = FourierBlock3D
|
||||
else:
|
||||
NotImplementedError('FNO implemented only for 1D/2D/3D data.')
|
||||
|
||||
# Here we build the FNO by stacking Fourier Blocks
|
||||
|
||||
# 1. Assign output dimensions for each FNO layer
|
||||
if layers is None:
|
||||
layers = [inner_size] * n_layers
|
||||
|
||||
# 2. Assign activation functions for each FNO layer
|
||||
if isinstance(func, list):
|
||||
if len(layers) != len(func):
|
||||
raise RuntimeError('Uncosistent number of layers and functions.')
|
||||
self._functions = func
|
||||
else:
|
||||
self._functions = [func for _ in range(len(layers))]
|
||||
|
||||
# 3. Assign modes functions for each FNO layer
|
||||
if isinstance(n_modes, list):
|
||||
if all(isinstance(i, list) for i in n_modes) and len(layers) != len(n_modes):
|
||||
raise RuntimeError('Uncosistent number of layers and functions.')
|
||||
elif all(isinstance(i, int) for i in n_modes):
|
||||
n_modes = [n_modes] * len(layers)
|
||||
else:
|
||||
n_modes = [n_modes] * len(layers)
|
||||
|
||||
# 4. Build the FNO network
|
||||
tmp_layers = layers.copy()
|
||||
out_feats = lifting_net(torch.rand(10, dimensions)).shape[-1]
|
||||
tmp_layers.insert(0, out_feats)
|
||||
|
||||
self._layers = []
|
||||
for i in range(len(tmp_layers) - 1):
|
||||
self._layers.append(
|
||||
fourier_layer(input_numb_fields=tmp_layers[i],
|
||||
output_numb_fields=tmp_layers[i + 1],
|
||||
n_modes=n_modes[i],
|
||||
activation=self._functions[i])
|
||||
)
|
||||
self._layers = nn.Sequential(*self._layers)
|
||||
|
||||
# 5. Padding values for spectral conv
|
||||
if isinstance(padding, int):
|
||||
padding = [padding] * dimensions
|
||||
self._ipad = [-pad if pad > 0 else None for pad in padding[:dimensions]]
|
||||
self._padding_type = padding_type
|
||||
self._pad = [val for pair in zip([0]*dimensions, padding) for val in pair]
|
||||
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward computation for Fourier Neural Operator. It performs a
|
||||
lifting of the input by the ``lifting_net``. Then different layers
|
||||
of Fourier Blocks are applied. Finally the output is projected
|
||||
to the final dimensionality by the ``projecting_net``.
|
||||
|
||||
:param torch.Tensor x: The input tensor for fourier block, depending on
|
||||
``dimension`` in the initialization. In particular it is expected
|
||||
* 1D tensors: ``[batch, X, channels]``
|
||||
* 2D tensors: ``[batch, X, Y, channels]``
|
||||
* 3D tensors: ``[batch, X, Y, Z, channels]``
|
||||
:return: The output tensor obtained from the FNO.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
|
||||
# lifting the input in higher dimensional space
|
||||
x = self._lifting_net(x)
|
||||
|
||||
# permuting the input [batch, channels, x, y, ...]
|
||||
permutation_idx = [0, x.ndim-1, *[i for i in range(1, x.ndim-1)]]
|
||||
x = x.permute(permutation_idx)
|
||||
|
||||
# padding the input
|
||||
x = torch.nn.functional.pad(x, pad=self._pad, mode=self._padding_type)
|
||||
|
||||
# apply fourier layers
|
||||
x = self._layers(x)
|
||||
|
||||
# remove padding
|
||||
idxs = [slice(None), slice(None)] + [slice(pad) for pad in self._ipad]
|
||||
x = x[idxs]
|
||||
|
||||
# permuting back [batch, x, y, ..., channels]
|
||||
permutation_idx = [0, *[i for i in range(2, x.ndim)], 1]
|
||||
x = x.permute(permutation_idx)
|
||||
|
||||
# apply projecting operator and return
|
||||
return self._projecting_net(x)
|
||||
Reference in New Issue
Block a user