import torch from torch import nn from torch_geometric.nn import MessagePassing from matplotlib.tri import Triangulation def plot_results_fn(x, pos, i, batch): x = x[batch == 0] pos = pos[batch == 0] tria = Triangulation(pos[:, 0].cpu(), pos[:, 1].cpu()) import matplotlib.pyplot as plt plt.tricontourf(tria, x[:, 0].cpu(), levels=14) plt.colorbar() plt.savefig(f"out_{i:03d}.png") plt.axis("equal") plt.close() class EncX(nn.Module): def __init__(self, x_ch, hidden): super().__init__() self.net = nn.Sequential( nn.Linear(x_ch, hidden // 2), nn.GELU(), nn.Linear(hidden // 2, hidden), nn.GELU(), ) def forward(self, x): return self.net(x) class EncC(nn.Module): def __init__(self, c_ch, hidden): super().__init__() self.net = nn.Sequential( nn.Linear(c_ch, hidden // 2), nn.GELU(), nn.Linear(hidden // 2, hidden), nn.GELU(), ) def forward(self, c): return self.net(c) class DecX(nn.Module): def __init__(self, hidden, out_ch): super().__init__() self.net = nn.Sequential( nn.Linear(hidden, hidden // 2), nn.GELU(), nn.Linear(hidden // 2, out_ch), nn.GELU(), ) def forward(self, x): return self.net(x) # class ConditionalGNOBlock(MessagePassing): # def __init__(self, hidden_ch, edge_ch=0, aggr="mean"): # super().__init__(aggr=aggr, node_dim=0) # self.edge_attr_net = nn.Sequential( # nn.Linear(edge_ch, hidden_ch // 2), # nn.SiLU(), # nn.Linear(hidden_ch // 2, 1), # nn.Softplus() # ) # self.diff_net = nn.Sequential( # nn.Linear(hidden_ch, hidden_ch), # nn.SiLU(), # nn.Linear(hidden_ch, hidden_ch), # ) # # self.x_net = nn.Sequential( # # nn.Linear(hidden_ch, hidden_ch), # # nn.SiLU(), # # nn.Linear(hidden_ch, hidden_ch), # # ) # self.c_ij_net = nn.Sequential( # nn.Linear(hidden_ch, hidden_ch // 2), # nn.SiLU(), # nn.Linear(hidden_ch // 2, 1), # nn.Sigmoid(), # ) # # self.gamma_net = nn.Sequential( # # nn.Linear(2 * hidden_ch, hidden_ch), # # nn.SiLU(), # # nn.Linear(hidden_ch, hidden_ch // 2), # # nn.SiLU(), # # nn.Linear(hidden_ch // 2, 1), # # nn.Sigmoid(), # # ) # self.alpha_net = nn.Sequential( # nn.Linear(2 * hidden_ch, hidden_ch), # nn.SiLU(), # nn.Linear(hidden_ch, hidden_ch // 2), # nn.SiLU(), # nn.Linear(hidden_ch // 2, 1), # nn.Sigmoid(), # ) # def forward(self, x, c, edge_index, edge_attr=None): # return self.propagate(edge_index, x=x, c=c, edge_attr=edge_attr) # def message(self, x_i, x_j, c_i, c_j, edge_attr): # c_ij = 0.5 * (c_i + c_j) # # gamma = self.gamma_net(torch.cat([x_i, x_j], dim=-1)) # # gate = torch.sself.edge_attr_net(edge_attr)) # gate = self.edge_attr_net(edge_attr) # # m = ( # # gamma * self.diff_net(x_j - x_i) + (1 - gamma) * self.x_net(x_j) # # ) * gate # m = self.diff_net(x_j - x_i) * gate # m = m * self.c_ij_net(c_ij) # return m # def update(self, aggr_out, x): # alpha = self.alpha_net(torch.cat([x, aggr_out], dim=-1)) # return x + alpha * aggr_out class ConditionalGNOBlock(MessagePassing): def __init__(self, hidden_ch, edge_ch=0, aggr="mean"): super().__init__(aggr=aggr, node_dim=0) self.edge_ch = edge_ch # Rete che mappa edge_attr -> coefficiente scalare (log-scale) # Se edge_ch==0 useremo un coefficiente apprendibile globale self.edge_attr_net = nn.Sequential( nn.Linear(edge_ch, hidden_ch), nn.GELU(), nn.Linear(hidden_ch, hidden_ch // 2), nn.GELU(), nn.Linear(hidden_ch // 2, 1), nn.Softplus(), ) # gating dalla condizione c_ij (restituisce scalar in (0,1)) self.c_ij_net = nn.Sequential( nn.Linear(hidden_ch, hidden_ch), nn.GELU(), nn.Linear(hidden_ch, hidden_ch // 2), nn.GELU(), nn.Linear(hidden_ch // 2, 1), nn.Sigmoid(), ) # alpha per passo (clampato tramite sigmoid) self.alpha_net = nn.Sequential( nn.Linear(2 * hidden_ch, hidden_ch), nn.GELU(), nn.Linear(hidden_ch, hidden_ch // 2), nn.GELU(), nn.Linear(hidden_ch // 2, 1), nn.Sigmoid(), ) self.diff_net = nn.Sequential( nn.Linear(hidden_ch, hidden_ch * 2), nn.GELU(), nn.Linear(hidden_ch * 2, hidden_ch**2), nn.GELU(), nn.Linear(hidden_ch**2, hidden_ch), nn.GELU(), ) # self.norm = nn.LayerNorm(hidden_ch) def forward(self, x, c, edge_index, edge_attr=None): # chiamiamo propagate; edge_attr può essere None return self.propagate(edge_index, x=x, c=c, edge_attr=edge_attr) def message(self, x_i, x_j, c_i, c_j, edge_attr): """ Implementazione diffusiva: m_ij = w_ij * (x_j - x_i) * c_gate_ij dove w_ij = softplus(edge_attr_net(edge_attr)) >= 0 """ c_ij = 0.5 * (c_i + c_j) # [E, H] c_gate = self.c_ij_net(c_ij) # [E, 1] in (0,1) w_raw = self.edge_attr_net(edge_attr) # [E,1] w = w_raw + 1e-8 diff = x_j - x_i # [E, H] m = w * self.diff_net(diff) + diff # [E,H] m = m * c_gate # [E,H] return m def update(self, aggr_out, x): """ TODO: doc """ alpha = self.alpha_net(torch.cat([x, aggr_out], dim=-1)) x_new = x + alpha * aggr_out return x_new class GatingGNO(nn.Module): """ TODO: add doc """ def __init__( self, x_ch_node, f_ch_node, hidden, layers, edge_ch=0, out_ch=1 ): super().__init__() self.encoder_x = EncX(x_ch_node, hidden) self.encoder_c = EncC(f_ch_node, hidden) self.blocks = nn.ModuleList( [ ConditionalGNOBlock(hidden_ch=hidden, edge_ch=edge_ch) for _ in range(layers) ] ) self.dec = DecX(hidden, out_ch) def forward( self, x, c, boundary, boundary_mask, edge_index, edge_attr=None, unrolling_steps=1, plot_results=False, batch=None, pos=None, ): x = self.encoder_x(x) c = self.encoder_c(c) if plot_results: x_ = self.dec(x) plot_results_fn(x_, pos, 0, batch=batch) for _ in range(1, unrolling_steps + 1): for i, blk in enumerate(self.blocks): x = blk(x, c, edge_index, edge_attr=edge_attr) if plot_results: x_ = self.dec(x) plot_results_fn(x_, pos, i * _, batch=batch) return self.dec(x)