import torch from torch import nn from torch_geometric.nn import MessagePassing from matplotlib.tri import Triangulation def plot_results_fn(x, pos, i, batch): x = x[batch == 0] pos = pos[batch == 0] tria = Triangulation(pos[:, 0].cpu(), pos[:, 1].cpu()) import matplotlib.pyplot as plt plt.tricontourf(tria, x[:, 0].cpu(), levels=14) plt.colorbar() plt.savefig(f"out_{i:03d}.png") plt.axis("equal") plt.close() class WarmUpNet(nn.Module): def __init__(self, input_dim, hidden_dim=64, n_layers=3): super().__init__() layers = [] layers.append(nn.Linear(input_dim, hidden_dim)) layers.append(nn.GELU()) for _ in range(n_layers - 2): layers.append(nn.Linear(hidden_dim, hidden_dim)) layers.append(nn.GELU()) layers.append(nn.Linear(hidden_dim, input_dim)) self.net = nn.Sequential(*layers) def forward(self, x, boundary_mask, boundary_values): x = self.net(x) x[boundary_mask] = boundary_values return x class EncX(nn.Module): def __init__(self, x_ch, hidden): super().__init__() self.net = nn.Sequential( nn.Linear(x_ch, hidden // 2), nn.GELU(), nn.Linear(hidden // 2, hidden), nn.GELU(), ) def forward(self, x): return self.net(x) class EncC(nn.Module): def __init__(self, c_ch, hidden): super().__init__() self.net = nn.Sequential( nn.Linear(c_ch, hidden // 2), nn.GELU(), nn.Linear(hidden // 2, hidden), nn.GELU(), ) def forward(self, c): return self.net(c) class DecX(nn.Module): def __init__(self, hidden, out_ch): super().__init__() self.net = nn.Sequential( nn.Linear(hidden, hidden // 2), nn.GELU(), nn.Linear(hidden // 2, out_ch), nn.GELU(), ) def forward(self, x): return self.net(x) class ConditionalGNOBlock(MessagePassing): def __init__(self, hidden_ch, edge_ch=0, aggr="mean"): super().__init__(aggr=aggr, node_dim=0) self.edge_attr_net = nn.Sequential( nn.Linear(edge_ch, hidden_ch // 2), nn.GELU(), nn.Linear(hidden_ch // 2, 1), nn.Softplus(), ) self.diff_net = nn.Sequential( nn.Linear(hidden_ch, hidden_ch * 2), nn.GELU(), nn.Linear(hidden_ch * 2, hidden_ch), nn.GELU(), ) self.c_ij_net = nn.Sequential( nn.Linear(hidden_ch, hidden_ch // 2), nn.GELU(), nn.Linear(hidden_ch // 2, 1), nn.Sigmoid(), ) self.gamma_net = nn.Sequential( nn.Linear(2 * hidden_ch, hidden_ch), nn.GELU(), nn.Linear(hidden_ch, hidden_ch // 2), nn.GELU(), nn.Linear(hidden_ch // 2, 1), nn.Sigmoid(), ) self.alpha_net = nn.Sequential( nn.Linear(2 * hidden_ch, hidden_ch), nn.GELU(), nn.Linear(hidden_ch, hidden_ch // 2), nn.GELU(), nn.Linear(hidden_ch // 2, 1), nn.Sigmoid(), ) def forward(self, x, c, edge_index, edge_attr=None): return self.propagate(edge_index, x=x, c=c, edge_attr=edge_attr) def message(self, x_i, x_j, c_i, c_j, edge_attr): c_ij = 0.5 * (c_i + c_j) gamma = self.gamma_net(torch.cat([x_i, x_j], dim=-1)) gate = self.edge_attr_net(edge_attr) m = self.diff_net(x_j - x_i) * gate m = m * self.c_ij_net(c_ij) return m def update(self, aggr_out, x): alpha = self.alpha_net(torch.cat([x, aggr_out], dim=-1)) return x + alpha * aggr_out class GatingGNO(nn.Module): """ TODO: add doc """ def __init__( self, x_ch_node, f_ch_node, hidden, layers, edge_ch=0, out_ch=1 ): super().__init__() self.warmup = WarmUpNet(x_ch_node, hidden_dim=hidden, n_layers=3) self.encoder_x = EncX(x_ch_node, hidden) self.encoder_c = EncC(f_ch_node, hidden) self.blocks = nn.ModuleList( [ ConditionalGNOBlock(hidden_ch=hidden, edge_ch=edge_ch) for _ in range(layers) ] ) self.dec = DecX(hidden, out_ch) def forward( self, x, c, edge_index, edge_attr=None, unrolling_steps=1, plot_results=False, batch=None, pos=None, boundary_mask=None, ): x = self.warmup(x, boundary_mask, x[boundary_mask]) x = self.encoder_x(x) c = self.encoder_c(c) if plot_results: x_ = self.dec(x) plot_results_fn(x_, pos, 0, batch=batch) for _ in range(1, unrolling_steps + 1): for i, blk in enumerate(self.blocks): x = blk(x, c, edge_index, edge_attr=edge_attr) if plot_results: x_ = self.dec(x) assert bc == x[boundary_mask] plot_results_fn(x_, pos, i * _, batch=batch) return self.dec(x)