Fourier Block and minor fixes

* Adding fourier block 1d/2d/3d
* Adding docs to SpectralConvBlock1D/2D/3D and to FourierBlock1D/2D/3D
* Adding tests for fourier block
This commit is contained in:
Dario Coscia
2023-09-07 18:18:28 +02:00
committed by Nicola Demo
parent 2bf42d5fea
commit 83ecdb0eab
4 changed files with 228 additions and 24 deletions

View File

@@ -3,9 +3,13 @@ __all__ = [
'ResidualBlock', 'ResidualBlock',
'SpectralConvBlock1D', 'SpectralConvBlock1D',
'SpectralConvBlock2D', 'SpectralConvBlock2D',
'SpectralConvBlock3D' 'SpectralConvBlock3D',
'FourierBlock1D',
'FourierBlock2D',
'FourierBlock3D',
] ]
from .convolution_2d import ContinuousConvBlock from .convolution_2d import ContinuousConvBlock
from .residual import ResidualBlock from .residual import ResidualBlock
from .spectral import SpectralConvBlock1D, SpectralConvBlock2D, SpectralConvBlock3D from .spectral import SpectralConvBlock1D, SpectralConvBlock2D, SpectralConvBlock3D
from .fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D

View File

@@ -2,9 +2,14 @@ import torch
import torch.nn as nn import torch.nn as nn
from ...utils import check_consistency from ...utils import check_consistency
from pina.model.layers import SpectralConvBlock1D, SpectralConvBlock2D, SpectralConvBlock3D
class FourierBlock(nn.Module):
"""Fourier block base class. Implementation of a fourier block. class FourierBlock1D(nn.Module):
"""
Fourier block implementation for three dimensional
input tensor. The combination of Fourier blocks
make up the Fourier Neural Operator
.. seealso:: .. seealso::
@@ -16,9 +21,143 @@ class FourierBlock(nn.Module):
""" """
def __init__(self): def __init__(self, input_numb_fields, output_numb_fields, n_modes, activation=torch.nn.Tanh):
super().__init__() super().__init__()
"""
PINA implementation of Fourier block one dimension. The module computes
the spectral convolution of the input with a linear kernel in the
fourier space, and then it maps the input back to the physical
space. The output is then added to a Linear tranformation of the
input in the physical space. Finally an activation function is
applied to the output.
The block expects an input of size ``[batch, input_numb_fields, N]``
and returns an output of size ``[batch, output_numb_fields, N]``.
:param int input_numb_fields: The number of channels for the input.
:param int output_numb_fields: The number of channels for the output.
:param list | tuple n_modes: Number of modes to select for each dimension.
It must be at most equal to the ``floor(N/2)+1``.
:param torch.nn.Module activation: The activation function.
"""
# check type consistency
check_consistency(activation(), nn.Module)
# assign variables
self._spectral_conv = SpectralConvBlock1D(input_numb_fields=input_numb_fields,
output_numb_fields=output_numb_fields,
n_modes=n_modes)
self._activation = activation()
self._linear = nn.Conv1d(input_numb_fields, output_numb_fields, 1)
def forward(self, x): def forward(self, x):
pass return self._activation(self._spectral_conv(x) + self._linear(x))
class FourierBlock2D(nn.Module):
"""
Fourier block implementation for two dimensional
input tensor. The combination of Fourier blocks
make up the Fourier Neural Operator
.. seealso::
**Original reference**: Li, Zongyi, et al.
"Fourier neural operator for parametric partial
differential equations." arXiv preprint
arXiv:2010.08895 (2020)
<https://arxiv.org/abs/2010.08895.pdf>`_.
"""
def __init__(self, input_numb_fields, output_numb_fields, n_modes, activation=torch.nn.Tanh):
"""
PINA implementation of Fourier block two dimensions. The module computes
the spectral convolution of the input with a linear kernel in the
fourier space, and then it maps the input back to the physical
space. The output is then added to a Linear tranformation of the
input in the physical space. Finally an activation function is
applied to the output.
The block expects an input of size ``[batch, input_numb_fields, Nx, Ny]``
and returns an output of size ``[batch, output_numb_fields, Nx, Ny]``.
:param int input_numb_fields: The number of channels for the input.
:param int output_numb_fields: The number of channels for the output.
:param list | tuple n_modes: Number of modes to select for each dimension.
It must be at most equal to the ``floor(Nx/2)+1`` and ``floor(Ny/2)+1``.
:param torch.nn.Module activation: The activation function.
"""
super().__init__()
# check type consistency
check_consistency(activation(), nn.Module)
# assign variables
self._spectral_conv = SpectralConvBlock2D(input_numb_fields=input_numb_fields,
output_numb_fields=output_numb_fields,
n_modes=n_modes)
self._activation = activation()
self._linear = nn.Conv1d(input_numb_fields, output_numb_fields, 1)
def forward(self, x):
shape_x = x.shape
ln = self._linear(x.view(shape_x[0], shape_x[1], -1))
ln = ln.view(shape_x[0], -1, shape_x[2], shape_x[3])
return self._activation(self._spectral_conv(x) + ln)
class FourierBlock3D(nn.Module):
"""
Fourier block implementation for three dimensional
input tensor. The combination of Fourier blocks
make up the Fourier Neural Operator
.. seealso::
**Original reference**: Li, Zongyi, et al.
"Fourier neural operator for parametric partial
differential equations." arXiv preprint
arXiv:2010.08895 (2020)
<https://arxiv.org/abs/2010.08895.pdf>`_.
"""
def __init__(self, input_numb_fields, output_numb_fields, n_modes, activation=torch.nn.Tanh):
"""
PINA implementation of Fourier block three dimensions. The module computes
the spectral convolution of the input with a linear kernel in the
fourier space, and then it maps the input back to the physical
space. The output is then added to a Linear tranformation of the
input in the physical space. Finally an activation function is
applied to the output.
The block expects an input of size ``[batch, input_numb_fields, Nx, Ny, Nz]``
and returns an output of size ``[batch, output_numb_fields, Nx, Ny, Nz]``.
:param int input_numb_fields: The number of channels for the input.
:param int output_numb_fields: The number of channels for the output.
:param list | tuple n_modes: Number of modes to select for each dimension.
It must be at most equal to the ``floor(Nx/2)+1``, ``floor(Ny/2)+1``
and ``floor(Nz/2)+1``.
:param torch.nn.Module activation: The activation function.
"""
super().__init__()
# check type consistency
check_consistency(activation(), nn.Module)
# assign variables
self._spectral_conv = SpectralConvBlock3D(input_numb_fields=input_numb_fields,
output_numb_fields=output_numb_fields,
n_modes=n_modes)
self._activation = activation()
self._linear = nn.Conv1d(input_numb_fields, output_numb_fields, 1)
def forward(self, x):
shape_x = x.shape
ln = self._linear(x.view(shape_x[0], shape_x[1], -1))
ln = ln.view(shape_x[0], -1, shape_x[2], shape_x[3], shape_x[4])
return self._activation(self._spectral_conv(x) + ln)

