from pina.model import GraphNeuralOperator import torch from torch_geometric.data import Data class GNO(torch.nn.Module): def __init__( self, x_ch_node, f_ch_node, hidden, layers, edge_ch=0, out_ch=1 ): super().__init__() lifting_operator = torch.nn.Linear(x_ch_node + f_ch_node, hidden) self.gno = GraphNeuralOperator( lifting_operator=lifting_operator, projection_operator=torch.nn.Linear(hidden, out_ch), edge_features=edge_ch, n_layers=layers, internal_n_layers=2, shared_weights=False, ) def forward(self, x, c, edge_index, edge_attr): x = torch.cat([x, c], dim=-1) x = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) return self.gno(x)