From 83ecdb0eab8acd8b74bb1d3f17890a7347503997 Mon Sep 17 00:00:00 2001 From: Dario Coscia Date: Thu, 7 Sep 2023 18:18:28 +0200 Subject: [PATCH] Fourier Block and minor fixes * Adding fourier block 1d/2d/3d * Adding docs to SpectralConvBlock1D/2D/3D and to FourierBlock1D/2D/3D * Adding tests for fourier block --- pina/model/layers/__init__.py | 6 +- pina/model/layers/fourier.py | 147 +++++++++++++++++++++++++++++- pina/model/layers/spectral.py | 56 ++++++++---- tests/test_layers/test_fourier.py | 43 +++++++++ 4 files changed, 228 insertions(+), 24 deletions(-) create mode 100644 tests/test_layers/test_fourier.py diff --git a/pina/model/layers/__init__.py b/pina/model/layers/__init__.py index 71d781c..9a2fd2a 100644 --- a/pina/model/layers/__init__.py +++ b/pina/model/layers/__init__.py @@ -3,9 +3,13 @@ __all__ = [ 'ResidualBlock', 'SpectralConvBlock1D', 'SpectralConvBlock2D', - 'SpectralConvBlock3D' + 'SpectralConvBlock3D', + 'FourierBlock1D', + 'FourierBlock2D', + 'FourierBlock3D', ] from .convolution_2d import ContinuousConvBlock from .residual import ResidualBlock from .spectral import SpectralConvBlock1D, SpectralConvBlock2D, SpectralConvBlock3D +from .fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D diff --git a/pina/model/layers/fourier.py b/pina/model/layers/fourier.py index 391dc96..6cfd9ad 100644 --- a/pina/model/layers/fourier.py +++ b/pina/model/layers/fourier.py @@ -2,9 +2,14 @@ import torch import torch.nn as nn from ...utils import check_consistency +from pina.model.layers import SpectralConvBlock1D, SpectralConvBlock2D, SpectralConvBlock3D -class FourierBlock(nn.Module): - """Fourier block base class. Implementation of a fourier block. + +class FourierBlock1D(nn.Module): + """ + Fourier block implementation for three dimensional + input tensor. The combination of Fourier blocks + make up the Fourier Neural Operator .. seealso:: @@ -16,9 +21,143 @@ class FourierBlock(nn.Module): """ - def __init__(self): + def __init__(self, input_numb_fields, output_numb_fields, n_modes, activation=torch.nn.Tanh): super().__init__() + """ + PINA implementation of Fourier block one dimension. The module computes + the spectral convolution of the input with a linear kernel in the + fourier space, and then it maps the input back to the physical + space. The output is then added to a Linear tranformation of the + input in the physical space. Finally an activation function is + applied to the output. + + The block expects an input of size ``[batch, input_numb_fields, N]`` + and returns an output of size ``[batch, output_numb_fields, N]``. + + :param int input_numb_fields: The number of channels for the input. + :param int output_numb_fields: The number of channels for the output. + :param list | tuple n_modes: Number of modes to select for each dimension. + It must be at most equal to the ``floor(N/2)+1``. + :param torch.nn.Module activation: The activation function. + """ + # check type consistency + check_consistency(activation(), nn.Module) + + # assign variables + self._spectral_conv = SpectralConvBlock1D(input_numb_fields=input_numb_fields, + output_numb_fields=output_numb_fields, + n_modes=n_modes) + self._activation = activation() + self._linear = nn.Conv1d(input_numb_fields, output_numb_fields, 1) + def forward(self, x): - pass \ No newline at end of file + return self._activation(self._spectral_conv(x) + self._linear(x)) + + +class FourierBlock2D(nn.Module): + """ + Fourier block implementation for two dimensional + input tensor. The combination of Fourier blocks + make up the Fourier Neural Operator + + .. seealso:: + + **Original reference**: Li, Zongyi, et al. + "Fourier neural operator for parametric partial + differential equations." arXiv preprint + arXiv:2010.08895 (2020) + `_. + + """ + + def __init__(self, input_numb_fields, output_numb_fields, n_modes, activation=torch.nn.Tanh): + """ + PINA implementation of Fourier block two dimensions. The module computes + the spectral convolution of the input with a linear kernel in the + fourier space, and then it maps the input back to the physical + space. The output is then added to a Linear tranformation of the + input in the physical space. Finally an activation function is + applied to the output. + + The block expects an input of size ``[batch, input_numb_fields, Nx, Ny]`` + and returns an output of size ``[batch, output_numb_fields, Nx, Ny]``. + + :param int input_numb_fields: The number of channels for the input. + :param int output_numb_fields: The number of channels for the output. + :param list | tuple n_modes: Number of modes to select for each dimension. + It must be at most equal to the ``floor(Nx/2)+1`` and ``floor(Ny/2)+1``. + :param torch.nn.Module activation: The activation function. + """ + super().__init__() + + # check type consistency + check_consistency(activation(), nn.Module) + + # assign variables + self._spectral_conv = SpectralConvBlock2D(input_numb_fields=input_numb_fields, + output_numb_fields=output_numb_fields, + n_modes=n_modes) + self._activation = activation() + self._linear = nn.Conv1d(input_numb_fields, output_numb_fields, 1) + + def forward(self, x): + shape_x = x.shape + ln = self._linear(x.view(shape_x[0], shape_x[1], -1)) + ln = ln.view(shape_x[0], -1, shape_x[2], shape_x[3]) + return self._activation(self._spectral_conv(x) + ln) + + +class FourierBlock3D(nn.Module): + """ + Fourier block implementation for three dimensional + input tensor. The combination of Fourier blocks + make up the Fourier Neural Operator + + .. seealso:: + + **Original reference**: Li, Zongyi, et al. + "Fourier neural operator for parametric partial + differential equations." arXiv preprint + arXiv:2010.08895 (2020) + `_. + + """ + + def __init__(self, input_numb_fields, output_numb_fields, n_modes, activation=torch.nn.Tanh): + """ + PINA implementation of Fourier block three dimensions. The module computes + the spectral convolution of the input with a linear kernel in the + fourier space, and then it maps the input back to the physical + space. The output is then added to a Linear tranformation of the + input in the physical space. Finally an activation function is + applied to the output. + + The block expects an input of size ``[batch, input_numb_fields, Nx, Ny, Nz]`` + and returns an output of size ``[batch, output_numb_fields, Nx, Ny, Nz]``. + + :param int input_numb_fields: The number of channels for the input. + :param int output_numb_fields: The number of channels for the output. + :param list | tuple n_modes: Number of modes to select for each dimension. + It must be at most equal to the ``floor(Nx/2)+1``, ``floor(Ny/2)+1`` + and ``floor(Nz/2)+1``. + :param torch.nn.Module activation: The activation function. + """ + super().__init__() + + # check type consistency + check_consistency(activation(), nn.Module) + + # assign variables + self._spectral_conv = SpectralConvBlock3D(input_numb_fields=input_numb_fields, + output_numb_fields=output_numb_fields, + n_modes=n_modes) + self._activation = activation() + self._linear = nn.Conv1d(input_numb_fields, output_numb_fields, 1) + + def forward(self, x): + shape_x = x.shape + ln = self._linear(x.view(shape_x[0], shape_x[1], -1)) + ln = ln.view(shape_x[0], -1, shape_x[2], shape_x[3], shape_x[4]) + return self._activation(self._spectral_conv(x) + ln) diff --git a/pina/model/layers/spectral.py b/pina/model/layers/spectral.py index 465ff18..2832e7a 100644 --- a/pina/model/layers/spectral.py +++ b/pina/model/layers/spectral.py @@ -12,14 +12,18 @@ class SpectralConvBlock1D(nn.Module): def __init__(self, input_numb_fields, output_numb_fields, n_modes): """ - TODO + PINA implementation of spectral convolution. The module computes + the spectral convolution of the input with a linear kernel in the + fourier space, and then it maps the input back to the physical + space. - :param input_numb_fields: _description_ - :type input_numb_fields: _type_ - :param output_numb_fields: _description_ - :type output_numb_fields: _type_ - :param n_modes: _description_ - :type n_modes: _type_ + The block expects an input of size ``[batch, input_numb_fields, N]`` + and returns an output of size ``[batch, output_numb_fields, N]``. + + :param int input_numb_fields: The number of channels for the input. + :param int output_numb_fields: The number of channels for the output. + :param int n_modes: Number of modes to select, it must be at most equal + to the ``floor(N/2)+1``. """ super().__init__() @@ -69,9 +73,6 @@ class SpectralConvBlock1D(nn.Module): """ batch_size = x.shape[0] - # if x.shape[-1] // 2 + 1 < self._modes: - # raise RuntimeError('Number of modes is too high, decrease number of modes.') - # Compute Fourier transform of the input x_ft = torch.fft.rfft(x) @@ -95,6 +96,20 @@ class SpectralConvBlock2D(nn.Module): """ def __init__(self, input_numb_fields, output_numb_fields, n_modes): + """ + PINA implementation of spectral convolution. The module computes + the spectral convolution of the input with a linear kernel in the + fourier space, and then it maps the input back to the physical + space. + + The block expects an input of size ``[batch, input_numb_fields, Nx, Ny]`` + and returns an output of size ``[batch, output_numb_fields, Nx, Ny]``. + + :param int input_numb_fields: The number of channels for the input. + :param int output_numb_fields: The number of channels for the output. + :param list | tuple n_modes: Number of modes to select for each dimension. + It must be at most equal to the ``floor(Nx/2)+1`` and ``floor(Ny/2)+1``. + """ super().__init__() # check type consistency @@ -188,16 +203,19 @@ class SpectralConvBlock3D(nn.Module): def __init__(self, input_numb_fields, output_numb_fields, n_modes): """ - TODO + PINA implementation of spectral convolution. The module computes + the spectral convolution of the input with a linear kernel in the + fourier space, and then it maps the input back to the physical + space. - :param input_numb_fields: _description_ - :type input_numb_fields: _type_ - :param output_numb_fields: _description_ - :type output_numb_fields: _type_ - :param n_modes: _description_ - :type n_modes: _type_ - :raises ValueError: _description_ - :raises ValueError: _description_ + The block expects an input of size ``[batch, input_numb_fields, Nx, Ny, Nz]`` + and returns an output of size ``[batch, output_numb_fields, Nx, Ny, Nz]``. + + :param int input_numb_fields: The number of channels for the input. + :param int output_numb_fields: The number of channels for the output. + :param list | tuple n_modes: Number of modes to select for each dimension. + It must be at most equal to the ``floor(Nx/2)+1``, ``floor(Ny/2)+1`` + and ``floor(Nz/2)+1``. """ super().__init__() diff --git a/tests/test_layers/test_fourier.py b/tests/test_layers/test_fourier.py new file mode 100644 index 0000000..af9425a --- /dev/null +++ b/tests/test_layers/test_fourier.py @@ -0,0 +1,43 @@ +from pina.model.layers import FourierBlock1D, FourierBlock2D, FourierBlock3D +import torch + +input_numb_fields = 3 +output_numb_fields = 4 +batch = 5 + +def test_constructor_1d(): + FourierBlock1D(input_numb_fields=input_numb_fields, + output_numb_fields=output_numb_fields, + n_modes=5) + +def test_forward_1d(): + sconv = FourierBlock1D(input_numb_fields=input_numb_fields, + output_numb_fields=output_numb_fields, + n_modes=4) + x = torch.rand(batch, input_numb_fields, 10) + sconv(x) + + +def test_constructor_2d(): + FourierBlock2D(input_numb_fields=input_numb_fields, + output_numb_fields=output_numb_fields, + n_modes=[5, 4]) + +def test_forward_2d(): + sconv = FourierBlock2D(input_numb_fields=input_numb_fields, + output_numb_fields=output_numb_fields, + n_modes=[5, 4]) + x = torch.rand(batch, input_numb_fields, 10, 10) + sconv(x) + +def test_constructor_3d(): + FourierBlock3D(input_numb_fields=input_numb_fields, + output_numb_fields=output_numb_fields, + n_modes=[5, 4, 4]) + +def test_forward_3d(): + sconv = FourierBlock3D(input_numb_fields=input_numb_fields, + output_numb_fields=output_numb_fields, + n_modes=[5, 4, 4]) + x = torch.rand(batch, input_numb_fields, 10, 10, 10) + sconv(x)