add pirate network

This commit is contained in:
Giovanni Canali
2025-07-21 17:10:31 +02:00
committed by Giovanni Canali
parent 6d1d4ef423
commit caa67ace93
9 changed files with 401 additions and 0 deletions

View File

@@ -104,6 +104,7 @@ Models
LowRankNeuralOperator <model/low_rank_neural_operator.rst> LowRankNeuralOperator <model/low_rank_neural_operator.rst>
GraphNeuralOperator <model/graph_neural_operator.rst> GraphNeuralOperator <model/graph_neural_operator.rst>
GraphNeuralKernel <model/graph_neural_operator_integral_kernel.rst> GraphNeuralKernel <model/graph_neural_operator_integral_kernel.rst>
PirateNet <model/pirate_network.rst>
Blocks Blocks
------------- -------------
@@ -121,6 +122,7 @@ Blocks
Continuous Convolution Interface <model/block/convolution_interface.rst> Continuous Convolution Interface <model/block/convolution_interface.rst>
Continuous Convolution Block <model/block/convolution.rst> Continuous Convolution Block <model/block/convolution.rst>
Orthogonal Block <model/block/orthogonal.rst> Orthogonal Block <model/block/orthogonal.rst>
PirateNet Block <model/block/pirate_network_block.rst>
Message Passing Message Passing
------------------- -------------------

View File

@@ -0,0 +1,8 @@
PirateNet Block
=======================================
.. currentmodule:: pina.model.block.pirate_network_block
.. autoclass:: PirateNetBlock
:members:
:show-inheritance:

View File

@@ -0,0 +1,7 @@
PirateNet
=======================
.. currentmodule:: pina.model.pirate_network
.. autoclass:: PirateNet
:members:
:show-inheritance:

View File

@@ -13,6 +13,7 @@ __all__ = [
"LowRankNeuralOperator", "LowRankNeuralOperator",
"Spline", "Spline",
"GraphNeuralOperator", "GraphNeuralOperator",
"PirateNet",
] ]
from .feed_forward import FeedForward, ResidualFeedForward from .feed_forward import FeedForward, ResidualFeedForward
@@ -24,3 +25,4 @@ from .average_neural_operator import AveragingNeuralOperator
from .low_rank_neural_operator import LowRankNeuralOperator from .low_rank_neural_operator import LowRankNeuralOperator
from .spline import Spline from .spline import Spline
from .graph_neural_operator import GraphNeuralOperator from .graph_neural_operator import GraphNeuralOperator
from .pirate_network import PirateNet

View File

@@ -18,6 +18,7 @@ __all__ = [
"LowRankBlock", "LowRankBlock",
"RBFBlock", "RBFBlock",
"GNOBlock", "GNOBlock",
"PirateNetBlock",
] ]
from .convolution_2d import ContinuousConvBlock from .convolution_2d import ContinuousConvBlock
@@ -35,3 +36,4 @@ from .average_neural_operator_block import AVNOBlock
from .low_rank_block import LowRankBlock from .low_rank_block import LowRankBlock
from .rbf_block import RBFBlock from .rbf_block import RBFBlock
from .gno_block import GNOBlock from .gno_block import GNOBlock
from .pirate_network_block import PirateNetBlock

View File

@@ -0,0 +1,89 @@
"""Module for the PirateNet block class."""
import torch
from ...utils import check_consistency, check_positive_integer
class PirateNetBlock(torch.nn.Module):
"""
The inner block of Physics-Informed residual adaptive network (PirateNet).
The block consists of three dense layers with dual gating operations and an
adaptive residual connection. The trainable ``alpha`` parameter controls
the contribution of the residual connection.
.. seealso::
**Original reference**:
Wang, S., Sankaran, S., Stinis., P., Perdikaris, P. (2025).
*Simulating Three-dimensional Turbulence with Physics-informed Neural
Networks*.
DOI: `arXiv preprint arXiv:2507.08972.
<https://arxiv.org/abs/2507.08972>`_
"""
def __init__(self, inner_size, activation):
"""
Initialization of the :class:`PirateNetBlock` class.
:param int inner_size: The number of hidden units in the dense layers.
:param torch.nn.Module activation: The activation function.
"""
super().__init__()
# Check consistency
check_consistency(activation, torch.nn.Module, subclass=True)
check_positive_integer(inner_size, strict=True)
# Initialize the linear transformations of the dense layers
self.linear1 = torch.nn.Linear(inner_size, inner_size)
self.linear2 = torch.nn.Linear(inner_size, inner_size)
self.linear3 = torch.nn.Linear(inner_size, inner_size)
# Initialize the scales of the dense layers
self.scale1 = torch.nn.Parameter(torch.zeros(inner_size))
self.scale2 = torch.nn.Parameter(torch.zeros(inner_size))
self.scale3 = torch.nn.Parameter(torch.zeros(inner_size))
# Initialize the adaptive residual connection parameter
self._alpha = torch.nn.Parameter(torch.zeros(1))
# Initialize the activation function
self.activation = activation()
def forward(self, x, U, V):
"""
Forward pass of the PirateNet block. It computes the output of the block
by applying the dense layers with scaling, and combines the results with
the input using the adaptive residual connection.
:param x: The input tensor.
:type x: torch.Tensor | LabelTensor
:param torch.Tensor U: The first shared gating tensor. It must have the
same shape as ``x``.
:param torch.Tensor V: The second shared gating tensor. It must have the
same shape as ``x``.
:return: The output tensor of the block.
:rtype: torch.Tensor | LabelTensor
"""
# Compute the output of the first dense layer with scaling
f = self.activation(self.linear1(x) * torch.exp(self.scale1))
z1 = f * U + (1 - f) * V
# Compute the output of the second dense layer with scaling
g = self.activation(self.linear2(z1) * torch.exp(self.scale2))
z2 = g * U + (1 - g) * V
# Compute the output of the block
h = self.activation(self.linear3(z2) * torch.exp(self.scale3))
return self._alpha * h + (1 - self._alpha) * x
@property
def alpha(self):
"""
Return the alpha parameter.
:return: The alpha parameter controlling the residual connection.
:rtype: torch.nn.Parameter
"""
return self._alpha

