Rename files
This commit is contained in:
committed by
Nicola Demo
parent
886bd23fdb
commit
ff43a7492b
177
pina/model/graph_neural_operator.py
Normal file
177
pina/model/graph_neural_operator.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import torch
|
||||
from torch.nn import Tanh
|
||||
from .block import GNOBlock
|
||||
from .kernel_neural_operator import KernelNeuralOperator
|
||||
|
||||
|
||||
class GraphNeuralKernel(torch.nn.Module):
|
||||
"""
|
||||
TODO add docstring
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
width,
|
||||
edge_features,
|
||||
n_layers=2,
|
||||
internal_n_layers=0,
|
||||
internal_layers=None,
|
||||
inner_size=None,
|
||||
internal_func=None,
|
||||
external_func=None,
|
||||
shared_weights=False
|
||||
):
|
||||
"""
|
||||
The Graph Neural Kernel constructor.
|
||||
|
||||
:param width: The width of the kernel.
|
||||
:type width: int
|
||||
:param edge_features: The number of edge features.
|
||||
:type edge_features: int
|
||||
:param n_layers: The number of kernel layers.
|
||||
:type n_layers: int
|
||||
:param internal_n_layers: The number of layers the FF Neural Network internal to each Kernel Layer.
|
||||
:type internal_n_layers: int
|
||||
:param internal_layers: Number of neurons of hidden layers(s) in the FF Neural Network inside for each Kernel Layer.
|
||||
:type internal_layers: list | tuple
|
||||
:param internal_func: The activation function used inside the computation of the representation of the edge features in the Graph Integral Layer.
|
||||
:param external_func: The activation function applied to the output of the Graph Integral Layer.
|
||||
:type external_func: torch.nn.Module
|
||||
:param shared_weights: If ``True`` the weights of the Graph Integral Layers are shared.
|
||||
"""
|
||||
super().__init__()
|
||||
if external_func is None:
|
||||
external_func = Tanh
|
||||
if internal_func is None:
|
||||
internal_func = Tanh
|
||||
|
||||
if shared_weights:
|
||||
self.layers = GNOBlock(
|
||||
width=width,
|
||||
edges_features=edge_features,
|
||||
n_layers=internal_n_layers,
|
||||
layers=internal_layers,
|
||||
inner_size=inner_size,
|
||||
internal_func=internal_func,
|
||||
external_func=external_func)
|
||||
self.n_layers = n_layers
|
||||
self.forward = self.forward_shared
|
||||
else:
|
||||
self.layers = torch.nn.ModuleList(
|
||||
[GNOBlock(
|
||||
width=width,
|
||||
edges_features=edge_features,
|
||||
n_layers=internal_n_layers,
|
||||
layers=internal_layers,
|
||||
inner_size=inner_size,
|
||||
internal_func=internal_func,
|
||||
external_func=external_func
|
||||
)
|
||||
for _ in range(n_layers)]
|
||||
)
|
||||
|
||||
def forward(self, x, edge_index, edge_attr):
|
||||
"""
|
||||
The forward pass of the Graph Neural Kernel used when the weights are not shared.
|
||||
|
||||
:param x: The input batch.
|
||||
:type x: torch.Tensor
|
||||
:param edge_index: The edge index.
|
||||
:type edge_index: torch.Tensor
|
||||
:param edge_attr: The edge attributes.
|
||||
:type edge_attr: torch.Tensor
|
||||
"""
|
||||
for layer in self.layers:
|
||||
x = layer(x, edge_index, edge_attr)
|
||||
return x
|
||||
|
||||
def forward_shared(self, x, edge_index, edge_attr):
|
||||
"""
|
||||
The forward pass of the Graph Neural Kernel used when the weights are shared.
|
||||
|
||||
:param x: The input batch.
|
||||
:type x: torch.Tensor
|
||||
:param edge_index: The edge index.
|
||||
:type edge_index: torch.Tensor
|
||||
:param edge_attr: The edge attributes.
|
||||
:type edge_attr: torch.Tensor
|
||||
"""
|
||||
for _ in range(self.n_layers):
|
||||
x = self.layers(x, edge_index, edge_attr)
|
||||
return x
|
||||
|
||||
|
||||
class GraphNeuralOperator(KernelNeuralOperator):
|
||||
"""
|
||||
TODO add docstring
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lifting_operator,
|
||||
projection_operator,
|
||||
edge_features,
|
||||
n_layers=10,
|
||||
internal_n_layers=0,
|
||||
inner_size=None,
|
||||
internal_layers=None,
|
||||
internal_func=None,
|
||||
external_func=None,
|
||||
shared_weights=True
|
||||
):
|
||||
"""
|
||||
The Graph Neural Operator constructor.
|
||||
|
||||
:param lifting_operator: The lifting operator mapping the node features to its hidden dimension.
|
||||
:type lifting_operator: torch.nn.Module
|
||||
:param projection_operator: The projection operator mapping the hidden representation of the nodes features to the output function.
|
||||
:type projection_operator: torch.nn.Module
|
||||
:param edge_features: Number of edge features.
|
||||
:type edge_features: int
|
||||
:param n_layers: The number of kernel layers.
|
||||
:type n_layers: int
|
||||
:param internal_n_layers: The number of layers the Feed Forward Neural Network internal to each Kernel Layer.
|
||||
:type internal_n_layers: int
|
||||
:param internal_layers: Number of neurons of hidden layers(s) in the FF Neural Network inside for each Kernel Layer.
|
||||
:type internal_layers: list | tuple
|
||||
:param internal_func: The activation function used inside the computation of the representation of the edge features in the Graph Integral Layer.
|
||||
:type internal_func: torch.nn.Module
|
||||
:param external_func: The activation function applied to the output of the Graph Integral Kernel.
|
||||
:type external_func: torch.nn.Module
|
||||
:param shared_weights: If ``True`` the weights of the Graph Integral Layers are shared.
|
||||
:type shared_weights: bool
|
||||
"""
|
||||
|
||||
if internal_func is None:
|
||||
internal_func = Tanh
|
||||
if external_func is None:
|
||||
external_func = Tanh
|
||||
|
||||
super().__init__(
|
||||
lifting_operator=lifting_operator,
|
||||
integral_kernels=GraphNeuralKernel(
|
||||
width=lifting_operator.out_features,
|
||||
edge_features=edge_features,
|
||||
internal_n_layers=internal_n_layers,
|
||||
inner_size=inner_size,
|
||||
internal_layers=internal_layers,
|
||||
external_func=external_func,
|
||||
internal_func=internal_func,
|
||||
n_layers=n_layers,
|
||||
shared_weights=shared_weights
|
||||
),
|
||||
projection_operator=projection_operator
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
The forward pass of the Graph Neural Operator.
|
||||
|
||||
:param x: The input batch.
|
||||
:type x: torch_geometric.data.Batch
|
||||
"""
|
||||
x, edge_index, edge_attr = x.x, x.edge_index, x.edge_attr
|
||||
x = self.lifting_operator(x)
|
||||
x = self.integral_kernels(x, edge_index, edge_attr)
|
||||
x = self.projection_operator(x)
|
||||
return x
|
||||
Reference in New Issue
Block a user