From caa67ace93cc3d736d4693ab6e80f29a5edf3b53 Mon Sep 17 00:00:00 2001 From: Giovanni Canali Date: Mon, 21 Jul 2025 17:10:31 +0200 Subject: [PATCH] add pirate network --- docs/source/_rst/_code.rst | 2 + .../_rst/model/block/pirate_network_block.rst | 8 ++ docs/source/_rst/model/pirate_network.rst | 7 + pina/model/__init__.py | 2 + pina/model/block/__init__.py | 2 + pina/model/block/pirate_network_block.py | 89 +++++++++++++ pina/model/pirate_network.py | 118 +++++++++++++++++ .../test_blocks/test_pirate_network_block.py | 53 ++++++++ tests/test_model/test_pirate_network.py | 120 ++++++++++++++++++ 9 files changed, 401 insertions(+) create mode 100644 docs/source/_rst/model/block/pirate_network_block.rst create mode 100644 docs/source/_rst/model/pirate_network.rst create mode 100644 pina/model/block/pirate_network_block.py create mode 100644 pina/model/pirate_network.py create mode 100644 tests/test_blocks/test_pirate_network_block.py create mode 100644 tests/test_model/test_pirate_network.py diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index 3655379..9bd36ab 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -104,6 +104,7 @@ Models LowRankNeuralOperator GraphNeuralOperator GraphNeuralKernel + PirateNet Blocks ------------- @@ -121,6 +122,7 @@ Blocks Continuous Convolution Interface Continuous Convolution Block Orthogonal Block + PirateNet Block Message Passing ------------------- diff --git a/docs/source/_rst/model/block/pirate_network_block.rst b/docs/source/_rst/model/block/pirate_network_block.rst new file mode 100644 index 0000000..5d0428a --- /dev/null +++ b/docs/source/_rst/model/block/pirate_network_block.rst @@ -0,0 +1,8 @@ +PirateNet Block +======================================= +.. currentmodule:: pina.model.block.pirate_network_block + +.. autoclass:: PirateNetBlock + :members: + :show-inheritance: + diff --git a/docs/source/_rst/model/pirate_network.rst b/docs/source/_rst/model/pirate_network.rst new file mode 100644 index 0000000..5b374c2 --- /dev/null +++ b/docs/source/_rst/model/pirate_network.rst @@ -0,0 +1,7 @@ +PirateNet +======================= +.. currentmodule:: pina.model.pirate_network + +.. autoclass:: PirateNet + :members: + :show-inheritance: \ No newline at end of file diff --git a/pina/model/__init__.py b/pina/model/__init__.py index 606dde7..5e34048 100644 --- a/pina/model/__init__.py +++ b/pina/model/__init__.py @@ -13,6 +13,7 @@ __all__ = [ "LowRankNeuralOperator", "Spline", "GraphNeuralOperator", + "PirateNet", ] from .feed_forward import FeedForward, ResidualFeedForward @@ -24,3 +25,4 @@ from .average_neural_operator import AveragingNeuralOperator from .low_rank_neural_operator import LowRankNeuralOperator from .spline import Spline from .graph_neural_operator import GraphNeuralOperator +from .pirate_network import PirateNet diff --git a/pina/model/block/__init__.py b/pina/model/block/__init__.py index c40135b..08b3133 100644 --- a/pina/model/block/__init__.py +++ b/pina/model/block/__init__.py @@ -18,6 +18,7 @@ __all__ = [ "LowRankBlock", "RBFBlock", "GNOBlock", + "PirateNetBlock", ] from .convolution_2d import ContinuousConvBlock @@ -35,3 +36,4 @@ from .average_neural_operator_block import AVNOBlock from .low_rank_block import LowRankBlock from .rbf_block import RBFBlock from .gno_block import GNOBlock +from .pirate_network_block import PirateNetBlock diff --git a/pina/model/block/pirate_network_block.py b/pina/model/block/pirate_network_block.py new file mode 100644 index 0000000..cfeb841 --- /dev/null +++ b/pina/model/block/pirate_network_block.py @@ -0,0 +1,89 @@ +"""Module for the PirateNet block class.""" + +import torch +from ...utils import check_consistency, check_positive_integer + + +class PirateNetBlock(torch.nn.Module): + """ + The inner block of Physics-Informed residual adaptive network (PirateNet). + + The block consists of three dense layers with dual gating operations and an + adaptive residual connection. The trainable ``alpha`` parameter controls + the contribution of the residual connection. + + .. seealso:: + + **Original reference**: + Wang, S., Sankaran, S., Stinis., P., Perdikaris, P. (2025). + *Simulating Three-dimensional Turbulence with Physics-informed Neural + Networks*. + DOI: `arXiv preprint arXiv:2507.08972. + `_ + """ + + def __init__(self, inner_size, activation): + """ + Initialization of the :class:`PirateNetBlock` class. + + :param int inner_size: The number of hidden units in the dense layers. + :param torch.nn.Module activation: The activation function. + """ + super().__init__() + + # Check consistency + check_consistency(activation, torch.nn.Module, subclass=True) + check_positive_integer(inner_size, strict=True) + + # Initialize the linear transformations of the dense layers + self.linear1 = torch.nn.Linear(inner_size, inner_size) + self.linear2 = torch.nn.Linear(inner_size, inner_size) + self.linear3 = torch.nn.Linear(inner_size, inner_size) + + # Initialize the scales of the dense layers + self.scale1 = torch.nn.Parameter(torch.zeros(inner_size)) + self.scale2 = torch.nn.Parameter(torch.zeros(inner_size)) + self.scale3 = torch.nn.Parameter(torch.zeros(inner_size)) + + # Initialize the adaptive residual connection parameter + self._alpha = torch.nn.Parameter(torch.zeros(1)) + + # Initialize the activation function + self.activation = activation() + + def forward(self, x, U, V): + """ + Forward pass of the PirateNet block. It computes the output of the block + by applying the dense layers with scaling, and combines the results with + the input using the adaptive residual connection. + + :param x: The input tensor. + :type x: torch.Tensor | LabelTensor + :param torch.Tensor U: The first shared gating tensor. It must have the + same shape as ``x``. + :param torch.Tensor V: The second shared gating tensor. It must have the + same shape as ``x``. + :return: The output tensor of the block. + :rtype: torch.Tensor | LabelTensor + """ + # Compute the output of the first dense layer with scaling + f = self.activation(self.linear1(x) * torch.exp(self.scale1)) + z1 = f * U + (1 - f) * V + + # Compute the output of the second dense layer with scaling + g = self.activation(self.linear2(z1) * torch.exp(self.scale2)) + z2 = g * U + (1 - g) * V + + # Compute the output of the block + h = self.activation(self.linear3(z2) * torch.exp(self.scale3)) + return self._alpha * h + (1 - self._alpha) * x + + @property + def alpha(self): + """ + Return the alpha parameter. + + :return: The alpha parameter controlling the residual connection. + :rtype: torch.nn.Parameter + """ + return self._alpha diff --git a/pina/model/pirate_network.py b/pina/model/pirate_network.py new file mode 100644 index 0000000..96102b4 --- /dev/null +++ b/pina/model/pirate_network.py @@ -0,0 +1,118 @@ +"""Module for the PirateNet model class.""" + +import torch +from .block import FourierFeatureEmbedding, PirateNetBlock +from ..utils import check_consistency, check_positive_integer + + +class PirateNet(torch.nn.Module): + """ + Implementation of Physics-Informed residual adaptive network (PirateNet). + + The model consists of a Fourier feature embedding layer, multiple PirateNet + blocks, and a final output layer. Each PirateNet block consist of three + dense layers with dual gating mechanism and an adaptive residual connection, + whose contribution is controlled by a trainable parameter ``alpha``. + + The PirateNet, augmented with random weight factorization, is designed to + mitigate spectral bias in deep networks. + + .. seealso:: + + **Original reference**: + Wang, S., Sankaran, S., Stinis., P., Perdikaris, P. (2025). + *Simulating Three-dimensional Turbulence with Physics-informed Neural + Networks*. + DOI: `arXiv preprint arXiv:2507.08972. + `_ + """ + + def __init__( + self, + input_dimension, + inner_size, + output_dimension, + embedding=None, + n_layers=3, + activation=torch.nn.Tanh, + ): + """ + Initialization of the :class:`PirateNet` class. + + :param int input_dimension: The number of input features. + :param int inner_size: The number of hidden units in the dense layers. + :param int output_dimension: The number of output features. + :param torch.nn.Module embedding: The embedding module used to transform + the input into a higher-dimensional feature space. If ``None``, a + default :class:`~pina.model.block.FourierFeatureEmbedding` with + scaling factor of 2 is used. Default is ``None``. + :param int n_layers: The number of PirateNet blocks in the model. + Default is 3. + :param torch.nn.Module activation: The activation function to be used in + the blocks. Default is :class:`torch.nn.Tanh`. + """ + super().__init__() + + # Check consistency + check_consistency(activation, torch.nn.Module, subclass=True) + check_positive_integer(input_dimension, strict=True) + check_positive_integer(inner_size, strict=True) + check_positive_integer(output_dimension, strict=True) + check_positive_integer(n_layers, strict=True) + + # Initialize the activation function + self.activation = activation() + + # Initialize the Fourier embedding + self.embedding = embedding or FourierFeatureEmbedding( + input_dimension=input_dimension, + output_dimension=inner_size, + sigma=2.0, + ) + + # Initialize the shared dense layers + self.linear1 = torch.nn.Linear(inner_size, inner_size) + self.linear2 = torch.nn.Linear(inner_size, inner_size) + + # Initialize the PirateNet blocks + self.blocks = torch.nn.ModuleList( + [PirateNetBlock(inner_size, activation) for _ in range(n_layers)] + ) + + # Initialize the output layer + self.output_layer = torch.nn.Linear(inner_size, output_dimension) + + def forward(self, input_): + """ + Forward pass of the PirateNet model. It applies the Fourier feature + embedding, computes the shared gating tensors U and V, and passes the + input through each block in the network. Finally, it applies the output + layer to produce the final output. + + :param input_: The input tensor for the model. + :type input_: torch.Tensor | LabelTensor + :return: The output tensor of the model. + :rtype: torch.Tensor | LabelTensor + """ + # Apply the Fourier feature embedding + x = self.embedding(input_) + + # Compute U and V from the shared dense layers + U = self.activation(self.linear1(x)) + V = self.activation(self.linear2(x)) + + # Pass through each block in the network + for block in self.blocks: + x = block(x, U, V) + + return self.output_layer(x) + + @property + def alpha(self): + """ + Return the alpha values of all PirateNetBlock layers. + + :return: A list of alpha values from each block. + :rtype: list + """ + return [block.alpha.item() for block in self.blocks] diff --git a/tests/test_blocks/test_pirate_network_block.py b/tests/test_blocks/test_pirate_network_block.py new file mode 100644 index 0000000..b827d24 --- /dev/null +++ b/tests/test_blocks/test_pirate_network_block.py @@ -0,0 +1,53 @@ +import torch +import pytest +from pina.model.block import PirateNetBlock + +data = torch.rand((20, 3)) + + +@pytest.mark.parametrize("inner_size", [10, 20]) +def test_constructor(inner_size): + + PirateNetBlock(inner_size=inner_size, activation=torch.nn.Tanh) + + # Should fail if inner_size is negative + with pytest.raises(AssertionError): + PirateNetBlock(inner_size=-1, activation=torch.nn.Tanh) + + +@pytest.mark.parametrize("inner_size", [10, 20]) +def test_forward(inner_size): + + model = PirateNetBlock(inner_size=inner_size, activation=torch.nn.Tanh) + + # Create dummy embedding + dummy_embedding = torch.nn.Linear(data.shape[1], inner_size) + x = dummy_embedding(data) + + # Create dummy U and V tensors + U = torch.rand((data.shape[0], inner_size)) + V = torch.rand((data.shape[0], inner_size)) + + output_ = model(x, U, V) + assert output_.shape == (data.shape[0], inner_size) + + +@pytest.mark.parametrize("inner_size", [10, 20]) +def test_backward(inner_size): + + model = PirateNetBlock(inner_size=inner_size, activation=torch.nn.Tanh) + data.requires_grad_() + + # Create dummy embedding + dummy_embedding = torch.nn.Linear(data.shape[1], inner_size) + x = dummy_embedding(data) + + # Create dummy U and V tensors + U = torch.rand((data.shape[0], inner_size)) + V = torch.rand((data.shape[0], inner_size)) + + output_ = model(x, U, V) + + loss = torch.mean(output_) + loss.backward() + assert data.grad.shape == data.shape diff --git a/tests/test_model/test_pirate_network.py b/tests/test_model/test_pirate_network.py new file mode 100644 index 0000000..f552f81 --- /dev/null +++ b/tests/test_model/test_pirate_network.py @@ -0,0 +1,120 @@ +import torch +import pytest +from pina.model import PirateNet +from pina.model.block import FourierFeatureEmbedding + +data = torch.rand((20, 3)) + + +@pytest.mark.parametrize("inner_size", [10, 20]) +@pytest.mark.parametrize("n_layers", [1, 3]) +@pytest.mark.parametrize("output_dimension", [2, 4]) +def test_constructor(inner_size, n_layers, output_dimension): + + # Loop over the default and custom embedding + for embedding in [None, torch.nn.Linear(data.shape[1], inner_size)]: + + # Constructor + model = PirateNet( + input_dimension=data.shape[1], + inner_size=inner_size, + output_dimension=output_dimension, + embedding=embedding, + n_layers=n_layers, + activation=torch.nn.Tanh, + ) + + # Check the default embedding + if embedding is None: + assert isinstance(model.embedding, FourierFeatureEmbedding) + assert model.embedding.sigma == 2.0 + + # Should fail if input_dimension is negative + with pytest.raises(AssertionError): + PirateNet( + input_dimension=-1, + inner_size=inner_size, + output_dimension=output_dimension, + embedding=embedding, + n_layers=n_layers, + activation=torch.nn.Tanh, + ) + + # Should fail if inner_size is negative + with pytest.raises(AssertionError): + PirateNet( + input_dimension=data.shape[1], + inner_size=-1, + output_dimension=output_dimension, + embedding=embedding, + n_layers=n_layers, + activation=torch.nn.Tanh, + ) + + # Should fail if output_dimension is negative + with pytest.raises(AssertionError): + PirateNet( + input_dimension=data.shape[1], + inner_size=inner_size, + output_dimension=-1, + embedding=embedding, + n_layers=n_layers, + activation=torch.nn.Tanh, + ) + + # Should fail if n_layers is negative + with pytest.raises(AssertionError): + PirateNet( + input_dimension=data.shape[1], + inner_size=inner_size, + output_dimension=output_dimension, + embedding=embedding, + n_layers=-1, + activation=torch.nn.Tanh, + ) + + +@pytest.mark.parametrize("inner_size", [10, 20]) +@pytest.mark.parametrize("n_layers", [1, 3]) +@pytest.mark.parametrize("output_dimension", [2, 4]) +def test_forward(inner_size, n_layers, output_dimension): + + # Loop over the default and custom embedding + for embedding in [None, torch.nn.Linear(data.shape[1], inner_size)]: + + model = PirateNet( + input_dimension=data.shape[1], + inner_size=inner_size, + output_dimension=output_dimension, + embedding=embedding, + n_layers=n_layers, + activation=torch.nn.Tanh, + ) + + output_ = model(data) + assert output_.shape == (data.shape[0], output_dimension) + + +@pytest.mark.parametrize("inner_size", [10, 20]) +@pytest.mark.parametrize("n_layers", [1, 3]) +@pytest.mark.parametrize("output_dimension", [2, 4]) +def test_backward(inner_size, n_layers, output_dimension): + + # Loop over the default and custom embedding + for embedding in [None, torch.nn.Linear(data.shape[1], inner_size)]: + + model = PirateNet( + input_dimension=data.shape[1], + inner_size=inner_size, + output_dimension=output_dimension, + embedding=embedding, + n_layers=n_layers, + activation=torch.nn.Tanh, + ) + + data.requires_grad_() + output_ = model(data) + + loss = torch.mean(output_) + loss.backward() + assert data.grad.shape == data.shape