View File

@@ -0,0 +1,118 @@
"""Module for the PirateNet model class."""
import torch
from .block import FourierFeatureEmbedding, PirateNetBlock
from ..utils import check_consistency, check_positive_integer
class PirateNet(torch.nn.Module):
"""
Implementation of Physics-Informed residual adaptive network (PirateNet).
The model consists of a Fourier feature embedding layer, multiple PirateNet
blocks, and a final output layer. Each PirateNet block consist of three
dense layers with dual gating mechanism and an adaptive residual connection,
whose contribution is controlled by a trainable parameter ``alpha``.
The PirateNet, augmented with random weight factorization, is designed to
mitigate spectral bias in deep networks.
.. seealso::
**Original reference**:
Wang, S., Sankaran, S., Stinis., P., Perdikaris, P. (2025).
*Simulating Three-dimensional Turbulence with Physics-informed Neural
Networks*.
DOI: `arXiv preprint arXiv:2507.08972.
<https://arxiv.org/abs/2507.08972>`_
"""
def __init__(
self,
input_dimension,
inner_size,
output_dimension,
embedding=None,
n_layers=3,
activation=torch.nn.Tanh,
):
"""
Initialization of the :class:`PirateNet` class.
:param int input_dimension: The number of input features.
:param int inner_size: The number of hidden units in the dense layers.
:param int output_dimension: The number of output features.
:param torch.nn.Module embedding: The embedding module used to transform
the input into a higher-dimensional feature space. If ``None``, a
default :class:`~pina.model.block.FourierFeatureEmbedding` with
scaling factor of 2 is used. Default is ``None``.
:param int n_layers: The number of PirateNet blocks in the model.
Default is 3.
:param torch.nn.Module activation: The activation function to be used in
the blocks. Default is :class:`torch.nn.Tanh`.
"""
super().__init__()
# Check consistency
check_consistency(activation, torch.nn.Module, subclass=True)
check_positive_integer(input_dimension, strict=True)
check_positive_integer(inner_size, strict=True)
check_positive_integer(output_dimension, strict=True)
check_positive_integer(n_layers, strict=True)
# Initialize the activation function
self.activation = activation()
# Initialize the Fourier embedding
self.embedding = embedding or FourierFeatureEmbedding(
input_dimension=input_dimension,
output_dimension=inner_size,
sigma=2.0,
)
# Initialize the shared dense layers
self.linear1 = torch.nn.Linear(inner_size, inner_size)
self.linear2 = torch.nn.Linear(inner_size, inner_size)
# Initialize the PirateNet blocks
self.blocks = torch.nn.ModuleList(
[PirateNetBlock(inner_size, activation) for _ in range(n_layers)]
)
# Initialize the output layer
self.output_layer = torch.nn.Linear(inner_size, output_dimension)
def forward(self, input_):
"""
Forward pass of the PirateNet model. It applies the Fourier feature
embedding, computes the shared gating tensors U and V, and passes the
input through each block in the network. Finally, it applies the output
layer to produce the final output.
:param input_: The input tensor for the model.
:type input_: torch.Tensor | LabelTensor
:return: The output tensor of the model.
:rtype: torch.Tensor | LabelTensor
"""
# Apply the Fourier feature embedding
x = self.embedding(input_)
# Compute U and V from the shared dense layers
U = self.activation(self.linear1(x))
V = self.activation(self.linear2(x))
# Pass through each block in the network
for block in self.blocks:
x = block(x, U, V)
return self.output_layer(x)
@property
def alpha(self):
"""
Return the alpha values of all PirateNetBlock layers.
:return: A list of alpha values from each block.
:rtype: list
"""
return [block.alpha.item() for block in self.blocks]

View File

