implement first GNO
This commit is contained in:
25
ThermalSolver/model/basic_gno.py
Normal file
25
ThermalSolver/model/basic_gno.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from pina.model import GraphNeuralOperator
|
||||
import torch
|
||||
from torch_geometric.data import Data
|
||||
|
||||
|
||||
class GNO(torch.nn.Module):
|
||||
def __init__(
|
||||
self, x_ch_node, f_ch_node, hidden, layers, edge_ch=0, out_ch=1
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
lifting_operator = torch.nn.Linear(x_ch_node + f_ch_node, hidden)
|
||||
self.gno = GraphNeuralOperator(
|
||||
lifting_operator=lifting_operator,
|
||||
projection_operator=torch.nn.Linear(hidden, out_ch),
|
||||
edge_features=edge_ch,
|
||||
n_layers=layers,
|
||||
internal_n_layers=2,
|
||||
shared_weights=False,
|
||||
)
|
||||
|
||||
def forward(self, x, c, edge_index, edge_attr):
|
||||
x = torch.cat([x, c], dim=-1)
|
||||
x = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
|
||||
return self.gno(x)
|
||||
Reference in New Issue
Block a user