diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index 6a9b80c..74c7852 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -57,6 +57,7 @@ Models FourierIntegralKernel FNO AveragingNeuralOperator + LowRankNeuralOperator Layers ------------- @@ -69,6 +70,7 @@ Layers Spectral convolution Fourier layers Averaging layer + Low Rank layer Continuous convolution Proper Orthogonal Decomposition Periodic Boundary Condition embeddings diff --git a/docs/source/_rst/layers/lowrank_layer.rst b/docs/source/_rst/layers/lowrank_layer.rst new file mode 100644 index 0000000..6e72feb --- /dev/null +++ b/docs/source/_rst/layers/lowrank_layer.rst @@ -0,0 +1,8 @@ +Low Rank layer +==================== +.. currentmodule:: pina.model.layers.lowrank_layer + +.. autoclass:: LowRankBlock + :members: + :show-inheritance: + :noindex: diff --git a/docs/source/_rst/models/lno.rst b/docs/source/_rst/models/lno.rst new file mode 100644 index 0000000..f3f8277 --- /dev/null +++ b/docs/source/_rst/models/lno.rst @@ -0,0 +1,7 @@ +Low Rank Neural Operator +============================== +.. currentmodule:: pina.model.lno + +.. autoclass:: LowRankNeuralOperator + :members: + :show-inheritance: \ No newline at end of file diff --git a/pina/model/__init__.py b/pina/model/__init__.py index b084988..4c441ee 100644 --- a/pina/model/__init__.py +++ b/pina/model/__init__.py @@ -8,6 +8,7 @@ __all__ = [ "FourierIntegralKernel", "KernelNeuralOperator", "AveragingNeuralOperator", + "LowRankNeuralOperator" ] from .feed_forward import FeedForward, ResidualFeedForward @@ -16,3 +17,4 @@ from .deeponet import DeepONet, MIONet from .fno import FNO, FourierIntegralKernel from .base_no import KernelNeuralOperator from .avno import AveragingNeuralOperator +from .lno import LowRankNeuralOperator diff --git a/pina/model/layers/__init__.py b/pina/model/layers/__init__.py index 87ea4e1..0d3e8a7 100644 --- a/pina/model/layers/__init__.py +++ b/pina/model/layers/__init__.py @@ -11,6 +11,7 @@ __all__ = [ "PODBlock", "PeriodicBoundaryEmbedding", "AVNOBlock", + "LowRankBlock", "AdaptiveActivationFunction", ] @@ -25,4 +26,5 @@ from .fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D from .pod import PODBlock from .embedding import PeriodicBoundaryEmbedding from .avno_layer import AVNOBlock -from .adaptive_func import AdaptiveActivationFunction +from .lowrank_layer import LowRankBlock +from .adaptive_func import AdaptiveActivationFunction \ No newline at end of file diff --git a/pina/model/layers/lowrank_layer.py b/pina/model/layers/lowrank_layer.py new file mode 100644 index 0000000..2a26e69 --- /dev/null +++ b/pina/model/layers/lowrank_layer.py @@ -0,0 +1,135 @@ +""" Module for Averaging Neural Operator Layer class. """ + +import torch + +from pina.utils import check_consistency +import pina.model as pm # avoid circular import + + +class LowRankBlock(torch.nn.Module): + r""" + The PINA implementation of the inner layer of the Averaging Neural Operator. + + The operator layer performs an affine transformation where the convolution + is approximated with a local average. Given the input function + :math:`v(x)\in\mathbb{R}^{\rm{emb}}` the layer computes + the operator update :math:`K(v)` as: + + .. math:: + K(v) = \sigma\left(Wv(x) + b + \sum_{i=1}^r \langle + \psi^{(i)} , v(x) \rangle \phi^{(i)} \right) + + where: + + * :math:`\mathbb{R}^{\rm{emb}}` is the embedding (hidden) size + corresponding to the ``hidden_size`` object + * :math:`\sigma` is a non-linear activation, corresponding to the + ``func`` object + * :math:`W\in\mathbb{R}^{\rm{emb}\times\rm{emb}}` is a tunable matrix. + * :math:`b\in\mathbb{R}^{\rm{emb}}` is a tunable bias. + * :math:`\psi^{(i)}\in\mathbb{R}^{\rm{emb}}` and + :math:`\phi^{(i)}\in\mathbb{R}^{\rm{emb}}` are :math:`r` a low rank + basis functions mapping. + * :math:`b\in\mathbb{R}^{\rm{emb}}` is a tunable bias. + + .. seealso:: + + **Original reference**: Kovachki, N., Li, Z., Liu, B., + Azizzadenesheli, K., Bhattacharya, K., Stuart, A., & Anandkumar, A. + (2023). *Neural operator: Learning maps between function + spaces with applications to PDEs*. Journal of Machine Learning + Research, 24(89), 1-97. + + """ + + def __init__(self, + input_dimensions, + embedding_dimenion, + rank, + inner_size=20, + n_layers=2, + func=torch.nn.Tanh, + bias=True): + """ + :param int input_dimensions: The number of input components of the + model. + Expected tensor shape of the form :math:`(*, d)`, where * + means any number of dimensions including none, + and :math:`d` the ``input_dimensions``. + :param int embedding_dimenion: Size of the embedding dimension of the + field. + :param int rank: The rank number of the basis approximation components + of the model. Expected tensor shape of the form :math:`(*, 2d)`, + where * means any number of dimensions including none, + and :math:`2d` the ``rank`` for both basis functions. + :param int inner_size: Number of neurons in the hidden layer(s) for the + basis function network. Default is 20. + :param int n_layers: Number of hidden layers. for the + basis function network. Default is 2. + :param func: The activation function to use for the + basis function network. If a single + :class:`torch.nn.Module` is passed, this is used as + activation function after any layers, except the last one. + If a list of Modules is passed, + they are used as activation functions at any layers, in order. + :param bool bias: If ``True`` the MLP will consider some bias for the + basis function network. + """ + super().__init__() + + # Assignment (check consistency inside FeedForward) + self._basis = pm.FeedForward(input_dimensions=input_dimensions, + output_dimensions=2*rank*embedding_dimenion, + inner_size=inner_size, n_layers=n_layers, + func=func, bias=bias) + self._nn = torch.nn.Linear(embedding_dimenion, embedding_dimenion) + + check_consistency(rank, int) + self._rank = rank + self._func = func() + + def forward(self, x, coords): + r""" + Forward pass of the layer, it performs an affine transformation of + the field, and a low rank approximation by + doing a dot product of the basis + :math:`\psi^{(i)}` with the filed vector :math:`v`, and use this + coefficients to expand :math:`\phi^{(i)}` evaluated in the + spatial input :math:`x`. + + :param torch.Tensor x: The input tensor for performing the + computation. It expects a tensor :math:`B \times N \times D`, + where :math:`B` is the batch_size, :math:`N` the number of points + in the mesh, :math:`D` the dimension of the problem. In particular + :math:`D` is the codomain of the function :math:`v`. For example + a scalar function has :math:`D=1`, a 4-dimensional vector function + :math:`D=4`. + :param torch.Tensor coords: The coordinates in which the field is + evaluated for performing the computation. It expects a + tensor :math:`B \times N \times d`, + where :math:`B` is the batch_size, :math:`N` the number of points + in the mesh, :math:`D` the dimension of the domain. + :return: The output tensor obtained from Average Neural Operator Block. + :rtype: torch.Tensor + """ + # extract basis + basis = self._basis(coords) + # reshape [B, N, D, 2*rank] + shape = list(basis.shape[:-1]) + [-1, 2*self.rank] + basis = basis.reshape(shape) + # divide + psi = basis[..., :self.rank] + phi = basis[..., self.rank:] + # compute dot product + coeff = torch.einsum('...dr,...d->...r', psi,x) + # expand the basis + expansion = torch.einsum('...r,...dr->...d', coeff,phi) + # apply linear layer and return + return self._func(self._nn(x) + expansion) + + @property + def rank(self): + """ + The basis rank. + """ + return self._rank diff --git a/pina/model/lno.py b/pina/model/lno.py new file mode 100644 index 0000000..3612704 --- /dev/null +++ b/pina/model/lno.py @@ -0,0 +1,143 @@ +"""Module LowRank Neural Operator.""" + +import torch +from torch import nn, concatenate + +from pina.utils import check_consistency + +from .base_no import KernelNeuralOperator +from .layers.lowrank_layer import LowRankBlock + + +class LowRankNeuralOperator(KernelNeuralOperator): + """ + Implementation of LowRank Neural Operator. + + LowRank Neural Operator is a general architecture for + learning Operators. Unlike traditional machine learning methods + LowRankNeuralOperator is designed to map entire functions + to other functions. It can be trained with Supervised or PINN based + learning strategies. + LowRankNeuralOperator does convolution by performing a low rank + approximation, see :class:`~pina.model.layers.lowrank_layer.LowRankBlock`. + + .. seealso:: + + **Original reference**: Kovachki, N., Li, Z., Liu, B., + Azizzadenesheli, K., Bhattacharya, K., Stuart, A., & Anandkumar, A. + (2023). *Neural operator: Learning maps between function + spaces with applications to PDEs*. Journal of Machine Learning + Research, 24(89), 1-97. + """ + + def __init__( + self, + lifting_net, + projecting_net, + field_indices, + coordinates_indices, + n_kernel_layers, + rank, + inner_size=20, + n_layers=2, + func=torch.nn.Tanh, + bias=True + ): + """ + :param torch.nn.Module lifting_net: The neural network for lifting + the input. It must take as input the input field and the coordinates + at which the input field is avaluated. The output of the lifting + net is chosen as embedding dimension of the problem + :param torch.nn.Module projecting_net: The neural network for + projecting the output. It must take as input the embedding dimension + (output of the ``lifting_net``) plus the dimension + of the coordinates. + :param list[str] field_indices: the label of the fields + in the input tensor. + :param list[str] coordinates_indices: the label of the + coordinates in the input tensor. + :param int n_kernel_layers: number of hidden kernel layers. + Default is 4. + :param int inner_size: Number of neurons in the hidden layer(s) for the + basis function network. Default is 20. + :param int n_layers: Number of hidden layers. for the + basis function network. Default is 2. + :param func: The activation function to use for the + basis function network. If a single + :class:`torch.nn.Module` is passed, this is used as + activation function after any layers, except the last one. + If a list of Modules is passed, + they are used as activation functions at any layers, in order. + :param bool bias: If ``True`` the MLP will consider some bias for the + basis function network. + """ + + # check consistency + check_consistency(field_indices, str) + check_consistency(coordinates_indices, str) + check_consistency(n_kernel_layers, int) + + # check hidden dimensions match + input_lifting_net = next(lifting_net.parameters()).size()[-1] + output_lifting_net = lifting_net( + torch.rand(size=next(lifting_net.parameters()).size()) + ).shape[-1] + projecting_net_input = next(projecting_net.parameters()).size()[-1] + + if len(field_indices) + len(coordinates_indices) != input_lifting_net: + raise ValueError( + "The lifting_net must take as input the " + "coordinates vector and the field vector." + ) + + if ( + output_lifting_net + len(coordinates_indices) + != projecting_net_input + ): + raise ValueError( + "The projecting_net input must be equal to " + "the embedding dimension (which is the output) " + "of the lifting_net plus the dimension of the " + "coordinates, i.e. len(coordinates_indices)." + ) + + # assign + self.coordinates_indices = coordinates_indices + self.field_indices = field_indices + integral_net = nn.Sequential( + *[LowRankBlock(input_dimensions=len(coordinates_indices), + embedding_dimenion=output_lifting_net, + rank=rank, + inner_size=inner_size, + n_layers=n_layers, + func=func, + bias=bias) for _ in range(n_kernel_layers)] + ) + super().__init__(lifting_net, integral_net, projecting_net) + + def forward(self, x): + r""" + Forward computation for LowRank Neural Operator. It performs a + lifting of the input by the ``lifting_net``. Then different layers + of LowRank Neural Operator Blocks are applied. + Finally the output is projected to the final dimensionality + by the ``projecting_net``. + + :param torch.Tensor x: The input tensor for fourier block, + depending on ``dimension`` in the initialization. It expects + a tensor :math:`B \times N \times D`, + where :math:`B` is the batch_size, :math:`N` the number of points + in the mesh, :math:`D` the dimension of the problem, i.e. the sum + of ``len(coordinates_indices)+len(field_indices)``. + :return: The output tensor obtained from Average Neural Operator. + :rtype: torch.Tensor + """ + # extract points + coords = x.extract(self.coordinates_indices) + # lifting + x = self._lifting_operator(x) + # kernel + for module in self._integral_kernels: + x = module(x, coords) + # projecting + return self._projection_operator(concatenate((x, coords), dim=-1)) diff --git a/tests/test_layers/test_lnolayer.py b/tests/test_layers/test_lnolayer.py new file mode 100644 index 0000000..28db849 --- /dev/null +++ b/tests/test_layers/test_lnolayer.py @@ -0,0 +1,58 @@ +import torch +import pytest + +from pina.model.layers import LowRankBlock +from pina import LabelTensor + + +input_dimensions=2 +embedding_dimenion=1 +rank=4 +inner_size=20 +n_layers=2 +func=torch.nn.Tanh +bias=True + +def test_constructor(): + LowRankBlock(input_dimensions=input_dimensions, + embedding_dimenion=embedding_dimenion, + rank=rank, + inner_size=inner_size, + n_layers=n_layers, + func=func, + bias=bias) + +def test_constructor_wrong(): + with pytest.raises(ValueError): + LowRankBlock(input_dimensions=input_dimensions, + embedding_dimenion=embedding_dimenion, + rank=0.5, + inner_size=inner_size, + n_layers=n_layers, + func=func, + bias=bias) + +def test_forward(): + block = LowRankBlock(input_dimensions=input_dimensions, + embedding_dimenion=embedding_dimenion, + rank=rank, + inner_size=inner_size, + n_layers=n_layers, + func=func, + bias=bias) + data = LabelTensor(torch.rand(10, 30, 3), labels=['x', 'y', 'u']) + block(data.extract('u'), data.extract(['x', 'y'])) + +def test_backward(): + block = LowRankBlock(input_dimensions=input_dimensions, + embedding_dimenion=embedding_dimenion, + rank=rank, + inner_size=inner_size, + n_layers=n_layers, + func=func, + bias=bias) + data = LabelTensor(torch.rand(10, 30, 3), labels=['x', 'y', 'u']) + data.requires_grad_(True) + out = block(data.extract('u'), data.extract(['x', 'y'])) + loss = out.mean() + loss.backward() \ No newline at end of file diff --git a/tests/test_model/test_lno.py b/tests/test_model/test_lno.py new file mode 100644 index 0000000..1cd09a7 --- /dev/null +++ b/tests/test_model/test_lno.py @@ -0,0 +1,141 @@ +import torch +from pina.model import LowRankNeuralOperator +from pina import LabelTensor +import pytest + + +batch_size = 15 +n_layers = 4 +embedding_dim = 24 +func = torch.nn.Tanh +rank = 4 +n_kernel_layers = 3 +field_indices = ['u'] +coordinates_indices = ['x', 'y'] + +def test_constructor(): + # working constructor + lifting_net = torch.nn.Linear(len(coordinates_indices) + len(field_indices), + embedding_dim) + projecting_net = torch.nn.Linear(embedding_dim + len(coordinates_indices), + len(field_indices)) + LowRankNeuralOperator( + lifting_net=lifting_net, + projecting_net=projecting_net, + coordinates_indices=coordinates_indices, + field_indices=field_indices, + n_kernel_layers=n_kernel_layers, + rank=rank) + + # not working constructor + with pytest.raises(ValueError): + LowRankNeuralOperator( + lifting_net=lifting_net, + projecting_net=projecting_net, + coordinates_indices=coordinates_indices, + field_indices=field_indices, + n_kernel_layers=3.2, # wrong + rank=rank) + + LowRankNeuralOperator( + lifting_net=[0], # wrong + projecting_net=projecting_net, + coordinates_indices=coordinates_indices, + field_indices=field_indices, + n_kernel_layers=n_kernel_layers, + rank=rank) + + LowRankNeuralOperator( + lifting_net=lifting_net, + projecting_net=[0], # wront + coordinates_indices=coordinates_indices, + field_indices=field_indices, + n_kernel_layers=n_kernel_layers, + rank=rank) + + LowRankNeuralOperator( + lifting_net=lifting_net, + projecting_net=projecting_net, + coordinates_indices=[0], #wrong + field_indices=field_indices, + n_kernel_layers=n_kernel_layers, + rank=rank) + + LowRankNeuralOperator( + lifting_net=lifting_net, + projecting_net=projecting_net, + coordinates_indices=coordinates_indices, + field_indices=[0], #wrong + n_kernel_layers=n_kernel_layers, + rank=rank) + + lifting_net = torch.nn.Linear(len(coordinates_indices), + embedding_dim) + LowRankNeuralOperator( + lifting_net=lifting_net, + projecting_net=projecting_net, + coordinates_indices=coordinates_indices, + field_indices=field_indices, + n_kernel_layers=n_kernel_layers, + rank=rank) + + lifting_net = torch.nn.Linear(len(coordinates_indices) + len(field_indices), + embedding_dim) + projecting_net = torch.nn.Linear(embedding_dim, + len(field_indices)) + LowRankNeuralOperator( + lifting_net=lifting_net, + projecting_net=projecting_net, + coordinates_indices=coordinates_indices, + field_indices=field_indices, + n_kernel_layers=n_kernel_layers, + rank=rank) + + +def test_forward(): + lifting_net = torch.nn.Linear(len(coordinates_indices) + len(field_indices), + embedding_dim) + projecting_net = torch.nn.Linear(embedding_dim + len(coordinates_indices), + len(field_indices)) + lno = LowRankNeuralOperator( + lifting_net=lifting_net, + projecting_net=projecting_net, + coordinates_indices=coordinates_indices, + field_indices=field_indices, + n_kernel_layers=n_kernel_layers, + rank=rank) + + input_ = LabelTensor( + torch.rand(batch_size, 100, + len(coordinates_indices) + len(field_indices)), + coordinates_indices + field_indices) + + out = lno(input_) + assert out.shape == torch.Size( + [batch_size, input_.shape[1], len(field_indices)]) + + +def test_backward(): + lifting_net = torch.nn.Linear(len(coordinates_indices) + len(field_indices), + embedding_dim) + projecting_net = torch.nn.Linear(embedding_dim + len(coordinates_indices), + len(field_indices)) + lno=LowRankNeuralOperator( + lifting_net=lifting_net, + projecting_net=projecting_net, + coordinates_indices=coordinates_indices, + field_indices=field_indices, + n_kernel_layers=n_kernel_layers, + rank=rank) + input_ = LabelTensor( + torch.rand(batch_size, 100, + len(coordinates_indices) + len(field_indices)), + coordinates_indices + field_indices) + input_ = input_.requires_grad_() + out = lno(input_) + tmp = torch.linalg.norm(out) + tmp.backward() + grad = input_.grad + assert grad.shape == torch.Size( + [batch_size, input_.shape[1], + len(coordinates_indices) + len(field_indices)]) \ No newline at end of file