View File

@@ -12,14 +12,18 @@ class SpectralConvBlock1D(nn.Module):
def __init__(self, input_numb_fields, output_numb_fields, n_modes): def __init__(self, input_numb_fields, output_numb_fields, n_modes):
""" """
TODO PINA implementation of spectral convolution. The module computes
the spectral convolution of the input with a linear kernel in the
fourier space, and then it maps the input back to the physical
space.
:param input_numb_fields: _description_ The block expects an input of size ``[batch, input_numb_fields, N]``
:type input_numb_fields: _type_ and returns an output of size ``[batch, output_numb_fields, N]``.
:param output_numb_fields: _description_
:type output_numb_fields: _type_ :param int input_numb_fields: The number of channels for the input.
:param n_modes: _description_ :param int output_numb_fields: The number of channels for the output.
:type n_modes: _type_ :param int n_modes: Number of modes to select, it must be at most equal
to the ``floor(N/2)+1``.
""" """
super().__init__() super().__init__()
@@ -69,9 +73,6 @@ class SpectralConvBlock1D(nn.Module):
""" """
batch_size = x.shape[0] batch_size = x.shape[0]
# if x.shape[-1] // 2 + 1 < self._modes:
# raise RuntimeError('Number of modes is too high, decrease number of modes.')
# Compute Fourier transform of the input # Compute Fourier transform of the input
x_ft = torch.fft.rfft(x) x_ft = torch.fft.rfft(x)
@@ -95,6 +96,20 @@ class SpectralConvBlock2D(nn.Module):
""" """
def __init__(self, input_numb_fields, output_numb_fields, n_modes): def __init__(self, input_numb_fields, output_numb_fields, n_modes):
"""
PINA implementation of spectral convolution. The module computes
the spectral convolution of the input with a linear kernel in the
fourier space, and then it maps the input back to the physical
space.
The block expects an input of size ``[batch, input_numb_fields, Nx, Ny]``
and returns an output of size ``[batch, output_numb_fields, Nx, Ny]``.
:param int input_numb_fields: The number of channels for the input.
:param int output_numb_fields: The number of channels for the output.
:param list | tuple n_modes: Number of modes to select for each dimension.
It must be at most equal to the ``floor(Nx/2)+1`` and ``floor(Ny/2)+1``.
"""
super().__init__() super().__init__()
# check type consistency # check type consistency
@@ -188,16 +203,19 @@ class SpectralConvBlock3D(nn.Module):
def __init__(self, input_numb_fields, output_numb_fields, n_modes): def __init__(self, input_numb_fields, output_numb_fields, n_modes):
""" """
TODO PINA implementation of spectral convolution. The module computes
the spectral convolution of the input with a linear kernel in the
fourier space, and then it maps the input back to the physical
space.
:param input_numb_fields: _description_ The block expects an input of size ``[batch, input_numb_fields, Nx, Ny, Nz]``
:type input_numb_fields: _type_ and returns an output of size ``[batch, output_numb_fields, Nx, Ny, Nz]``.
:param output_numb_fields: _description_
:type output_numb_fields: _type_ :param int input_numb_fields: The number of channels for the input.
:param n_modes: _description_ :param int output_numb_fields: The number of channels for the output.
:type n_modes: _type_ :param list | tuple n_modes: Number of modes to select for each dimension.
:raises ValueError: _description_ It must be at most equal to the ``floor(Nx/2)+1``, ``floor(Ny/2)+1``
:raises ValueError: _description_ and ``floor(Nz/2)+1``.
""" """
super().__init__() super().__init__()

