* Building FNO for 1D/2D/3D data * Fixing bug in trunk/branch net in DeepONEt * Fixing type check bug in spectral conv * Adding tests for FNO * Fixing bug in Fourier Layer (conv1d/2d/3d)
158 lines
6.0 KiB
Python
158 lines
6.0 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from ..utils import check_consistency
|
|
from .layers.fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D
|
|
|
|
|
|
class FNO(torch.nn.Module):
|
|
"""
|
|
The PINA implementation of Fourier Neural Operator network.
|
|
|
|
Fourier Neural Operator (FNO) is a general architecture for learning Operators.
|
|
Unlike traditional machine learning methods FNO is designed to map
|
|
entire functions to other functions. It can be trained both with
|
|
Supervised learning strategies. FNO does global convolution by performing the
|
|
operation on the Fourier space.
|
|
|
|
.. seealso::
|
|
|
|
**Original reference**: Li, Z., Kovachki, N., Azizzadenesheli, K., Liu, B.,
|
|
Bhattacharya, K., Stuart, A., & Anandkumar, A. (2020). "Fourier neural operator for
|
|
parametric partial differential equations".
|
|
DOI: `arXiv preprint arXiv:2010.08895.
|
|
<https://arxiv.org/abs/2010.08895>`_
|
|
"""
|
|
def __init__(self,
|
|
lifting_net,
|
|
projecting_net,
|
|
n_modes,
|
|
dimensions = 3,
|
|
padding = 8,
|
|
padding_type = "constant",
|
|
inner_size = 20,
|
|
n_layers = 2,
|
|
func = nn.Tanh,
|
|
layers = None):
|
|
super().__init__()
|
|
|
|
# check type consistency
|
|
check_consistency(lifting_net, nn.Module)
|
|
check_consistency(projecting_net, nn.Module)
|
|
check_consistency(dimensions, int)
|
|
check_consistency(padding, int)
|
|
check_consistency(padding_type, str)
|
|
check_consistency(inner_size, int)
|
|
check_consistency(n_layers, int)
|
|
check_consistency(func, nn.Module, subclass=True)
|
|
if layers is not None:
|
|
if isinstance(layers, (tuple, list)):
|
|
check_consistency(layers, int)
|
|
else:
|
|
raise ValueError('layers must be tuple or list of int.')
|
|
if not isinstance(n_modes, (list, tuple, int)):
|
|
raise ValueError('n_modes must be a int or list or tuple of valid modes.'
|
|
' More information on the official documentation.')
|
|
|
|
# assign variables
|
|
# TODO check input lifting net and input projecting net
|
|
self._lifting_net = lifting_net
|
|
self._projecting_net = projecting_net
|
|
self._padding = padding
|
|
|
|
# initialize fourier layer for each dimension
|
|
if dimensions == 1:
|
|
fourier_layer = FourierBlock1D
|
|
elif dimensions == 2:
|
|
fourier_layer = FourierBlock2D
|
|
elif dimensions == 3:
|
|
fourier_layer = FourierBlock3D
|
|
else:
|
|
NotImplementedError('FNO implemented only for 1D/2D/3D data.')
|
|
|
|
# Here we build the FNO by stacking Fourier Blocks
|
|
|
|
# 1. Assign output dimensions for each FNO layer
|
|
if layers is None:
|
|
layers = [inner_size] * n_layers
|
|
|
|
# 2. Assign activation functions for each FNO layer
|
|
if isinstance(func, list):
|
|
if len(layers) != len(func):
|
|
raise RuntimeError('Uncosistent number of layers and functions.')
|
|
self._functions = func
|
|
else:
|
|
self._functions = [func for _ in range(len(layers))]
|
|
|
|
# 3. Assign modes functions for each FNO layer
|
|
if isinstance(n_modes, list):
|
|
if all(isinstance(i, list) for i in n_modes) and len(layers) != len(n_modes):
|
|
raise RuntimeError('Uncosistent number of layers and functions.')
|
|
elif all(isinstance(i, int) for i in n_modes):
|
|
n_modes = [n_modes] * len(layers)
|
|
else:
|
|
n_modes = [n_modes] * len(layers)
|
|
|
|
# 4. Build the FNO network
|
|
tmp_layers = layers.copy()
|
|
out_feats = lifting_net(torch.rand(10, dimensions)).shape[-1]
|
|
tmp_layers.insert(0, out_feats)
|
|
|
|
self._layers = []
|
|
for i in range(len(tmp_layers) - 1):
|
|
self._layers.append(
|
|
fourier_layer(input_numb_fields=tmp_layers[i],
|
|
output_numb_fields=tmp_layers[i + 1],
|
|
n_modes=n_modes[i],
|
|
activation=self._functions[i])
|
|
)
|
|
self._layers = nn.Sequential(*self._layers)
|
|
|
|
# 5. Padding values for spectral conv
|
|
if isinstance(padding, int):
|
|
padding = [padding] * dimensions
|
|
self._ipad = [-pad if pad > 0 else None for pad in padding[:dimensions]]
|
|
self._padding_type = padding_type
|
|
self._pad = [val for pair in zip([0]*dimensions, padding) for val in pair]
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
"""
|
|
Forward computation for Fourier Neural Operator. It performs a
|
|
lifting of the input by the ``lifting_net``. Then different layers
|
|
of Fourier Blocks are applied. Finally the output is projected
|
|
to the final dimensionality by the ``projecting_net``.
|
|
|
|
:param torch.Tensor x: The input tensor for fourier block, depending on
|
|
``dimension`` in the initialization. In particular it is expected
|
|
* 1D tensors: ``[batch, X, channels]``
|
|
* 2D tensors: ``[batch, X, Y, channels]``
|
|
* 3D tensors: ``[batch, X, Y, Z, channels]``
|
|
:return: The output tensor obtained from the FNO.
|
|
:rtype: torch.Tensor
|
|
"""
|
|
|
|
# lifting the input in higher dimensional space
|
|
x = self._lifting_net(x)
|
|
|
|
# permuting the input [batch, channels, x, y, ...]
|
|
permutation_idx = [0, x.ndim-1, *[i for i in range(1, x.ndim-1)]]
|
|
x = x.permute(permutation_idx)
|
|
|
|
# padding the input
|
|
x = torch.nn.functional.pad(x, pad=self._pad, mode=self._padding_type)
|
|
|
|
# apply fourier layers
|
|
x = self._layers(x)
|
|
|
|
# remove padding
|
|
idxs = [slice(None), slice(None)] + [slice(pad) for pad in self._ipad]
|
|
x = x[idxs]
|
|
|
|
# permuting back [batch, x, y, ..., channels]
|
|
permutation_idx = [0, *[i for i in range(2, x.ndim)], 1]
|
|
x = x.permute(permutation_idx)
|
|
|
|
# apply projecting operator and return
|
|
return self._projecting_net(x)
|