add pirate network
This commit is contained in:
committed by
Giovanni Canali
parent
6d1d4ef423
commit
caa67ace93
@@ -104,6 +104,7 @@ Models
|
||||
LowRankNeuralOperator <model/low_rank_neural_operator.rst>
|
||||
GraphNeuralOperator <model/graph_neural_operator.rst>
|
||||
GraphNeuralKernel <model/graph_neural_operator_integral_kernel.rst>
|
||||
PirateNet <model/pirate_network.rst>
|
||||
|
||||
Blocks
|
||||
-------------
|
||||
@@ -121,6 +122,7 @@ Blocks
|
||||
Continuous Convolution Interface <model/block/convolution_interface.rst>
|
||||
Continuous Convolution Block <model/block/convolution.rst>
|
||||
Orthogonal Block <model/block/orthogonal.rst>
|
||||
PirateNet Block <model/block/pirate_network_block.rst>
|
||||
|
||||
Message Passing
|
||||
-------------------
|
||||
|
||||
8
docs/source/_rst/model/block/pirate_network_block.rst
Normal file
8
docs/source/_rst/model/block/pirate_network_block.rst
Normal file
@@ -0,0 +1,8 @@
|
||||
PirateNet Block
|
||||
=======================================
|
||||
.. currentmodule:: pina.model.block.pirate_network_block
|
||||
|
||||
.. autoclass:: PirateNetBlock
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
7
docs/source/_rst/model/pirate_network.rst
Normal file
7
docs/source/_rst/model/pirate_network.rst
Normal file
@@ -0,0 +1,7 @@
|
||||
PirateNet
|
||||
=======================
|
||||
.. currentmodule:: pina.model.pirate_network
|
||||
|
||||
.. autoclass:: PirateNet
|
||||
:members:
|
||||
:show-inheritance:
|
||||
@@ -13,6 +13,7 @@ __all__ = [
|
||||
"LowRankNeuralOperator",
|
||||
"Spline",
|
||||
"GraphNeuralOperator",
|
||||
"PirateNet",
|
||||
]
|
||||
|
||||
from .feed_forward import FeedForward, ResidualFeedForward
|
||||
@@ -24,3 +25,4 @@ from .average_neural_operator import AveragingNeuralOperator
|
||||
from .low_rank_neural_operator import LowRankNeuralOperator
|
||||
from .spline import Spline
|
||||
from .graph_neural_operator import GraphNeuralOperator
|
||||
from .pirate_network import PirateNet
|
||||
|
||||
@@ -18,6 +18,7 @@ __all__ = [
|
||||
"LowRankBlock",
|
||||
"RBFBlock",
|
||||
"GNOBlock",
|
||||
"PirateNetBlock",
|
||||
]
|
||||
|
||||
from .convolution_2d import ContinuousConvBlock
|
||||
@@ -35,3 +36,4 @@ from .average_neural_operator_block import AVNOBlock
|
||||
from .low_rank_block import LowRankBlock
|
||||
from .rbf_block import RBFBlock
|
||||
from .gno_block import GNOBlock
|
||||
from .pirate_network_block import PirateNetBlock
|
||||
|
||||
89
pina/model/block/pirate_network_block.py
Normal file
89
pina/model/block/pirate_network_block.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Module for the PirateNet block class."""
|
||||
|
||||
import torch
|
||||
from ...utils import check_consistency, check_positive_integer
|
||||
|
||||
|
||||
class PirateNetBlock(torch.nn.Module):
|
||||
"""
|
||||
The inner block of Physics-Informed residual adaptive network (PirateNet).
|
||||
|
||||
The block consists of three dense layers with dual gating operations and an
|
||||
adaptive residual connection. The trainable ``alpha`` parameter controls
|
||||
the contribution of the residual connection.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**:
|
||||
Wang, S., Sankaran, S., Stinis., P., Perdikaris, P. (2025).
|
||||
*Simulating Three-dimensional Turbulence with Physics-informed Neural
|
||||
Networks*.
|
||||
DOI: `arXiv preprint arXiv:2507.08972.
|
||||
<https://arxiv.org/abs/2507.08972>`_
|
||||
"""
|
||||
|
||||
def __init__(self, inner_size, activation):
|
||||
"""
|
||||
Initialization of the :class:`PirateNetBlock` class.
|
||||
|
||||
:param int inner_size: The number of hidden units in the dense layers.
|
||||
:param torch.nn.Module activation: The activation function.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Check consistency
|
||||
check_consistency(activation, torch.nn.Module, subclass=True)
|
||||
check_positive_integer(inner_size, strict=True)
|
||||
|
||||
# Initialize the linear transformations of the dense layers
|
||||
self.linear1 = torch.nn.Linear(inner_size, inner_size)
|
||||
self.linear2 = torch.nn.Linear(inner_size, inner_size)
|
||||
self.linear3 = torch.nn.Linear(inner_size, inner_size)
|
||||
|
||||
# Initialize the scales of the dense layers
|
||||
self.scale1 = torch.nn.Parameter(torch.zeros(inner_size))
|
||||
self.scale2 = torch.nn.Parameter(torch.zeros(inner_size))
|
||||
self.scale3 = torch.nn.Parameter(torch.zeros(inner_size))
|
||||
|
||||
# Initialize the adaptive residual connection parameter
|
||||
self._alpha = torch.nn.Parameter(torch.zeros(1))
|
||||
|
||||
# Initialize the activation function
|
||||
self.activation = activation()
|
||||
|
||||
def forward(self, x, U, V):
|
||||
"""
|
||||
Forward pass of the PirateNet block. It computes the output of the block
|
||||
by applying the dense layers with scaling, and combines the results with
|
||||
the input using the adaptive residual connection.
|
||||
|
||||
:param x: The input tensor.
|
||||
:type x: torch.Tensor | LabelTensor
|
||||
:param torch.Tensor U: The first shared gating tensor. It must have the
|
||||
same shape as ``x``.
|
||||
:param torch.Tensor V: The second shared gating tensor. It must have the
|
||||
same shape as ``x``.
|
||||
:return: The output tensor of the block.
|
||||
:rtype: torch.Tensor | LabelTensor
|
||||
"""
|
||||
# Compute the output of the first dense layer with scaling
|
||||
f = self.activation(self.linear1(x) * torch.exp(self.scale1))
|
||||
z1 = f * U + (1 - f) * V
|
||||
|
||||
# Compute the output of the second dense layer with scaling
|
||||
g = self.activation(self.linear2(z1) * torch.exp(self.scale2))
|
||||
z2 = g * U + (1 - g) * V
|
||||
|
||||
# Compute the output of the block
|
||||
h = self.activation(self.linear3(z2) * torch.exp(self.scale3))
|
||||
return self._alpha * h + (1 - self._alpha) * x
|
||||
|
||||
@property
|
||||
def alpha(self):
|
||||
"""
|
||||
Return the alpha parameter.
|
||||
|
||||
:return: The alpha parameter controlling the residual connection.
|
||||
:rtype: torch.nn.Parameter
|
||||
"""
|
||||
return self._alpha
|
||||
118
pina/model/pirate_network.py
Normal file
118
pina/model/pirate_network.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""Module for the PirateNet model class."""
|
||||
|
||||
import torch
|
||||
from .block import FourierFeatureEmbedding, PirateNetBlock
|
||||
from ..utils import check_consistency, check_positive_integer
|
||||
|
||||
|
||||
class PirateNet(torch.nn.Module):
|
||||
"""
|
||||
Implementation of Physics-Informed residual adaptive network (PirateNet).
|
||||
|
||||
The model consists of a Fourier feature embedding layer, multiple PirateNet
|
||||
blocks, and a final output layer. Each PirateNet block consist of three
|
||||
dense layers with dual gating mechanism and an adaptive residual connection,
|
||||
whose contribution is controlled by a trainable parameter ``alpha``.
|
||||
|
||||
The PirateNet, augmented with random weight factorization, is designed to
|
||||
mitigate spectral bias in deep networks.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**:
|
||||
Wang, S., Sankaran, S., Stinis., P., Perdikaris, P. (2025).
|
||||
*Simulating Three-dimensional Turbulence with Physics-informed Neural
|
||||
Networks*.
|
||||
DOI: `arXiv preprint arXiv:2507.08972.
|
||||
<https://arxiv.org/abs/2507.08972>`_
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dimension,
|
||||
inner_size,
|
||||
output_dimension,
|
||||
embedding=None,
|
||||
n_layers=3,
|
||||
activation=torch.nn.Tanh,
|
||||
):
|
||||
"""
|
||||
Initialization of the :class:`PirateNet` class.
|
||||
|
||||
:param int input_dimension: The number of input features.
|
||||
:param int inner_size: The number of hidden units in the dense layers.
|
||||
:param int output_dimension: The number of output features.
|
||||
:param torch.nn.Module embedding: The embedding module used to transform
|
||||
the input into a higher-dimensional feature space. If ``None``, a
|
||||
default :class:`~pina.model.block.FourierFeatureEmbedding` with
|
||||
scaling factor of 2 is used. Default is ``None``.
|
||||
:param int n_layers: The number of PirateNet blocks in the model.
|
||||
Default is 3.
|
||||
:param torch.nn.Module activation: The activation function to be used in
|
||||
the blocks. Default is :class:`torch.nn.Tanh`.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Check consistency
|
||||
check_consistency(activation, torch.nn.Module, subclass=True)
|
||||
check_positive_integer(input_dimension, strict=True)
|
||||
check_positive_integer(inner_size, strict=True)
|
||||
check_positive_integer(output_dimension, strict=True)
|
||||
check_positive_integer(n_layers, strict=True)
|
||||
|
||||
# Initialize the activation function
|
||||
self.activation = activation()
|
||||
|
||||
# Initialize the Fourier embedding
|
||||
self.embedding = embedding or FourierFeatureEmbedding(
|
||||
input_dimension=input_dimension,
|
||||
output_dimension=inner_size,
|
||||
sigma=2.0,
|
||||
)
|
||||
|
||||
# Initialize the shared dense layers
|
||||
self.linear1 = torch.nn.Linear(inner_size, inner_size)
|
||||
self.linear2 = torch.nn.Linear(inner_size, inner_size)
|
||||
|
||||
# Initialize the PirateNet blocks
|
||||
self.blocks = torch.nn.ModuleList(
|
||||
[PirateNetBlock(inner_size, activation) for _ in range(n_layers)]
|
||||
)
|
||||
|
||||
# Initialize the output layer
|
||||
self.output_layer = torch.nn.Linear(inner_size, output_dimension)
|
||||
|
||||
def forward(self, input_):
|
||||
"""
|
||||
Forward pass of the PirateNet model. It applies the Fourier feature
|
||||
embedding, computes the shared gating tensors U and V, and passes the
|
||||
input through each block in the network. Finally, it applies the output
|
||||
layer to produce the final output.
|
||||
|
||||
:param input_: The input tensor for the model.
|
||||
:type input_: torch.Tensor | LabelTensor
|
||||
:return: The output tensor of the model.
|
||||
:rtype: torch.Tensor | LabelTensor
|
||||
"""
|
||||
# Apply the Fourier feature embedding
|
||||
x = self.embedding(input_)
|
||||
|
||||
# Compute U and V from the shared dense layers
|
||||
U = self.activation(self.linear1(x))
|
||||
V = self.activation(self.linear2(x))
|
||||
|
||||
# Pass through each block in the network
|
||||
for block in self.blocks:
|
||||
x = block(x, U, V)
|
||||
|
||||
return self.output_layer(x)
|
||||
|
||||
@property
|
||||
def alpha(self):
|
||||
"""
|
||||
Return the alpha values of all PirateNetBlock layers.
|
||||
|
||||
:return: A list of alpha values from each block.
|
||||
:rtype: list
|
||||
"""
|
||||
return [block.alpha.item() for block in self.blocks]
|
||||
53
tests/test_blocks/test_pirate_network_block.py
Normal file
53
tests/test_blocks/test_pirate_network_block.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import torch
|
||||
import pytest
|
||||
from pina.model.block import PirateNetBlock
|
||||
|
||||
data = torch.rand((20, 3))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("inner_size", [10, 20])
|
||||
def test_constructor(inner_size):
|
||||
|
||||
PirateNetBlock(inner_size=inner_size, activation=torch.nn.Tanh)
|
||||
|
||||
# Should fail if inner_size is negative
|
||||
with pytest.raises(AssertionError):
|
||||
PirateNetBlock(inner_size=-1, activation=torch.nn.Tanh)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("inner_size", [10, 20])
|
||||
def test_forward(inner_size):
|
||||
|
||||
model = PirateNetBlock(inner_size=inner_size, activation=torch.nn.Tanh)
|
||||
|
||||
# Create dummy embedding
|
||||
dummy_embedding = torch.nn.Linear(data.shape[1], inner_size)
|
||||
x = dummy_embedding(data)
|
||||
|
||||
# Create dummy U and V tensors
|
||||
U = torch.rand((data.shape[0], inner_size))
|
||||
V = torch.rand((data.shape[0], inner_size))
|
||||
|
||||
output_ = model(x, U, V)
|
||||
assert output_.shape == (data.shape[0], inner_size)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("inner_size", [10, 20])
|
||||
def test_backward(inner_size):
|
||||
|
||||
model = PirateNetBlock(inner_size=inner_size, activation=torch.nn.Tanh)
|
||||
data.requires_grad_()
|
||||
|
||||
# Create dummy embedding
|
||||
dummy_embedding = torch.nn.Linear(data.shape[1], inner_size)
|
||||
x = dummy_embedding(data)
|
||||
|
||||
# Create dummy U and V tensors
|
||||
U = torch.rand((data.shape[0], inner_size))
|
||||
V = torch.rand((data.shape[0], inner_size))
|
||||
|
||||
output_ = model(x, U, V)
|
||||
|
||||
loss = torch.mean(output_)
|
||||
loss.backward()
|
||||
assert data.grad.shape == data.shape
|
||||
120
tests/test_model/test_pirate_network.py
Normal file
120
tests/test_model/test_pirate_network.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import torch
|
||||
import pytest
|
||||
from pina.model import PirateNet
|
||||
from pina.model.block import FourierFeatureEmbedding
|
||||
|
||||
data = torch.rand((20, 3))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("inner_size", [10, 20])
|
||||
@pytest.mark.parametrize("n_layers", [1, 3])
|
||||
@pytest.mark.parametrize("output_dimension", [2, 4])
|
||||
def test_constructor(inner_size, n_layers, output_dimension):
|
||||
|
||||
# Loop over the default and custom embedding
|
||||
for embedding in [None, torch.nn.Linear(data.shape[1], inner_size)]:
|
||||
|
||||
# Constructor
|
||||
model = PirateNet(
|
||||
input_dimension=data.shape[1],
|
||||
inner_size=inner_size,
|
||||
output_dimension=output_dimension,
|
||||
embedding=embedding,
|
||||
n_layers=n_layers,
|
||||
activation=torch.nn.Tanh,
|
||||
)
|
||||
|
||||
# Check the default embedding
|
||||
if embedding is None:
|
||||
assert isinstance(model.embedding, FourierFeatureEmbedding)
|
||||
assert model.embedding.sigma == 2.0
|
||||
|
||||
# Should fail if input_dimension is negative
|
||||
with pytest.raises(AssertionError):
|
||||
PirateNet(
|
||||
input_dimension=-1,
|
||||
inner_size=inner_size,
|
||||
output_dimension=output_dimension,
|
||||
embedding=embedding,
|
||||
n_layers=n_layers,
|
||||
activation=torch.nn.Tanh,
|
||||
)
|
||||
|
||||
# Should fail if inner_size is negative
|
||||
with pytest.raises(AssertionError):
|
||||
PirateNet(
|
||||
input_dimension=data.shape[1],
|
||||
inner_size=-1,
|
||||
output_dimension=output_dimension,
|
||||
embedding=embedding,
|
||||
n_layers=n_layers,
|
||||
activation=torch.nn.Tanh,
|
||||
)
|
||||
|
||||
# Should fail if output_dimension is negative
|
||||
with pytest.raises(AssertionError):
|
||||
PirateNet(
|
||||
input_dimension=data.shape[1],
|
||||
inner_size=inner_size,
|
||||
output_dimension=-1,
|
||||
embedding=embedding,
|
||||
n_layers=n_layers,
|
||||
activation=torch.nn.Tanh,
|
||||
)
|
||||
|
||||
# Should fail if n_layers is negative
|
||||
with pytest.raises(AssertionError):
|
||||
PirateNet(
|
||||
input_dimension=data.shape[1],
|
||||
inner_size=inner_size,
|
||||
output_dimension=output_dimension,
|
||||
embedding=embedding,
|
||||
n_layers=-1,
|
||||
activation=torch.nn.Tanh,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("inner_size", [10, 20])
|
||||
@pytest.mark.parametrize("n_layers", [1, 3])
|
||||
@pytest.mark.parametrize("output_dimension", [2, 4])
|
||||
def test_forward(inner_size, n_layers, output_dimension):
|
||||
|
||||
# Loop over the default and custom embedding
|
||||
for embedding in [None, torch.nn.Linear(data.shape[1], inner_size)]:
|
||||
|
||||
model = PirateNet(
|
||||
input_dimension=data.shape[1],
|
||||
inner_size=inner_size,
|
||||
output_dimension=output_dimension,
|
||||
embedding=embedding,
|
||||
n_layers=n_layers,
|
||||
activation=torch.nn.Tanh,
|
||||
)
|
||||
|
||||
output_ = model(data)
|
||||
assert output_.shape == (data.shape[0], output_dimension)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("inner_size", [10, 20])
|
||||
@pytest.mark.parametrize("n_layers", [1, 3])
|
||||
@pytest.mark.parametrize("output_dimension", [2, 4])
|
||||
def test_backward(inner_size, n_layers, output_dimension):
|
||||
|
||||
# Loop over the default and custom embedding
|
||||
for embedding in [None, torch.nn.Linear(data.shape[1], inner_size)]:
|
||||
|
||||
model = PirateNet(
|
||||
input_dimension=data.shape[1],
|
||||
inner_size=inner_size,
|
||||
output_dimension=output_dimension,
|
||||
embedding=embedding,
|
||||
n_layers=n_layers,
|
||||
activation=torch.nn.Tanh,
|
||||
)
|
||||
|
||||
data.requires_grad_()
|
||||
output_ = model(data)
|
||||
|
||||
loss = torch.mean(output_)
|
||||
loss.backward()
|
||||
assert data.grad.shape == data.shape
|
||||
Reference in New Issue
Block a user