diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index bdbe70c..4156279 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -48,11 +48,13 @@ Models :maxdepth: 5 Network + KernelNeuralOperator FeedForward MultiFeedForward ResidualFeedForward DeepONet MIONet + FourierIntegralKernel FNO Layers diff --git a/docs/source/_rst/models/base_no.rst b/docs/source/_rst/models/base_no.rst new file mode 100644 index 0000000..772261c --- /dev/null +++ b/docs/source/_rst/models/base_no.rst @@ -0,0 +1,7 @@ +KernelNeuralOperator +======================= +.. currentmodule:: pina.model.base_no + +.. autoclass:: KernelNeuralOperator + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/models/fourier_kernel.rst b/docs/source/_rst/models/fourier_kernel.rst new file mode 100644 index 0000000..e45ba17 --- /dev/null +++ b/docs/source/_rst/models/fourier_kernel.rst @@ -0,0 +1,7 @@ +FourierIntegralKernel +========================= +.. currentmodule:: pina.model.fno + +.. autoclass:: FourierIntegralKernel + :members: + :show-inheritance: \ No newline at end of file diff --git a/pina/model/__init__.py b/pina/model/__init__.py index aab81bf..d2a6f08 100644 --- a/pina/model/__init__.py +++ b/pina/model/__init__.py @@ -1,13 +1,16 @@ __all__ = [ - "FeedForward", - "ResidualFeedForward", - "MultiFeedForward", - "DeepONet", - "MIONet", - "FNO", + 'FeedForward', + 'ResidualFeedForward', + 'MultiFeedForward', + 'DeepONet', + 'MIONet', + 'FNO', + 'FourierIntegralKernel', + 'KernelNeuralOperator' ] from .feed_forward import FeedForward, ResidualFeedForward from .multi_feed_forward import MultiFeedForward from .deeponet import DeepONet, MIONet -from .fno import FNO +from .fno import FNO, FourierIntegralKernel +from .base_no import KernelNeuralOperator diff --git a/pina/model/base_no.py b/pina/model/base_no.py new file mode 100644 index 0000000..5743437 --- /dev/null +++ b/pina/model/base_no.py @@ -0,0 +1,136 @@ +""" +Kernel Neural Operator Module. +""" + +import torch +from pina.utils import check_consistency + + +class KernelNeuralOperator(torch.nn.Module): + r""" + Base class for composing Neural Operators with integral kernels. + + This is a base class for composing neural operators with multiple + integral kernels. All neural operator models defined in PINA inherit + from this class. The structure is inspired by the work of Kovachki, N. + et al. see Figure 2 of the reference for extra details. The Neural + Operators inheriting from this class can be written as: + + .. math:: + G_\theta := P \circ K_m \circ \cdot \circ K_1 \circ L + + where: + + * :math:`G_\theta: \mathcal{A}\subset \mathbb{R}^{\rm{in}} \rightarrow + \mathcal{D}\subset \mathbb{R}^{\rm{out}}` is the neural operator + approximation of the unknown real operator :math:`G`, that is + :math:`G \approx G_\theta` + * :math:`L: \mathcal{A}\subset \mathbb{R}^{\rm{in}} \rightarrow + \mathbb{R}^{\rm{emb}}` is a lifting operator mapping the input + from its domain :math:`\mathcal{A}\subset \mathbb{R}^{\rm{in}}` + to its embedding dimension :math:`\mathbb{R}^{\rm{emb}}` + * :math:`\{K_i : \mathbb{R}^{\rm{emb}} \rightarrow + \mathbb{R}^{\rm{emb}} \}_{i=1}^m` are :math:`m` integral kernels + mapping each hidden representation to the next one. + * :math:`P : \mathbb{R}^{\rm{emb}} \rightarrow \mathcal{D}\subset + \mathbb{R}^{\rm{out}}` is a projection operator mapping the hidden + representation to the output function. + + .. seealso:: + + **Original reference**: Kovachki, N., Li, Z., Liu, B., + Azizzadenesheli, K., Bhattacharya, K., Stuart, A., & Anandkumar, A. + (2023). *Neural operator: Learning maps between function + spaces with applications to PDEs*. Journal of Machine Learning + Research, 24(89), 1-97. + """ + def __init__(self, lifting_operator, integral_kernels, projection_operator): + """ + :param torch.nn.Module lifting_operator: The lifting operator + mapping the input to its hidden dimension. + :param torch.nn.Module integral_kernels: List of integral kernels + mapping each hidden representation to the next one. + :param torch.nn.Module projection_operator: The projection operator + mapping the hidden representation to the output function. + """ + + super().__init__() + + self._lifting_operator = lifting_operator + self._integral_kernels = integral_kernels + self._projection_operator = projection_operator + + @property + def lifting_operator(self): + """ + The lifting operator property. + """ + return self._lifting_operator + + @lifting_operator.setter + def lifting_operator(self, value): + """ + The lifting operator setter + + :param torch.nn.Module value: The lifting operator torch module. + """ + check_consistency(value, torch.nn.Module) + self._lifting_operator = value + + @property + def projection_operator(self): + """ + The projection operator property. + """ + return self._projection_operator + + @projection_operator.setter + def projection_operator(self, value): + """ + The projection operator setter + + :param torch.nn.Module value: The projection operator torch module. + """ + check_consistency(value, torch.nn.Module) + self._projection_operator = value + + @property + def integral_kernels(self): + """ + The integral kernels operator property. + """ + return self._integral_kernels + + @integral_kernels.setter + def integral_kernels(self, value): + """ + The integral kernels operator setter + + :param torch.nn.Module value: The integral kernels operator torch + module. + """ + check_consistency(value, torch.nn.Module) + self._integral_kernels = value + + def forward(self, x): + r""" + Forward computation for Base Neural Operator. It performs a + lifting of the input by the ``lifting_operator``. + Then different layers integral kernels are applied using + ``integral_kernels``. Finally the output is projected + to the final dimensionality by the ``projection_operator``. + + :param torch.Tensor x: The input tensor for performing the + computation. It expects a tensor :math:`B \times N \times D`, + where :math:`B` is the batch_size, :math:`N` the number of points + in the mesh, :math:`D` the dimension of the problem. In particular + :math:`D` is the number of spatial/paramtric/temporal variables + plus the field variables. For example for 2D problems with 2 + output\ variables :math:`D=4`. + :return: The output tensor obtained from the NO. + :rtype: torch.Tensor + """ + x = self.lifting_operator(x) + x = self.integral_kernels(x) + x = self.projection_operator(x) + return x diff --git a/pina/model/fno.py b/pina/model/fno.py index 93168e8..e320383 100644 --- a/pina/model/fno.py +++ b/pina/model/fno.py @@ -1,54 +1,66 @@ +""" +Fourier Neural Operator Module. +""" + import torch import torch.nn as nn -from ..utils import check_consistency -from .layers.fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D from pina import LabelTensor import warnings +from ..utils import check_consistency +from .layers.fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D +from .base_no import KernelNeuralOperator -class FNO(torch.nn.Module): +class FourierIntegralKernel(torch.nn.Module): """ - The PINA implementation of Fourier Neural Operator network. + Implementation of Fourier Integral Kernel network. - Fourier Neural Operator (FNO) is a general architecture for learning Operators. - Unlike traditional machine learning methods FNO is designed to map - entire functions to other functions. It can be trained both with - Supervised learning strategies. FNO does global convolution by performing the - operation on the Fourier space. + This class implements the Fourier Integral Kernel network, which is a + PINA implementation of Fourier Neural Operator kernel network. + It performs global convolution by operating in the Fourier space. .. seealso:: - **Original reference**: Li, Z., Kovachki, N., Azizzadenesheli, K., Liu, B., - Bhattacharya, K., Stuart, A., & Anandkumar, A. (2020). *Fourier neural operator for - parametric partial differential equations*. + **Original reference**: Li, Z., Kovachki, N., Azizzadenesheli, + K., Liu, B., Bhattacharya, K., Stuart, A., & Anandkumar, A. + (2020). *Fourier neural operator for parametric partial + differential equations*. DOI: `arXiv preprint arXiv:2010.08895. `_ """ - - def __init__( - self, - lifting_net, - projecting_net, - n_modes, - dimensions=3, - padding=8, - padding_type="constant", - inner_size=20, - n_layers=2, - func=nn.Tanh, - layers=None, - ): + def __init__(self, + input_numb_fields, + output_numb_fields, + n_modes, + dimensions=3, + padding=8, + padding_type="constant", + inner_size=20, + n_layers=2, + func=nn.Tanh, + layers=None): + """ + :param int input_numb_fields: Number of input fields. + :param int output_numb_fields: Number of output fields. + :param int | list[int] n_modes: Number of modes. + :param int dimensions: Number of dimensions (1, 2, or 3). + :param int padding: Padding size, defaults to 8. + :param str padding_type: Type of padding, defaults to "constant". + :param int inner_size: Inner size, defaults to 20. + :param int n_layers: Number of layers, defaults to 2. + :param torch.nn.Module func: Activation function, defaults to nn.Tanh. + :param list[int] layers: List of layer sizes, defaults to None. + """ super().__init__() # check type consistency - check_consistency(lifting_net, nn.Module) - check_consistency(projecting_net, nn.Module) check_consistency(dimensions, int) check_consistency(padding, int) check_consistency(padding_type, str) check_consistency(inner_size, int) check_consistency(n_layers, int) check_consistency(func, nn.Module, subclass=True) + if layers is not None: if isinstance(layers, (tuple, list)): check_consistency(layers, int) @@ -57,13 +69,9 @@ class FNO(torch.nn.Module): if not isinstance(n_modes, (list, tuple, int)): raise ValueError( "n_modes must be a int or list or tuple of valid modes." - " More information on the official documentation." - ) + " More information on the official documentation.") - # assign variables - # TODO check input lifting net and input projecting net - self._lifting_net = lifting_net - self._projecting_net = projecting_net + # assign padding self._padding = padding # initialize fourier layer for each dimension @@ -74,9 +82,11 @@ class FNO(torch.nn.Module): elif dimensions == 3: fourier_layer = FourierBlock3D else: - raise NotImplementedError("FNO implemented only for 1D/2D/3D data.") + raise NotImplementedError( + "FNO implemented only for 1D/2D/3D data." + ) - # Here we build the FNO by stacking Fourier Blocks + # Here we build the FNO kernels by stacking Fourier Blocks # 1. Assign output dimensions for each FNO layer if layers is None: @@ -86,43 +96,33 @@ class FNO(torch.nn.Module): if isinstance(func, list): if len(layers) != len(func): raise RuntimeError( - "Uncosistent number of layers and functions." - ) - self._functions = func + 'Uncosistent number of layers and functions.') + _functions = func else: - self._functions = [func for _ in range(len(layers))] + _functions = [func for _ in range(len(layers) - 1)] + _functions.append(torch.nn.Identity) # 3. Assign modes functions for each FNO layer if isinstance(n_modes, list): - if all(isinstance(i, list) for i in n_modes) and len(layers) != len( - n_modes - ): + if all(isinstance(i, list) + for i in n_modes) and len(layers) != len(n_modes): raise RuntimeError( - "Uncosistent number of layers and functions." - ) + "Uncosistent number of layers and functions.") elif all(isinstance(i, int) for i in n_modes): n_modes = [n_modes] * len(layers) else: n_modes = [n_modes] * len(layers) # 4. Build the FNO network - tmp_layers = layers.copy() - first_parameter = next(lifting_net.parameters()) - input_shape = first_parameter.size() - out_feats = lifting_net(torch.rand(size=input_shape)).shape[-1] - tmp_layers.insert(0, out_feats) - - self._layers = [] - for i in range(len(tmp_layers) - 1): - self._layers.append( - fourier_layer( - input_numb_fields=tmp_layers[i], - output_numb_fields=tmp_layers[i + 1], - n_modes=n_modes[i], - activation=self._functions[i], - ) - ) - self._layers = nn.Sequential(*self._layers) + _layers = [] + tmp_layers = [input_numb_fields] + layers + [output_numb_fields] + for i in range(len(layers)): + _layers.append( + fourier_layer(input_numb_fields=tmp_layers[i], + output_numb_fields=tmp_layers[i + 1], + n_modes=n_modes[i], + activation=_functions[i])) + self._layers = nn.Sequential(*_layers) # 5. Padding values for spectral conv if isinstance(padding, int): @@ -140,23 +140,22 @@ class FNO(torch.nn.Module): of Fourier Blocks are applied. Finally the output is projected to the final dimensionality by the ``projecting_net``. - :param torch.Tensor x: The input tensor for fourier block, depending on - ``dimension`` in the initialization. In particular it is expected + :param torch.Tensor x: The input tensor for fourier block, + depending on ``dimension`` in the initialization. + In particular it is expected: + * 1D tensors: ``[batch, X, channels]`` * 2D tensors: ``[batch, X, Y, channels]`` * 3D tensors: ``[batch, X, Y, Z, channels]`` - :return: The output tensor obtained from the FNO. + :return: The output tensor obtained from the kernels convolution. :rtype: torch.Tensor """ - if isinstance(x, LabelTensor): # TODO remove when Network is fixed + if isinstance(x, LabelTensor): #TODO remove when Network is fixed warnings.warn( - "LabelTensor passed as input is not allowed, casting LabelTensor to Torch.Tensor" + 'LabelTensor passed as input is not allowed,' + ' casting LabelTensor to Torch.Tensor' ) x = x.as_subclass(torch.Tensor) - - # lifting the input in higher dimensional space - x = self._lifting_net(x) - # permuting the input [batch, channels, x, y, ...] permutation_idx = [0, x.ndim - 1, *[i for i in range(1, x.ndim - 1)]] x = x.permute(permutation_idx) @@ -175,5 +174,85 @@ class FNO(torch.nn.Module): permutation_idx = [0, *[i for i in range(2, x.ndim)], 1] x = x.permute(permutation_idx) - # apply projecting operator and return - return self._projecting_net(x) + return x + + +class FNO(KernelNeuralOperator): + """ + The PINA implementation of Fourier Neural Operator network. + + Fourier Neural Operator (FNO) is a general architecture for + learning Operators. Unlike traditional machine learning methods + FNO is designed to map entire functions to other functions. It + can be trained with Supervised learning strategies. FNO does global + convolution by performing the operation on the Fourier space. + + .. seealso:: + + **Original reference**: Li, Z., Kovachki, N., Azizzadenesheli, + K., Liu, B., Bhattacharya, K., Stuart, A., & Anandkumar, A. + (2020). *Fourier neural operator for parametric partial + differential equations*. + DOI: `arXiv preprint arXiv:2010.08895. + `_ + """ + def __init__(self, + lifting_net, + projecting_net, + n_modes, + dimensions=3, + padding=8, + padding_type="constant", + inner_size=20, + n_layers=2, + func=nn.Tanh, + layers=None): + """ + :param torch.nn.Module lifting_net: The neural network for lifting + the input. + :param torch.nn.Module projecting_net: The neural network for + projecting the output. + :param int | list[int] n_modes: Number of modes. + :param int dimensions: Number of dimensions (1, 2, or 3). + :param int padding: Padding size, defaults to 8. + :param str padding_type: Type of padding, defaults to `constant`. + :param int inner_size: Inner size, defaults to 20. + :param int n_layers: Number of layers, defaults to 2. + :param torch.nn.Module func: Activation function, defaults to nn.Tanh. + :param list[int] layers: List of layer sizes, defaults to None. + """ + lifting_operator_out = lifting_net( + torch.rand(size=next(lifting_net.parameters()).size())).shape[-1] + super().__init__(lifting_operator=lifting_net, + projection_operator=projecting_net, + integral_kernels=FourierIntegralKernel( + input_numb_fields=lifting_operator_out, + output_numb_fields=next( + projecting_net.parameters()).size(), + n_modes=n_modes, + dimensions=dimensions, + padding=padding, + padding_type=padding_type, + inner_size=inner_size, + n_layers=n_layers, + func=func, + layers=layers)) + + def forward(self, x): + """ + Forward computation for Fourier Neural Operator. It performs a + lifting of the input by the ``lifting_net``. Then different layers + of Fourier Blocks are applied. Finally the output is projected + to the final dimensionality by the ``projecting_net``. + + :param torch.Tensor x: The input tensor for fourier block, + depending on ``dimension`` in the initialization. In + particular it is expected: + + * 1D tensors: ``[batch, X, channels]`` + * 2D tensors: ``[batch, X, Y, channels]`` + * 3D tensors: ``[batch, X, Y, Z, channels]`` + :return: The output tensor obtained from FNO. + :rtype: torch.Tensor + """ + return super().forward(x) diff --git a/tests/test_model/test_base_no.py b/tests/test_model/test_base_no.py new file mode 100644 index 0000000..4a14fd1 --- /dev/null +++ b/tests/test_model/test_base_no.py @@ -0,0 +1,40 @@ +import torch +from pina.model import KernelNeuralOperator, FeedForward + +input_dim = 2 +output_dim = 4 +embedding_dim = 24 +batch_size = 10 +numb = 256 +data = torch.rand(size=(batch_size, numb, input_dim), requires_grad=True) +output_shape = torch.Size([batch_size, numb, output_dim]) + + +lifting_operator = FeedForward(input_dimensions=input_dim, output_dimensions=embedding_dim) +projection_operator = FeedForward(input_dimensions=embedding_dim, output_dimensions=output_dim) +integral_kernels = torch.nn.Sequential(FeedForward(input_dimensions=embedding_dim, + output_dimensions=embedding_dim), + FeedForward(input_dimensions=embedding_dim, + output_dimensions=embedding_dim),) + +def test_constructor(): + KernelNeuralOperator(lifting_operator=lifting_operator, + integral_kernels=integral_kernels, + projection_operator=projection_operator) + +def test_forward(): + operator = KernelNeuralOperator(lifting_operator=lifting_operator, + integral_kernels=integral_kernels, + projection_operator=projection_operator) + out = operator(data) + assert out.shape == output_shape + +def test_backward(): + operator = KernelNeuralOperator(lifting_operator=lifting_operator, + integral_kernels=integral_kernels, + projection_operator=projection_operator) + out = operator(data) + loss = torch.nn.functional.mse_loss(out, torch.zeros_like(out)) + loss.backward() + grad = data.grad + assert grad.shape == data.shape