Neural Operator fix and addition

* Building FNO for 1D/2D/3D data
* Fixing bug in trunk/branch net in DeepONEt
* Fixing type check bug in spectral conv
* Adding tests for FNO
* Fixing bug in Fourier Layer (conv1d/2d/3d)
This commit is contained in:
Dario Coscia
2023-09-09 22:09:34 +02:00
committed by Nicola Demo
parent 83ecdb0eab
commit 603f56d264
6 changed files with 315 additions and 33 deletions

View File

@@ -2,8 +2,10 @@ __all__ = [
'FeedForward',
'MultiFeedForward',
'DeepONet',
'FNO',
]
from .feed_forward import FeedForward
from .multi_feed_forward import MultiFeedForward
from .deeponet import DeepONet
from .fno import FNO

View File

@@ -108,8 +108,10 @@ class DeepONet(torch.nn.Module):
check_consistency(trunk_net, torch.nn.Module)
# check trunk branch nets consistency
trunk_out_dim = trunk_net.layers[-1].out_features
branch_out_dim = branch_net.layers[-1].out_features
input_trunk = torch.rand(10, len(input_indeces_trunk_net))
input_branch = torch.rand(10, len(input_indeces_branch_net))
trunk_out_dim = trunk_net(input_trunk).shape[-1]
branch_out_dim = branch_net(input_branch).shape[-1]
if trunk_out_dim != branch_out_dim:
raise ValueError('Branch and trunk networks have not the same '
'output dimension.')

157
pina/model/fno.py Normal file
View File

@@ -0,0 +1,157 @@
import torch
import torch.nn as nn
from ..utils import check_consistency
from .layers.fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D
class FNO(torch.nn.Module):
"""
The PINA implementation of Fourier Neural Operator network.
Fourier Neural Operator (FNO) is a general architecture for learning Operators.
Unlike traditional machine learning methods FNO is designed to map
entire functions to other functions. It can be trained both with
Supervised learning strategies. FNO does global convolution by performing the
operation on the Fourier space.
.. seealso::
**Original reference**: Li, Z., Kovachki, N., Azizzadenesheli, K., Liu, B.,
Bhattacharya, K., Stuart, A., & Anandkumar, A. (2020). "Fourier neural operator for
parametric partial differential equations".
DOI: `arXiv preprint arXiv:2010.08895.
<https://arxiv.org/abs/2010.08895>`_
"""
def __init__(self,
lifting_net,
projecting_net,
n_modes,
dimensions = 3,
padding = 8,
padding_type = "constant",
inner_size = 20,
n_layers = 2,
func = nn.Tanh,
layers = None):
super().__init__()
# check type consistency
check_consistency(lifting_net, nn.Module)
check_consistency(projecting_net, nn.Module)
check_consistency(dimensions, int)
check_consistency(padding, int)
check_consistency(padding_type, str)
check_consistency(inner_size, int)
check_consistency(n_layers, int)
check_consistency(func, nn.Module, subclass=True)
if layers is not None:
if isinstance(layers, (tuple, list)):
check_consistency(layers, int)
else:
raise ValueError('layers must be tuple or list of int.')
if not isinstance(n_modes, (list, tuple, int)):
raise ValueError('n_modes must be a int or list or tuple of valid modes.'
' More information on the official documentation.')
# assign variables
# TODO check input lifting net and input projecting net
self._lifting_net = lifting_net
self._projecting_net = projecting_net
self._padding = padding
# initialize fourier layer for each dimension
if dimensions == 1:
fourier_layer = FourierBlock1D
elif dimensions == 2:
fourier_layer = FourierBlock2D
elif dimensions == 3:
fourier_layer = FourierBlock3D
else:
NotImplementedError('FNO implemented only for 1D/2D/3D data.')
# Here we build the FNO by stacking Fourier Blocks
# 1. Assign output dimensions for each FNO layer
if layers is None:
layers = [inner_size] * n_layers
# 2. Assign activation functions for each FNO layer
if isinstance(func, list):
if len(layers) != len(func):
raise RuntimeError('Uncosistent number of layers and functions.')
self._functions = func
else:
self._functions = [func for _ in range(len(layers))]
# 3. Assign modes functions for each FNO layer
if isinstance(n_modes, list):
if all(isinstance(i, list) for i in n_modes) and len(layers) != len(n_modes):
raise RuntimeError('Uncosistent number of layers and functions.')
elif all(isinstance(i, int) for i in n_modes):
n_modes = [n_modes] * len(layers)
else:
n_modes = [n_modes] * len(layers)
# 4. Build the FNO network
tmp_layers = layers.copy()
out_feats = lifting_net(torch.rand(10, dimensions)).shape[-1]
tmp_layers.insert(0, out_feats)
self._layers = []
for i in range(len(tmp_layers) - 1):
self._layers.append(
fourier_layer(input_numb_fields=tmp_layers[i],
output_numb_fields=tmp_layers[i + 1],
n_modes=n_modes[i],
activation=self._functions[i])
)
self._layers = nn.Sequential(*self._layers)
# 5. Padding values for spectral conv
if isinstance(padding, int):
padding = [padding] * dimensions
self._ipad = [-pad if pad > 0 else None for pad in padding[:dimensions]]
self._padding_type = padding_type
self._pad = [val for pair in zip([0]*dimensions, padding) for val in pair]
def forward(self, x):
"""
Forward computation for Fourier Neural Operator. It performs a
lifting of the input by the ``lifting_net``. Then different layers
of Fourier Blocks are applied. Finally the output is projected
to the final dimensionality by the ``projecting_net``.
:param torch.Tensor x: The input tensor for fourier block, depending on
``dimension`` in the initialization. In particular it is expected
* 1D tensors: ``[batch, X, channels]``
* 2D tensors: ``[batch, X, Y, channels]``
* 3D tensors: ``[batch, X, Y, Z, channels]``
:return: The output tensor obtained from the FNO.
:rtype: torch.Tensor
"""
# lifting the input in higher dimensional space
x = self._lifting_net(x)
# permuting the input [batch, channels, x, y, ...]
permutation_idx = [0, x.ndim-1, *[i for i in range(1, x.ndim-1)]]
x = x.permute(permutation_idx)
# padding the input
x = torch.nn.functional.pad(x, pad=self._pad, mode=self._padding_type)
# apply fourier layers
x = self._layers(x)
# remove padding
idxs = [slice(None), slice(None)] + [slice(pad) for pad in self._ipad]
x = x[idxs]
# permuting back [batch, x, y, ..., channels]
permutation_idx = [0, *[i for i in range(2, x.ndim)], 1]
x = x.permute(permutation_idx)
# apply projecting operator and return
return self._projecting_net(x)

