From 477cc94f7fa83d1a301b9c9536e70847e37144a7 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Fri, 17 Oct 2025 11:02:04 +0200 Subject: [PATCH] add warmup network --- ThermalSolver/model/local_gno.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/ThermalSolver/model/local_gno.py b/ThermalSolver/model/local_gno.py index 7e2e022..1c93401 100644 --- a/ThermalSolver/model/local_gno.py +++ b/ThermalSolver/model/local_gno.py @@ -17,6 +17,24 @@ def plot_results_fn(x, pos, i, batch): plt.close() +class WarmUpNet(nn.Module): + def __init__(self, input_dim, hidden_dim=64, n_layers=3): + super().__init__() + layers = [] + layers.append(nn.Linear(input_dim, hidden_dim)) + layers.append(nn.GELU()) + for _ in range(n_layers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, input_dim)) + self.net = nn.Sequential(*layers) + + def forward(self, x, boundary_mask, boundary_values): + x = self.net(x) + x[boundary_mask] = boundary_values + return x + + class EncX(nn.Module): def __init__(self, x_ch, hidden): super().__init__() @@ -127,6 +145,7 @@ class GatingGNO(nn.Module): self, x_ch_node, f_ch_node, hidden, layers, edge_ch=0, out_ch=1 ): super().__init__() + self.warmup = WarmUpNet(x_ch_node, hidden_dim=hidden, n_layers=3) self.encoder_x = EncX(x_ch_node, hidden) self.encoder_c = EncC(f_ch_node, hidden) @@ -150,12 +169,12 @@ class GatingGNO(nn.Module): pos=None, boundary_mask=None, ): + x = self.warmup(x, boundary_mask, x[boundary_mask]) x = self.encoder_x(x) c = self.encoder_c(c) if plot_results: x_ = self.dec(x) plot_results_fn(x_, pos, 0, batch=batch) - bc = x[boundary_mask] for _ in range(1, unrolling_steps + 1): for i, blk in enumerate(self.blocks): x = blk(x, c, edge_index, edge_attr=edge_attr)