add warmup network

This commit is contained in:
2025-10-17 11:02:04 +02:00
parent 8f23a8af66
commit 477cc94f7f

View File

@@ -17,6 +17,24 @@ def plot_results_fn(x, pos, i, batch):
plt.close()
class WarmUpNet(nn.Module):
def __init__(self, input_dim, hidden_dim=64, n_layers=3):
super().__init__()
layers = []
layers.append(nn.Linear(input_dim, hidden_dim))
layers.append(nn.GELU())
for _ in range(n_layers - 2):
layers.append(nn.Linear(hidden_dim, hidden_dim))
layers.append(nn.GELU())
layers.append(nn.Linear(hidden_dim, input_dim))
self.net = nn.Sequential(*layers)
def forward(self, x, boundary_mask, boundary_values):
x = self.net(x)
x[boundary_mask] = boundary_values
return x
class EncX(nn.Module):
def __init__(self, x_ch, hidden):
super().__init__()
@@ -127,6 +145,7 @@ class GatingGNO(nn.Module):
self, x_ch_node, f_ch_node, hidden, layers, edge_ch=0, out_ch=1
):
super().__init__()
self.warmup = WarmUpNet(x_ch_node, hidden_dim=hidden, n_layers=3)
self.encoder_x = EncX(x_ch_node, hidden)
self.encoder_c = EncC(f_ch_node, hidden)
@@ -150,12 +169,12 @@ class GatingGNO(nn.Module):
pos=None,
boundary_mask=None,
):
x = self.warmup(x, boundary_mask, x[boundary_mask])
x = self.encoder_x(x)
c = self.encoder_c(c)
if plot_results:
x_ = self.dec(x)
plot_results_fn(x_, pos, 0, batch=batch)
bc = x[boundary_mask]
for _ in range(1, unrolling_steps + 1):
for i, blk in enumerate(self.blocks):
x = blk(x, c, edge_index, edge_attr=edge_attr)