View File

@@ -20,7 +20,6 @@ class FourierBlock1D(nn.Module):
<https://arxiv.org/abs/2010.08895.pdf>`_.
"""
def __init__(self, input_numb_fields, output_numb_fields, n_modes, activation=torch.nn.Tanh):
super().__init__()
"""
@@ -50,9 +49,19 @@ class FourierBlock1D(nn.Module):
self._activation = activation()
self._linear = nn.Conv1d(input_numb_fields, output_numb_fields, 1)
def forward(self, x):
"""
Forward computation for Fourier Block. It performs a spectral
convolution and a linear transformation of the input and sum the
results.
:param x: The input tensor for fourier block, expect of size
``[batch, input_numb_fields, x]``.
:type x: torch.Tensor
:return: The output tensor obtained from the
fourier block of size ``[batch, output_numb_fields, x]``.
:rtype: torch.Tensor
"""
return self._activation(self._spectral_conv(x) + self._linear(x))
@@ -71,7 +80,6 @@ class FourierBlock2D(nn.Module):
<https://arxiv.org/abs/2010.08895.pdf>`_.
"""
def __init__(self, input_numb_fields, output_numb_fields, n_modes, activation=torch.nn.Tanh):
"""
PINA implementation of Fourier block two dimensions. The module computes
@@ -100,13 +108,22 @@ class FourierBlock2D(nn.Module):
output_numb_fields=output_numb_fields,
n_modes=n_modes)
self._activation = activation()
self._linear = nn.Conv1d(input_numb_fields, output_numb_fields, 1)
self._linear = nn.Conv2d(input_numb_fields, output_numb_fields, 1)
def forward(self, x):
shape_x = x.shape
ln = self._linear(x.view(shape_x[0], shape_x[1], -1))
ln = ln.view(shape_x[0], -1, shape_x[2], shape_x[3])
return self._activation(self._spectral_conv(x) + ln)
"""
Forward computation for Fourier Block. It performs a spectral
convolution and a linear transformation of the input and sum the
results.
:param x: The input tensor for fourier block, expect of size
``[batch, input_numb_fields, x, y]``.
:type x: torch.Tensor
:return: The output tensor obtained from the
fourier block of size ``[batch, output_numb_fields, x, y, z]``.
:rtype: torch.Tensor
"""
return self._activation(self._spectral_conv(x) + self._linear(x))
class FourierBlock3D(nn.Module):
@@ -124,7 +141,6 @@ class FourierBlock3D(nn.Module):
<https://arxiv.org/abs/2010.08895.pdf>`_.
"""
def __init__(self, input_numb_fields, output_numb_fields, n_modes, activation=torch.nn.Tanh):
"""
PINA implementation of Fourier block three dimensions. The module computes
@@ -154,10 +170,19 @@ class FourierBlock3D(nn.Module):
output_numb_fields=output_numb_fields,
n_modes=n_modes)
self._activation = activation()
self._linear = nn.Conv1d(input_numb_fields, output_numb_fields, 1)
self._linear = nn.Conv3d(input_numb_fields, output_numb_fields, 1)
def forward(self, x):
shape_x = x.shape
ln = self._linear(x.view(shape_x[0], shape_x[1], -1))
ln = ln.view(shape_x[0], -1, shape_x[2], shape_x[3], shape_x[4])
return self._activation(self._spectral_conv(x) + ln)
"""
Forward computation for Fourier Block. It performs a spectral
convolution and a linear transformation of the input and sum the
results.
:param x: The input tensor for fourier block, expect of size
``[batch, input_numb_fields, x, y, z]``.
:type x: torch.Tensor
:return: The output tensor obtained from the
fourier block of size ``[batch, output_numb_fields, x, y, z]``.
:rtype: torch.Tensor
"""
return self._activation(self._spectral_conv(x) + self._linear(x))