View File

@@ -0,0 +1,43 @@
from pina.model.layers import FourierBlock1D, FourierBlock2D, FourierBlock3D
import torch
input_numb_fields = 3
output_numb_fields = 4
batch = 5
def test_constructor_1d():
FourierBlock1D(input_numb_fields=input_numb_fields,
output_numb_fields=output_numb_fields,
n_modes=5)
def test_forward_1d():
sconv = FourierBlock1D(input_numb_fields=input_numb_fields,
output_numb_fields=output_numb_fields,
n_modes=4)
x = torch.rand(batch, input_numb_fields, 10)
sconv(x)
def test_constructor_2d():
FourierBlock2D(input_numb_fields=input_numb_fields,
output_numb_fields=output_numb_fields,
n_modes=[5, 4])
def test_forward_2d():
sconv = FourierBlock2D(input_numb_fields=input_numb_fields,
output_numb_fields=output_numb_fields,
n_modes=[5, 4])
x = torch.rand(batch, input_numb_fields, 10, 10)
sconv(x)
def test_constructor_3d():
FourierBlock3D(input_numb_fields=input_numb_fields,
output_numb_fields=output_numb_fields,
n_modes=[5, 4, 4])
def test_forward_3d():
sconv = FourierBlock3D(input_numb_fields=input_numb_fields,
output_numb_fields=output_numb_fields,
n_modes=[5, 4, 4])
x = torch.rand(batch, input_numb_fields, 10, 10, 10)
sconv(x)