26 lines
800 B
Python
26 lines
800 B
Python
from pina.model import GraphNeuralOperator
|
|
import torch
|
|
from torch_geometric.data import Data
|
|
|
|
|
|
class GNO(torch.nn.Module):
|
|
def __init__(
|
|
self, x_ch_node, f_ch_node, hidden, layers, edge_ch=0, out_ch=1
|
|
):
|
|
super().__init__()
|
|
|
|
lifting_operator = torch.nn.Linear(x_ch_node + f_ch_node, hidden)
|
|
self.gno = GraphNeuralOperator(
|
|
lifting_operator=lifting_operator,
|
|
projection_operator=torch.nn.Linear(hidden, out_ch),
|
|
edge_features=edge_ch,
|
|
n_layers=layers,
|
|
internal_n_layers=2,
|
|
shared_weights=False,
|
|
)
|
|
|
|
def forward(self, x, c, edge_index, edge_attr):
|
|
x = torch.cat([x, c], dim=-1)
|
|
x = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
|
|
return self.gno(x)
|