View File

@@ -115,15 +115,19 @@ class SpectralConvBlock2D(nn.Module):
# check type consistency
check_consistency(input_numb_fields, int)
check_consistency(output_numb_fields, int)
if not isinstance(n_modes, (tuple, list)):
raise ValueError('expected n_modes to be a list or tuple of len two, '
'with each entry corresponding to the number of modes '
'for each dimension ')
if len(n_modes) != 2:
raise ValueError('expected n_modes to be a list or tuple of len two, '
'with each entry corresponding to the number of modes '
'for each dimension ')
check_consistency(n_modes, int)
if isinstance(n_modes, (tuple, list)):
if len(n_modes) != 2:
raise ValueError('Expected n_modes to be a list or tuple of len two, '
'with each entry corresponding to the number of modes '
'for each dimension ')
elif isinstance(n_modes, int):
n_modes = [n_modes]*2
else:
raise ValueError('Expected n_modes to be a list or tuple of len two, '
'with each entry corresponding to the number of modes '
'for each dimension; or an int value representing the '
'number of modes for all dimensions')
# assign variables
@@ -222,15 +226,19 @@ class SpectralConvBlock3D(nn.Module):
# check type consistency
check_consistency(input_numb_fields, int)
check_consistency(output_numb_fields, int)
if not isinstance(n_modes, (tuple, list)):
raise ValueError('expected n_modes to be a list or tuple of len three, '
'with each entry corresponding to the number of modes '
'for each dimension ')
if len(n_modes) != 3:
raise ValueError('expected n_modes to be a list or tuple of len three, '
'with each entry corresponding to the number of modes '
'for each dimension ')
check_consistency(n_modes, int)
if isinstance(n_modes, (tuple, list)):
if len(n_modes) != 3:
raise ValueError('Expected n_modes to be a list or tuple of len three, '
'with each entry corresponding to the number of modes '
'for each dimension ')
elif isinstance(n_modes, int):
n_modes = [n_modes]*3
else:
raise ValueError('Expected n_modes to be a list or tuple of len three, '
'with each entry corresponding to the number of modes '
'for each dimension; or an int value representing the '
'number of modes for all dimensions')
# assign variables
self._modes = n_modes