diff --git a/ThermalSolver/model/local_gno.py b/ThermalSolver/model/local_gno.py index efdb868..d4e1099 100644 --- a/ThermalSolver/model/local_gno.py +++ b/ThermalSolver/model/local_gno.py @@ -59,92 +59,40 @@ class DecX(nn.Module): return self.net(x) -# class ConditionalGNOBlock(MessagePassing): -# def __init__(self, hidden_ch, edge_ch=0, aggr="mean"): -# super().__init__(aggr=aggr, node_dim=0) - -# self.edge_attr_net = nn.Sequential( -# nn.Linear(edge_ch, hidden_ch // 2), -# nn.SiLU(), -# nn.Linear(hidden_ch // 2, 1), -# nn.Softplus() -# ) - -# self.diff_net = nn.Sequential( -# nn.Linear(hidden_ch, hidden_ch), -# nn.SiLU(), -# nn.Linear(hidden_ch, hidden_ch), -# ) - -# # self.x_net = nn.Sequential( -# # nn.Linear(hidden_ch, hidden_ch), -# # nn.SiLU(), -# # nn.Linear(hidden_ch, hidden_ch), -# # ) - -# self.c_ij_net = nn.Sequential( -# nn.Linear(hidden_ch, hidden_ch // 2), -# nn.SiLU(), -# nn.Linear(hidden_ch // 2, 1), -# nn.Sigmoid(), -# ) - -# # self.gamma_net = nn.Sequential( -# # nn.Linear(2 * hidden_ch, hidden_ch), -# # nn.SiLU(), -# # nn.Linear(hidden_ch, hidden_ch // 2), -# # nn.SiLU(), -# # nn.Linear(hidden_ch // 2, 1), -# # nn.Sigmoid(), -# # ) - -# self.alpha_net = nn.Sequential( -# nn.Linear(2 * hidden_ch, hidden_ch), -# nn.SiLU(), -# nn.Linear(hidden_ch, hidden_ch // 2), -# nn.SiLU(), -# nn.Linear(hidden_ch // 2, 1), -# nn.Sigmoid(), -# ) - -# def forward(self, x, c, edge_index, edge_attr=None): -# return self.propagate(edge_index, x=x, c=c, edge_attr=edge_attr) - -# def message(self, x_i, x_j, c_i, c_j, edge_attr): -# c_ij = 0.5 * (c_i + c_j) -# # gamma = self.gamma_net(torch.cat([x_i, x_j], dim=-1)) -# # gate = torch.sself.edge_attr_net(edge_attr)) -# gate = self.edge_attr_net(edge_attr) -# # m = ( -# # gamma * self.diff_net(x_j - x_i) + (1 - gamma) * self.x_net(x_j) -# # ) * gate -# m = self.diff_net(x_j - x_i) * gate -# m = m * self.c_ij_net(c_ij) -# return m - -# def update(self, aggr_out, x): -# alpha = self.alpha_net(torch.cat([x, aggr_out], dim=-1)) -# return x + alpha * aggr_out - - class ConditionalGNOBlock(MessagePassing): def __init__(self, hidden_ch, edge_ch=0, aggr="mean"): super().__init__(aggr=aggr, node_dim=0) - self.edge_ch = edge_ch - # Rete che mappa edge_attr -> coefficiente scalare (log-scale) - # Se edge_ch==0 useremo un coefficiente apprendibile globale self.edge_attr_net = nn.Sequential( - nn.Linear(edge_ch, hidden_ch), - nn.GELU(), - nn.Linear(hidden_ch, hidden_ch // 2), + nn.Linear(edge_ch, hidden_ch // 2), nn.GELU(), nn.Linear(hidden_ch // 2, 1), nn.Softplus(), ) - # gating dalla condizione c_ij (restituisce scalar in (0,1)) + + self.diff_net = nn.Sequential( + nn.Linear(hidden_ch, hidden_ch * 2), + nn.GELU(), + nn.Linear(hidden_ch * 2, hidden_ch), + nn.GELU(), + ) + + self.x_net = nn.Sequential( + nn.Linear(hidden_ch, hidden_ch * 2), + nn.GELU(), + nn.Linear(hidden_ch * 2, hidden_ch), + nn.GELU(), + ) + self.c_ij_net = nn.Sequential( - nn.Linear(hidden_ch, hidden_ch), + nn.Linear(hidden_ch, hidden_ch // 2), + nn.GELU(), + nn.Linear(hidden_ch // 2, 1), + nn.Sigmoid(), + ) + + self.gamma_net = nn.Sequential( + nn.Linear(2 * hidden_ch, hidden_ch), nn.GELU(), nn.Linear(hidden_ch, hidden_ch // 2), nn.GELU(), @@ -152,7 +100,6 @@ class ConditionalGNOBlock(MessagePassing): nn.Sigmoid(), ) - # alpha per passo (clampato tramite sigmoid) self.alpha_net = nn.Sequential( nn.Linear(2 * hidden_ch, hidden_ch), nn.GELU(), @@ -162,43 +109,23 @@ class ConditionalGNOBlock(MessagePassing): nn.Sigmoid(), ) - self.diff_net = nn.Sequential( - nn.Linear(hidden_ch, hidden_ch * 2), - nn.GELU(), - nn.Linear(hidden_ch * 2, hidden_ch**2), - nn.GELU(), - nn.Linear(hidden_ch**2, hidden_ch), - nn.GELU(), - ) - - # self.norm = nn.LayerNorm(hidden_ch) - def forward(self, x, c, edge_index, edge_attr=None): - # chiamiamo propagate; edge_attr può essere None return self.propagate(edge_index, x=x, c=c, edge_attr=edge_attr) def message(self, x_i, x_j, c_i, c_j, edge_attr): - """ - Implementazione diffusiva: - m_ij = w_ij * (x_j - x_i) * c_gate_ij - dove w_ij = softplus(edge_attr_net(edge_attr)) >= 0 - """ - c_ij = 0.5 * (c_i + c_j) # [E, H] - c_gate = self.c_ij_net(c_ij) # [E, 1] in (0,1) - w_raw = self.edge_attr_net(edge_attr) # [E,1] - w = w_raw + 1e-8 - diff = x_j - x_i # [E, H] - m = w * self.diff_net(diff) + diff # [E,H] - m = m * c_gate # [E,H] + c_ij = 0.5 * (c_i + c_j) + gamma = self.gamma_net(torch.cat([x_i, x_j], dim=-1)) + gate = self.edge_attr_net(edge_attr) + m = ( + gamma * self.diff_net(x_j - x_i) + (1 - gamma) * self.x_net(x_j) + ) * gate + m = self.diff_net(x_j - x_i) * gate + m = m * self.c_ij_net(c_ij) return m def update(self, aggr_out, x): - """ - TODO: doc - """ alpha = self.alpha_net(torch.cat([x, aggr_out], dim=-1)) - x_new = x + alpha * aggr_out - return x_new + return x + alpha * aggr_out class GatingGNO(nn.Module): @@ -225,8 +152,6 @@ class GatingGNO(nn.Module): self, x, c, - boundary, - boundary_mask, edge_index, edge_attr=None, unrolling_steps=1,