Fourier Block and minor fixes

* Adding fourier block 1d/2d/3d
* Adding docs to SpectralConvBlock1D/2D/3D and to FourierBlock1D/2D/3D
* Adding tests for fourier block
This commit is contained in:
Dario Coscia
2023-09-07 18:18:28 +02:00
committed by Nicola Demo
parent 2bf42d5fea
commit 83ecdb0eab
4 changed files with 228 additions and 24 deletions

View File

@@ -3,9 +3,13 @@ __all__ = [
'ResidualBlock',
'SpectralConvBlock1D',
'SpectralConvBlock2D',
'SpectralConvBlock3D'
'SpectralConvBlock3D',
'FourierBlock1D',
'FourierBlock2D',
'FourierBlock3D',
]
from .convolution_2d import ContinuousConvBlock
from .residual import ResidualBlock
from .spectral import SpectralConvBlock1D, SpectralConvBlock2D, SpectralConvBlock3D
from .fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D

View File

@@ -2,9 +2,14 @@ import torch
import torch.nn as nn
from ...utils import check_consistency
from pina.model.layers import SpectralConvBlock1D, SpectralConvBlock2D, SpectralConvBlock3D
class FourierBlock(nn.Module):
"""Fourier block base class. Implementation of a fourier block.
class FourierBlock1D(nn.Module):
"""
Fourier block implementation for three dimensional
input tensor. The combination of Fourier blocks
make up the Fourier Neural Operator
.. seealso::
@@ -16,9 +21,143 @@ class FourierBlock(nn.Module):
"""
def __init__(self):
def __init__(self, input_numb_fields, output_numb_fields, n_modes, activation=torch.nn.Tanh):
super().__init__()
"""
PINA implementation of Fourier block one dimension. The module computes
the spectral convolution of the input with a linear kernel in the
fourier space, and then it maps the input back to the physical
space. The output is then added to a Linear tranformation of the
input in the physical space. Finally an activation function is
applied to the output.
The block expects an input of size ``[batch, input_numb_fields, N]``
and returns an output of size ``[batch, output_numb_fields, N]``.
:param int input_numb_fields: The number of channels for the input.
:param int output_numb_fields: The number of channels for the output.
:param list | tuple n_modes: Number of modes to select for each dimension.
It must be at most equal to the ``floor(N/2)+1``.
:param torch.nn.Module activation: The activation function.
"""
# check type consistency
check_consistency(activation(), nn.Module)
# assign variables
self._spectral_conv = SpectralConvBlock1D(input_numb_fields=input_numb_fields,
output_numb_fields=output_numb_fields,
n_modes=n_modes)
self._activation = activation()
self._linear = nn.Conv1d(input_numb_fields, output_numb_fields, 1)
def forward(self, x):
pass
return self._activation(self._spectral_conv(x) + self._linear(x))
class FourierBlock2D(nn.Module):
"""
Fourier block implementation for two dimensional
input tensor. The combination of Fourier blocks
make up the Fourier Neural Operator
.. seealso::
**Original reference**: Li, Zongyi, et al.
"Fourier neural operator for parametric partial
differential equations." arXiv preprint
arXiv:2010.08895 (2020)
<https://arxiv.org/abs/2010.08895.pdf>`_.
"""
def __init__(self, input_numb_fields, output_numb_fields, n_modes, activation=torch.nn.Tanh):
"""
PINA implementation of Fourier block two dimensions. The module computes
the spectral convolution of the input with a linear kernel in the
fourier space, and then it maps the input back to the physical
space. The output is then added to a Linear tranformation of the
input in the physical space. Finally an activation function is
applied to the output.
The block expects an input of size ``[batch, input_numb_fields, Nx, Ny]``
and returns an output of size ``[batch, output_numb_fields, Nx, Ny]``.
:param int input_numb_fields: The number of channels for the input.
:param int output_numb_fields: The number of channels for the output.
:param list | tuple n_modes: Number of modes to select for each dimension.
It must be at most equal to the ``floor(Nx/2)+1`` and ``floor(Ny/2)+1``.
:param torch.nn.Module activation: The activation function.
"""
super().__init__()
# check type consistency
check_consistency(activation(), nn.Module)
# assign variables
self._spectral_conv = SpectralConvBlock2D(input_numb_fields=input_numb_fields,
output_numb_fields=output_numb_fields,
n_modes=n_modes)
self._activation = activation()
self._linear = nn.Conv1d(input_numb_fields, output_numb_fields, 1)
def forward(self, x):
shape_x = x.shape
ln = self._linear(x.view(shape_x[0], shape_x[1], -1))
ln = ln.view(shape_x[0], -1, shape_x[2], shape_x[3])
return self._activation(self._spectral_conv(x) + ln)
class FourierBlock3D(nn.Module):
"""
Fourier block implementation for three dimensional
input tensor. The combination of Fourier blocks
make up the Fourier Neural Operator
.. seealso::
**Original reference**: Li, Zongyi, et al.
"Fourier neural operator for parametric partial
differential equations." arXiv preprint
arXiv:2010.08895 (2020)
<https://arxiv.org/abs/2010.08895.pdf>`_.
"""
def __init__(self, input_numb_fields, output_numb_fields, n_modes, activation=torch.nn.Tanh):
"""
PINA implementation of Fourier block three dimensions. The module computes
the spectral convolution of the input with a linear kernel in the
fourier space, and then it maps the input back to the physical
space. The output is then added to a Linear tranformation of the
input in the physical space. Finally an activation function is
applied to the output.
The block expects an input of size ``[batch, input_numb_fields, Nx, Ny, Nz]``
and returns an output of size ``[batch, output_numb_fields, Nx, Ny, Nz]``.
:param int input_numb_fields: The number of channels for the input.
:param int output_numb_fields: The number of channels for the output.
:param list | tuple n_modes: Number of modes to select for each dimension.
It must be at most equal to the ``floor(Nx/2)+1``, ``floor(Ny/2)+1``
and ``floor(Nz/2)+1``.
:param torch.nn.Module activation: The activation function.
"""
super().__init__()
# check type consistency
check_consistency(activation(), nn.Module)
# assign variables
self._spectral_conv = SpectralConvBlock3D(input_numb_fields=input_numb_fields,
output_numb_fields=output_numb_fields,
n_modes=n_modes)
self._activation = activation()
self._linear = nn.Conv1d(input_numb_fields, output_numb_fields, 1)
def forward(self, x):
shape_x = x.shape
ln = self._linear(x.view(shape_x[0], shape_x[1], -1))
ln = ln.view(shape_x[0], -1, shape_x[2], shape_x[3], shape_x[4])
return self._activation(self._spectral_conv(x) + ln)

View File

@@ -12,14 +12,18 @@ class SpectralConvBlock1D(nn.Module):
def __init__(self, input_numb_fields, output_numb_fields, n_modes):
"""
TODO
PINA implementation of spectral convolution. The module computes
the spectral convolution of the input with a linear kernel in the
fourier space, and then it maps the input back to the physical
space.
:param input_numb_fields: _description_
:type input_numb_fields: _type_
:param output_numb_fields: _description_
:type output_numb_fields: _type_
:param n_modes: _description_
:type n_modes: _type_
The block expects an input of size ``[batch, input_numb_fields, N]``
and returns an output of size ``[batch, output_numb_fields, N]``.
:param int input_numb_fields: The number of channels for the input.
:param int output_numb_fields: The number of channels for the output.
:param int n_modes: Number of modes to select, it must be at most equal
to the ``floor(N/2)+1``.
"""
super().__init__()
@@ -69,9 +73,6 @@ class SpectralConvBlock1D(nn.Module):
"""
batch_size = x.shape[0]
# if x.shape[-1] // 2 + 1 < self._modes:
# raise RuntimeError('Number of modes is too high, decrease number of modes.')
# Compute Fourier transform of the input
x_ft = torch.fft.rfft(x)
@@ -95,6 +96,20 @@ class SpectralConvBlock2D(nn.Module):
"""
def __init__(self, input_numb_fields, output_numb_fields, n_modes):
"""
PINA implementation of spectral convolution. The module computes
the spectral convolution of the input with a linear kernel in the
fourier space, and then it maps the input back to the physical
space.
The block expects an input of size ``[batch, input_numb_fields, Nx, Ny]``
and returns an output of size ``[batch, output_numb_fields, Nx, Ny]``.
:param int input_numb_fields: The number of channels for the input.
:param int output_numb_fields: The number of channels for the output.
:param list | tuple n_modes: Number of modes to select for each dimension.
It must be at most equal to the ``floor(Nx/2)+1`` and ``floor(Ny/2)+1``.
"""
super().__init__()
# check type consistency
@@ -188,16 +203,19 @@ class SpectralConvBlock3D(nn.Module):
def __init__(self, input_numb_fields, output_numb_fields, n_modes):
"""
TODO
PINA implementation of spectral convolution. The module computes
the spectral convolution of the input with a linear kernel in the
fourier space, and then it maps the input back to the physical
space.
:param input_numb_fields: _description_
:type input_numb_fields: _type_
:param output_numb_fields: _description_
:type output_numb_fields: _type_
:param n_modes: _description_
:type n_modes: _type_
:raises ValueError: _description_
:raises ValueError: _description_
The block expects an input of size ``[batch, input_numb_fields, Nx, Ny, Nz]``
and returns an output of size ``[batch, output_numb_fields, Nx, Ny, Nz]``.
:param int input_numb_fields: The number of channels for the input.
:param int output_numb_fields: The number of channels for the output.
:param list | tuple n_modes: Number of modes to select for each dimension.
It must be at most equal to the ``floor(Nx/2)+1``, ``floor(Ny/2)+1``
and ``floor(Nz/2)+1``.
"""
super().__init__()

View File

@@ -0,0 +1,43 @@
from pina.model.layers import FourierBlock1D, FourierBlock2D, FourierBlock3D
import torch
input_numb_fields = 3
output_numb_fields = 4
batch = 5
def test_constructor_1d():
FourierBlock1D(input_numb_fields=input_numb_fields,
output_numb_fields=output_numb_fields,
n_modes=5)
def test_forward_1d():
sconv = FourierBlock1D(input_numb_fields=input_numb_fields,
output_numb_fields=output_numb_fields,
n_modes=4)
x = torch.rand(batch, input_numb_fields, 10)
sconv(x)
def test_constructor_2d():
FourierBlock2D(input_numb_fields=input_numb_fields,
output_numb_fields=output_numb_fields,
n_modes=[5, 4])
def test_forward_2d():
sconv = FourierBlock2D(input_numb_fields=input_numb_fields,
output_numb_fields=output_numb_fields,
n_modes=[5, 4])
x = torch.rand(batch, input_numb_fields, 10, 10)
sconv(x)
def test_constructor_3d():
FourierBlock3D(input_numb_fields=input_numb_fields,
output_numb_fields=output_numb_fields,
n_modes=[5, 4, 4])
def test_forward_3d():
sconv = FourierBlock3D(input_numb_fields=input_numb_fields,
output_numb_fields=output_numb_fields,
n_modes=[5, 4, 4])
x = torch.rand(batch, input_numb_fields, 10, 10, 10)
sconv(x)