add warmup network
This commit is contained in:
@@ -17,6 +17,24 @@ def plot_results_fn(x, pos, i, batch):
|
|||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
class WarmUpNet(nn.Module):
|
||||||
|
def __init__(self, input_dim, hidden_dim=64, n_layers=3):
|
||||||
|
super().__init__()
|
||||||
|
layers = []
|
||||||
|
layers.append(nn.Linear(input_dim, hidden_dim))
|
||||||
|
layers.append(nn.GELU())
|
||||||
|
for _ in range(n_layers - 2):
|
||||||
|
layers.append(nn.Linear(hidden_dim, hidden_dim))
|
||||||
|
layers.append(nn.GELU())
|
||||||
|
layers.append(nn.Linear(hidden_dim, input_dim))
|
||||||
|
self.net = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x, boundary_mask, boundary_values):
|
||||||
|
x = self.net(x)
|
||||||
|
x[boundary_mask] = boundary_values
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class EncX(nn.Module):
|
class EncX(nn.Module):
|
||||||
def __init__(self, x_ch, hidden):
|
def __init__(self, x_ch, hidden):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -127,6 +145,7 @@ class GatingGNO(nn.Module):
|
|||||||
self, x_ch_node, f_ch_node, hidden, layers, edge_ch=0, out_ch=1
|
self, x_ch_node, f_ch_node, hidden, layers, edge_ch=0, out_ch=1
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.warmup = WarmUpNet(x_ch_node, hidden_dim=hidden, n_layers=3)
|
||||||
self.encoder_x = EncX(x_ch_node, hidden)
|
self.encoder_x = EncX(x_ch_node, hidden)
|
||||||
self.encoder_c = EncC(f_ch_node, hidden)
|
self.encoder_c = EncC(f_ch_node, hidden)
|
||||||
|
|
||||||
@@ -150,12 +169,12 @@ class GatingGNO(nn.Module):
|
|||||||
pos=None,
|
pos=None,
|
||||||
boundary_mask=None,
|
boundary_mask=None,
|
||||||
):
|
):
|
||||||
|
x = self.warmup(x, boundary_mask, x[boundary_mask])
|
||||||
x = self.encoder_x(x)
|
x = self.encoder_x(x)
|
||||||
c = self.encoder_c(c)
|
c = self.encoder_c(c)
|
||||||
if plot_results:
|
if plot_results:
|
||||||
x_ = self.dec(x)
|
x_ = self.dec(x)
|
||||||
plot_results_fn(x_, pos, 0, batch=batch)
|
plot_results_fn(x_, pos, 0, batch=batch)
|
||||||
bc = x[boundary_mask]
|
|
||||||
for _ in range(1, unrolling_steps + 1):
|
for _ in range(1, unrolling_steps + 1):
|
||||||
for i, blk in enumerate(self.blocks):
|
for i, blk in enumerate(self.blocks):
|
||||||
x = blk(x, c, edge_index, edge_attr=edge_attr)
|
x = blk(x, c, edge_index, edge_attr=edge_attr)
|
||||||
|
|||||||
Reference in New Issue
Block a user