@@ -0,0 +1,53 @@
import torch
import pytest
from pina.model.block import PirateNetBlock
data = torch.rand((20, 3))
@pytest.mark.parametrize("inner_size", [10, 20])
def test_constructor(inner_size):
PirateNetBlock(inner_size=inner_size, activation=torch.nn.Tanh)
# Should fail if inner_size is negative
with pytest.raises(AssertionError):
PirateNetBlock(inner_size=-1, activation=torch.nn.Tanh)
@pytest.mark.parametrize("inner_size", [10, 20])
def test_forward(inner_size):
model = PirateNetBlock(inner_size=inner_size, activation=torch.nn.Tanh)
# Create dummy embedding
dummy_embedding = torch.nn.Linear(data.shape[1], inner_size)
x = dummy_embedding(data)
# Create dummy U and V tensors
U = torch.rand((data.shape[0], inner_size))
V = torch.rand((data.shape[0], inner_size))
output_ = model(x, U, V)
assert output_.shape == (data.shape[0], inner_size)
@pytest.mark.parametrize("inner_size", [10, 20])
def test_backward(inner_size):
model = PirateNetBlock(inner_size=inner_size, activation=torch.nn.Tanh)
data.requires_grad_()
# Create dummy embedding
dummy_embedding = torch.nn.Linear(data.shape[1], inner_size)
x = dummy_embedding(data)
# Create dummy U and V tensors
U = torch.rand((data.shape[0], inner_size))
V = torch.rand((data.shape[0], inner_size))
output_ = model(x, U, V)
loss = torch.mean(output_)
loss.backward()
assert data.grad.shape == data.shape

View File

@@ -0,0 +1,120 @@
import torch
import pytest
from pina.model import PirateNet
from pina.model.block import FourierFeatureEmbedding
data = torch.rand((20, 3))
@pytest.mark.parametrize("inner_size", [10, 20])
@pytest.mark.parametrize("n_layers", [1, 3])
@pytest.mark.parametrize("output_dimension", [2, 4])
def test_constructor(inner_size, n_layers, output_dimension):
# Loop over the default and custom embedding
for embedding in [None, torch.nn.Linear(data.shape[1], inner_size)]:
# Constructor
model = PirateNet(
input_dimension=data.shape[1],
inner_size=inner_size,
output_dimension=output_dimension,
embedding=embedding,
n_layers=n_layers,
activation=torch.nn.Tanh,
)
# Check the default embedding
if embedding is None:
assert isinstance(model.embedding, FourierFeatureEmbedding)
assert model.embedding.sigma == 2.0
# Should fail if input_dimension is negative
with pytest.raises(AssertionError):
PirateNet(
input_dimension=-1,
inner_size=inner_size,
output_dimension=output_dimension,
embedding=embedding,
n_layers=n_layers,
activation=torch.nn.Tanh,
)
# Should fail if inner_size is negative
with pytest.raises(AssertionError):
PirateNet(
input_dimension=data.shape[1],
inner_size=-1,
output_dimension=output_dimension,
embedding=embedding,
n_layers=n_layers,
activation=torch.nn.Tanh,
)
# Should fail if output_dimension is negative
with pytest.raises(AssertionError):
PirateNet(
input_dimension=data.shape[1],
inner_size=inner_size,
output_dimension=-1,
embedding=embedding,
n_layers=n_layers,
activation=torch.nn.Tanh,
)
# Should fail if n_layers is negative
with pytest.raises(AssertionError):
PirateNet(
input_dimension=data.shape[1],
inner_size=inner_size,
output_dimension=output_dimension,
embedding=embedding,
n_layers=-1,
activation=torch.nn.Tanh,
)
@pytest.mark.parametrize("inner_size", [10, 20])
@pytest.mark.parametrize("n_layers", [1, 3])
@pytest.mark.parametrize("output_dimension", [2, 4])
def test_forward(inner_size, n_layers, output_dimension):
# Loop over the default and custom embedding
for embedding in [None, torch.nn.Linear(data.shape[1], inner_size)]:
model = PirateNet(
input_dimension=data.shape[1],
inner_size=inner_size,
output_dimension=output_dimension,
embedding=embedding,
n_layers=n_layers,
activation=torch.nn.Tanh,
)
output_ = model(data)
assert output_.shape == (data.shape[0], output_dimension)
@pytest.mark.parametrize("inner_size", [10, 20])
@pytest.mark.parametrize("n_layers", [1, 3])
@pytest.mark.parametrize("output_dimension", [2, 4])
def test_backward(inner_size, n_layers, output_dimension):
# Loop over the default and custom embedding
for embedding in [None, torch.nn.Linear(data.shape[1], inner_size)]:
model = PirateNet(
input_dimension=data.shape[1],
inner_size=inner_size,
output_dimension=output_dimension,
embedding=embedding,
n_layers=n_layers,
activation=torch.nn.Tanh,
)
data.requires_grad_()
output_ = model(data)
loss = torch.mean(output_)
loss.backward()
assert data.grad.shape == data.shape