add warmup network
This commit is contained in:
@@ -17,6 +17,24 @@ def plot_results_fn(x, pos, i, batch):
|
||||
plt.close()
|
||||
|
||||
|
||||
class WarmUpNet(nn.Module):
|
||||
def __init__(self, input_dim, hidden_dim=64, n_layers=3):
|
||||
super().__init__()
|
||||
layers = []
|
||||
layers.append(nn.Linear(input_dim, hidden_dim))
|
||||
layers.append(nn.GELU())
|
||||
for _ in range(n_layers - 2):
|
||||
layers.append(nn.Linear(hidden_dim, hidden_dim))
|
||||
layers.append(nn.GELU())
|
||||
layers.append(nn.Linear(hidden_dim, input_dim))
|
||||
self.net = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x, boundary_mask, boundary_values):
|
||||
x = self.net(x)
|
||||
x[boundary_mask] = boundary_values
|
||||
return x
|
||||
|
||||
|
||||
class EncX(nn.Module):
|
||||
def __init__(self, x_ch, hidden):
|
||||
super().__init__()
|
||||
@@ -127,6 +145,7 @@ class GatingGNO(nn.Module):
|
||||
self, x_ch_node, f_ch_node, hidden, layers, edge_ch=0, out_ch=1
|
||||
):
|
||||
super().__init__()
|
||||
self.warmup = WarmUpNet(x_ch_node, hidden_dim=hidden, n_layers=3)
|
||||
self.encoder_x = EncX(x_ch_node, hidden)
|
||||
self.encoder_c = EncC(f_ch_node, hidden)
|
||||
|
||||
@@ -150,12 +169,12 @@ class GatingGNO(nn.Module):
|
||||
pos=None,
|
||||
boundary_mask=None,
|
||||
):
|
||||
x = self.warmup(x, boundary_mask, x[boundary_mask])
|
||||
x = self.encoder_x(x)
|
||||
c = self.encoder_c(c)
|
||||
if plot_results:
|
||||
x_ = self.dec(x)
|
||||
plot_results_fn(x_, pos, 0, batch=batch)
|
||||
bc = x[boundary_mask]
|
||||
for _ in range(1, unrolling_steps + 1):
|
||||
for i, blk in enumerate(self.blocks):
|
||||
x = blk(x, c, edge_index, edge_attr=edge_attr)
|
||||
|
||||
Reference in New Issue
Block a user