import torch from torch import nn from torch_geometric.nn import MessagePassing from matplotlib.tri import Triangulation def _import_boundary_conditions(x, boundary, boundary_mask): x[boundary_mask] = boundary def plot_results_fn(x, pos, i, batch): x = x[batch == 0] pos = pos[batch == 0] tria = Triangulation(pos[:, 0].cpu(), pos[:, 1].cpu()) import matplotlib.pyplot as plt plt.tricontourf(tria, x[:, 0].cpu(), levels=14) plt.colorbar() plt.savefig(f"out_{i:03d}.png") plt.axis("equal") plt.close() class EncX(nn.Module): def __init__(self, x_ch, hidden): super().__init__() self.net = nn.Sequential( nn.Linear(x_ch, hidden // 2), nn.SiLU(), nn.Linear(hidden // 2, hidden), ) def forward(self, x): return self.net(x) class EncC(nn.Module): def __init__(self, c_ch, hidden): super().__init__() self.net = nn.Sequential( nn.Linear(c_ch, hidden // 2), nn.SiLU(), nn.Linear(hidden // 2, hidden), ) def forward(self, c): return self.net(c) class DecX(nn.Module): def __init__(self, hidden, out_ch): super().__init__() self.net = nn.Sequential( nn.Linear(hidden, hidden // 2), nn.SiLU(), nn.Linear(hidden // 2, out_ch), ) def forward(self, x): return self.net(x) class ConditionalGNOBlock(MessagePassing): def __init__(self, hidden_ch, edge_ch=0, aggr="mean"): super().__init__(aggr=aggr, node_dim=0) # self.film_msg = FiLM(c_ch=hidden_ch, h_ch=hidden_ch) self.edge_attr_net = nn.Sequential( nn.Linear(edge_ch, hidden_ch // 2), nn.SiLU(), nn.Linear(hidden_ch // 2, hidden_ch), nn.Tanh(), ) self.msg_proj = nn.Sequential( nn.Linear(hidden_ch, hidden_ch), nn.SiLU(), nn.Linear(hidden_ch, hidden_ch), ) self.diff_net = nn.Sequential( nn.Linear(hidden_ch, hidden_ch), nn.SiLU(), nn.Linear(hidden_ch, hidden_ch), ) self.x_net = nn.Sequential( nn.Linear(hidden_ch, hidden_ch), nn.SiLU(), nn.Linear(hidden_ch, hidden_ch), ) self.c_ij_net = nn.Sequential( nn.Linear(hidden_ch, hidden_ch), nn.SiLU(), nn.Linear(hidden_ch, hidden_ch), nn.Tanh(), ) self.balancing = nn.Parameter(torch.tensor(0.0)) self.alpha = nn.Parameter(torch.tensor(1.0)) def forward(self, x, c, edge_index, edge_attr=None): return self.propagate(edge_index, x=x, c=c, edge_attr=edge_attr) def message(self, x_i, x_j, c_i, c_j, edge_attr): c_ij = 0.5 * (c_i + c_j) alpha = torch.sigmoid(self.balancing) gate = torch.sigmoid(self.edge_attr_net(edge_attr)) m = ( alpha * self.diff_net(x_j - x_i) + (1 - alpha) * self.x_net(x_j) * gate ) m = m * self.c_ij_net(c_ij) return m def update(self, aggr_out, x): return x + self.alpha * self.msg_proj(aggr_out) class GatingGNO(nn.Module): """ TODO: add doc """ def __init__( self, x_ch_node, f_ch_node, hidden, layers, edge_ch=0, out_ch=1 ): super().__init__() self.encoder_x = EncX(x_ch_node, hidden) self.encoder_c = EncC(f_ch_node, hidden) self.blocks = nn.ModuleList( [ ConditionalGNOBlock(hidden_ch=hidden, edge_ch=edge_ch) for _ in range(layers) ] ) self.dec = DecX(hidden, out_ch) def forward( self, x, c, boundary, boundary_mask, edge_index, edge_attr=None, unrolling_steps=1, plot_results=False, batch=None, pos=None, ): x = self.encoder_x(x) c = self.encoder_c(c) boundary = self.encoder_x(boundary) if plot_results: _import_boundary_conditions(x, boundary, boundary_mask) x_ = self.dec(x) plot_results_fn(x_, pos, 0, batch=batch) for _ in range(1, unrolling_steps + 1): _import_boundary_conditions(x, boundary, boundary_mask) for blk in self.blocks: x = blk(x, c, edge_index, edge_attr=edge_attr) if plot_results: x_ = self.dec(x) plot_results_fn(x_, pos, _, batch=batch) return self.dec(x)