Files
thermal-conduction-ml/ThermalSolver/model/basic_gno.py
2025-09-25 14:44:39 +02:00

26 lines
800 B
Python

from pina.model import GraphNeuralOperator
import torch
from torch_geometric.data import Data
class GNO(torch.nn.Module):
def __init__(
self, x_ch_node, f_ch_node, hidden, layers, edge_ch=0, out_ch=1
):
super().__init__()
lifting_operator = torch.nn.Linear(x_ch_node + f_ch_node, hidden)
self.gno = GraphNeuralOperator(
lifting_operator=lifting_operator,
projection_operator=torch.nn.Linear(hidden, out_ch),
edge_features=edge_ch,
n_layers=layers,
internal_n_layers=2,
shared_weights=False,
)
def forward(self, x, c, edge_index, edge_attr):
x = torch.cat([x, c], dim=-1)
x = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
return self.gno(x)