From 603f56d264fc7a5ae33320730d5fb11920646145 Mon Sep 17 00:00:00 2001 From: Dario Coscia Date: Sat, 9 Sep 2023 22:09:34 +0200 Subject: [PATCH] Neural Operator fix and addition * Building FNO for 1D/2D/3D data * Fixing bug in trunk/branch net in DeepONEt * Fixing type check bug in spectral conv * Adding tests for FNO * Fixing bug in Fourier Layer (conv1d/2d/3d) --- pina/model/__init__.py | 2 + pina/model/deeponet.py | 6 +- pina/model/fno.py | 157 ++++++++++++++++++++++++++++++++++ pina/model/layers/fourier.py | 55 ++++++++---- pina/model/layers/spectral.py | 40 +++++---- tests/test_model/test_fno.py | 88 +++++++++++++++++++ 6 files changed, 315 insertions(+), 33 deletions(-) create mode 100644 pina/model/fno.py create mode 100644 tests/test_model/test_fno.py diff --git a/pina/model/__init__.py b/pina/model/__init__.py index 81fbc09..a1ea5b2 100644 --- a/pina/model/__init__.py +++ b/pina/model/__init__.py @@ -2,8 +2,10 @@ __all__ = [ 'FeedForward', 'MultiFeedForward', 'DeepONet', + 'FNO', ] from .feed_forward import FeedForward from .multi_feed_forward import MultiFeedForward from .deeponet import DeepONet +from .fno import FNO diff --git a/pina/model/deeponet.py b/pina/model/deeponet.py index 97f29f8..26e5bdf 100644 --- a/pina/model/deeponet.py +++ b/pina/model/deeponet.py @@ -108,8 +108,10 @@ class DeepONet(torch.nn.Module): check_consistency(trunk_net, torch.nn.Module) # check trunk branch nets consistency - trunk_out_dim = trunk_net.layers[-1].out_features - branch_out_dim = branch_net.layers[-1].out_features + input_trunk = torch.rand(10, len(input_indeces_trunk_net)) + input_branch = torch.rand(10, len(input_indeces_branch_net)) + trunk_out_dim = trunk_net(input_trunk).shape[-1] + branch_out_dim = branch_net(input_branch).shape[-1] if trunk_out_dim != branch_out_dim: raise ValueError('Branch and trunk networks have not the same ' 'output dimension.') diff --git a/pina/model/fno.py b/pina/model/fno.py new file mode 100644 index 0000000..d90c380 --- /dev/null +++ b/pina/model/fno.py @@ -0,0 +1,157 @@ +import torch +import torch.nn as nn +from ..utils import check_consistency +from .layers.fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D + + +class FNO(torch.nn.Module): + """ + The PINA implementation of Fourier Neural Operator network. + + Fourier Neural Operator (FNO) is a general architecture for learning Operators. + Unlike traditional machine learning methods FNO is designed to map + entire functions to other functions. It can be trained both with + Supervised learning strategies. FNO does global convolution by performing the + operation on the Fourier space. + + .. seealso:: + + **Original reference**: Li, Z., Kovachki, N., Azizzadenesheli, K., Liu, B., + Bhattacharya, K., Stuart, A., & Anandkumar, A. (2020). "Fourier neural operator for + parametric partial differential equations". + DOI: `arXiv preprint arXiv:2010.08895. + `_ + """ + def __init__(self, + lifting_net, + projecting_net, + n_modes, + dimensions = 3, + padding = 8, + padding_type = "constant", + inner_size = 20, + n_layers = 2, + func = nn.Tanh, + layers = None): + super().__init__() + + # check type consistency + check_consistency(lifting_net, nn.Module) + check_consistency(projecting_net, nn.Module) + check_consistency(dimensions, int) + check_consistency(padding, int) + check_consistency(padding_type, str) + check_consistency(inner_size, int) + check_consistency(n_layers, int) + check_consistency(func, nn.Module, subclass=True) + if layers is not None: + if isinstance(layers, (tuple, list)): + check_consistency(layers, int) + else: + raise ValueError('layers must be tuple or list of int.') + if not isinstance(n_modes, (list, tuple, int)): + raise ValueError('n_modes must be a int or list or tuple of valid modes.' + ' More information on the official documentation.') + + # assign variables + # TODO check input lifting net and input projecting net + self._lifting_net = lifting_net + self._projecting_net = projecting_net + self._padding = padding + + # initialize fourier layer for each dimension + if dimensions == 1: + fourier_layer = FourierBlock1D + elif dimensions == 2: + fourier_layer = FourierBlock2D + elif dimensions == 3: + fourier_layer = FourierBlock3D + else: + NotImplementedError('FNO implemented only for 1D/2D/3D data.') + + # Here we build the FNO by stacking Fourier Blocks + + # 1. Assign output dimensions for each FNO layer + if layers is None: + layers = [inner_size] * n_layers + + # 2. Assign activation functions for each FNO layer + if isinstance(func, list): + if len(layers) != len(func): + raise RuntimeError('Uncosistent number of layers and functions.') + self._functions = func + else: + self._functions = [func for _ in range(len(layers))] + + # 3. Assign modes functions for each FNO layer + if isinstance(n_modes, list): + if all(isinstance(i, list) for i in n_modes) and len(layers) != len(n_modes): + raise RuntimeError('Uncosistent number of layers and functions.') + elif all(isinstance(i, int) for i in n_modes): + n_modes = [n_modes] * len(layers) + else: + n_modes = [n_modes] * len(layers) + + # 4. Build the FNO network + tmp_layers = layers.copy() + out_feats = lifting_net(torch.rand(10, dimensions)).shape[-1] + tmp_layers.insert(0, out_feats) + + self._layers = [] + for i in range(len(tmp_layers) - 1): + self._layers.append( + fourier_layer(input_numb_fields=tmp_layers[i], + output_numb_fields=tmp_layers[i + 1], + n_modes=n_modes[i], + activation=self._functions[i]) + ) + self._layers = nn.Sequential(*self._layers) + + # 5. Padding values for spectral conv + if isinstance(padding, int): + padding = [padding] * dimensions + self._ipad = [-pad if pad > 0 else None for pad in padding[:dimensions]] + self._padding_type = padding_type + self._pad = [val for pair in zip([0]*dimensions, padding) for val in pair] + + + + def forward(self, x): + """ + Forward computation for Fourier Neural Operator. It performs a + lifting of the input by the ``lifting_net``. Then different layers + of Fourier Blocks are applied. Finally the output is projected + to the final dimensionality by the ``projecting_net``. + + :param torch.Tensor x: The input tensor for fourier block, depending on + ``dimension`` in the initialization. In particular it is expected + * 1D tensors: ``[batch, X, channels]`` + * 2D tensors: ``[batch, X, Y, channels]`` + * 3D tensors: ``[batch, X, Y, Z, channels]`` + :return: The output tensor obtained from the FNO. + :rtype: torch.Tensor + """ + + # lifting the input in higher dimensional space + x = self._lifting_net(x) + + # permuting the input [batch, channels, x, y, ...] + permutation_idx = [0, x.ndim-1, *[i for i in range(1, x.ndim-1)]] + x = x.permute(permutation_idx) + + # padding the input + x = torch.nn.functional.pad(x, pad=self._pad, mode=self._padding_type) + + # apply fourier layers + x = self._layers(x) + + # remove padding + idxs = [slice(None), slice(None)] + [slice(pad) for pad in self._ipad] + x = x[idxs] + + # permuting back [batch, x, y, ..., channels] + permutation_idx = [0, *[i for i in range(2, x.ndim)], 1] + x = x.permute(permutation_idx) + + # apply projecting operator and return + return self._projecting_net(x) diff --git a/pina/model/layers/fourier.py b/pina/model/layers/fourier.py index 6cfd9ad..4333276 100644 --- a/pina/model/layers/fourier.py +++ b/pina/model/layers/fourier.py @@ -20,7 +20,6 @@ class FourierBlock1D(nn.Module): `_. """ - def __init__(self, input_numb_fields, output_numb_fields, n_modes, activation=torch.nn.Tanh): super().__init__() """ @@ -50,9 +49,19 @@ class FourierBlock1D(nn.Module): self._activation = activation() self._linear = nn.Conv1d(input_numb_fields, output_numb_fields, 1) - - def forward(self, x): + """ + Forward computation for Fourier Block. It performs a spectral + convolution and a linear transformation of the input and sum the + results. + + :param x: The input tensor for fourier block, expect of size + ``[batch, input_numb_fields, x]``. + :type x: torch.Tensor + :return: The output tensor obtained from the + fourier block of size ``[batch, output_numb_fields, x]``. + :rtype: torch.Tensor + """ return self._activation(self._spectral_conv(x) + self._linear(x)) @@ -71,7 +80,6 @@ class FourierBlock2D(nn.Module): `_. """ - def __init__(self, input_numb_fields, output_numb_fields, n_modes, activation=torch.nn.Tanh): """ PINA implementation of Fourier block two dimensions. The module computes @@ -100,13 +108,22 @@ class FourierBlock2D(nn.Module): output_numb_fields=output_numb_fields, n_modes=n_modes) self._activation = activation() - self._linear = nn.Conv1d(input_numb_fields, output_numb_fields, 1) + self._linear = nn.Conv2d(input_numb_fields, output_numb_fields, 1) def forward(self, x): - shape_x = x.shape - ln = self._linear(x.view(shape_x[0], shape_x[1], -1)) - ln = ln.view(shape_x[0], -1, shape_x[2], shape_x[3]) - return self._activation(self._spectral_conv(x) + ln) + """ + Forward computation for Fourier Block. It performs a spectral + convolution and a linear transformation of the input and sum the + results. + + :param x: The input tensor for fourier block, expect of size + ``[batch, input_numb_fields, x, y]``. + :type x: torch.Tensor + :return: The output tensor obtained from the + fourier block of size ``[batch, output_numb_fields, x, y, z]``. + :rtype: torch.Tensor + """ + return self._activation(self._spectral_conv(x) + self._linear(x)) class FourierBlock3D(nn.Module): @@ -124,7 +141,6 @@ class FourierBlock3D(nn.Module): `_. """ - def __init__(self, input_numb_fields, output_numb_fields, n_modes, activation=torch.nn.Tanh): """ PINA implementation of Fourier block three dimensions. The module computes @@ -154,10 +170,19 @@ class FourierBlock3D(nn.Module): output_numb_fields=output_numb_fields, n_modes=n_modes) self._activation = activation() - self._linear = nn.Conv1d(input_numb_fields, output_numb_fields, 1) + self._linear = nn.Conv3d(input_numb_fields, output_numb_fields, 1) def forward(self, x): - shape_x = x.shape - ln = self._linear(x.view(shape_x[0], shape_x[1], -1)) - ln = ln.view(shape_x[0], -1, shape_x[2], shape_x[3], shape_x[4]) - return self._activation(self._spectral_conv(x) + ln) + """ + Forward computation for Fourier Block. It performs a spectral + convolution and a linear transformation of the input and sum the + results. + + :param x: The input tensor for fourier block, expect of size + ``[batch, input_numb_fields, x, y, z]``. + :type x: torch.Tensor + :return: The output tensor obtained from the + fourier block of size ``[batch, output_numb_fields, x, y, z]``. + :rtype: torch.Tensor + """ + return self._activation(self._spectral_conv(x) + self._linear(x)) diff --git a/pina/model/layers/spectral.py b/pina/model/layers/spectral.py index 2832e7a..e6a06f7 100644 --- a/pina/model/layers/spectral.py +++ b/pina/model/layers/spectral.py @@ -115,15 +115,19 @@ class SpectralConvBlock2D(nn.Module): # check type consistency check_consistency(input_numb_fields, int) check_consistency(output_numb_fields, int) - if not isinstance(n_modes, (tuple, list)): - raise ValueError('expected n_modes to be a list or tuple of len two, ' - 'with each entry corresponding to the number of modes ' - 'for each dimension ') - if len(n_modes) != 2: - raise ValueError('expected n_modes to be a list or tuple of len two, ' - 'with each entry corresponding to the number of modes ' - 'for each dimension ') check_consistency(n_modes, int) + if isinstance(n_modes, (tuple, list)): + if len(n_modes) != 2: + raise ValueError('Expected n_modes to be a list or tuple of len two, ' + 'with each entry corresponding to the number of modes ' + 'for each dimension ') + elif isinstance(n_modes, int): + n_modes = [n_modes]*2 + else: + raise ValueError('Expected n_modes to be a list or tuple of len two, ' + 'with each entry corresponding to the number of modes ' + 'for each dimension; or an int value representing the ' + 'number of modes for all dimensions') # assign variables @@ -222,15 +226,19 @@ class SpectralConvBlock3D(nn.Module): # check type consistency check_consistency(input_numb_fields, int) check_consistency(output_numb_fields, int) - if not isinstance(n_modes, (tuple, list)): - raise ValueError('expected n_modes to be a list or tuple of len three, ' - 'with each entry corresponding to the number of modes ' - 'for each dimension ') - if len(n_modes) != 3: - raise ValueError('expected n_modes to be a list or tuple of len three, ' - 'with each entry corresponding to the number of modes ' - 'for each dimension ') check_consistency(n_modes, int) + if isinstance(n_modes, (tuple, list)): + if len(n_modes) != 3: + raise ValueError('Expected n_modes to be a list or tuple of len three, ' + 'with each entry corresponding to the number of modes ' + 'for each dimension ') + elif isinstance(n_modes, int): + n_modes = [n_modes]*3 + else: + raise ValueError('Expected n_modes to be a list or tuple of len three, ' + 'with each entry corresponding to the number of modes ' + 'for each dimension; or an int value representing the ' + 'number of modes for all dimensions') # assign variables self._modes = n_modes diff --git a/tests/test_model/test_fno.py b/tests/test_model/test_fno.py new file mode 100644 index 0000000..7c2613c --- /dev/null +++ b/tests/test_model/test_fno.py @@ -0,0 +1,88 @@ +import torch +from pina.model import FNO + + +output_channels = 5 +batch_size = 15 +resolution = [30, 40, 50] +lifting_dim = 128 + + +def test_constructor(): + input_channels = 3 + lifting_net = torch.nn.Linear(input_channels, lifting_dim) + projecting_net = torch.nn.Linear(60, output_channels) + + # simple constructor + FNO(lifting_net=lifting_net, + projecting_net=projecting_net, + n_modes=5, + dimensions=3, + inner_size=60, + n_layers=5) + + # simple constructor with n_modes list + FNO(lifting_net=lifting_net, + projecting_net=projecting_net, + n_modes=[5, 3, 2], + dimensions=3, + inner_size=60, + n_layers=5) + + # simple constructor with n_modes list of list + FNO(lifting_net=lifting_net, + projecting_net=projecting_net, + n_modes=[[5, 3, 2], [5, 3, 2]], + dimensions=3, + inner_size=60, + n_layers=2) + + # simple constructor with n_modes list of list + projecting_net = torch.nn.Linear(50, output_channels) + FNO(lifting_net=lifting_net, + projecting_net=projecting_net, + n_modes=5, + dimensions=3, + layers=[50, 50]) + +def test_1d_forward(): + input_channels = 1 + input_ = torch.rand(batch_size, resolution[0], input_channels) + lifting_net = torch.nn.Linear(input_channels, lifting_dim) + projecting_net = torch.nn.Linear(60, output_channels) + fno = FNO(lifting_net=lifting_net, + projecting_net=projecting_net, + n_modes=5, + dimensions=1, + inner_size=60, + n_layers=2) + out = fno(input_) + assert out.shape == torch.Size([batch_size, resolution[0], output_channels]) + +def test_2d_forward(): + input_channels = 2 + input_ = torch.rand(batch_size, resolution[0], resolution[1], input_channels) + lifting_net = torch.nn.Linear(input_channels, lifting_dim) + projecting_net = torch.nn.Linear(60, output_channels) + fno = FNO(lifting_net=lifting_net, + projecting_net=projecting_net, + n_modes=5, + dimensions=2, + inner_size=60, + n_layers=2) + out = fno(input_) + assert out.shape == torch.Size([batch_size, resolution[0], resolution[1], output_channels]) + +def test_3d_forward(): + input_channels = 3 + input_ = torch.rand(batch_size, resolution[0], resolution[1], resolution[2], input_channels) + lifting_net = torch.nn.Linear(input_channels, lifting_dim) + projecting_net = torch.nn.Linear(60, output_channels) + fno = FNO(lifting_net=lifting_net, + projecting_net=projecting_net, + n_modes=5, + dimensions=3, + inner_size=60, + n_layers=2) + out = fno(input_) + assert out.shape == torch.Size([batch_size, resolution[0], resolution[1], resolution[2], output_channels]) \ No newline at end of file