update model, model and datamodule
This commit is contained in:
@@ -77,13 +77,6 @@ class ConditionalGNOBlock(MessagePassing):
|
||||
nn.GELU(),
|
||||
)
|
||||
|
||||
self.x_net = nn.Sequential(
|
||||
nn.Linear(hidden_ch, hidden_ch * 2),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_ch * 2, hidden_ch),
|
||||
nn.GELU(),
|
||||
)
|
||||
|
||||
self.c_ij_net = nn.Sequential(
|
||||
nn.Linear(hidden_ch, hidden_ch // 2),
|
||||
nn.GELU(),
|
||||
@@ -116,9 +109,6 @@ class ConditionalGNOBlock(MessagePassing):
|
||||
c_ij = 0.5 * (c_i + c_j)
|
||||
gamma = self.gamma_net(torch.cat([x_i, x_j], dim=-1))
|
||||
gate = self.edge_attr_net(edge_attr)
|
||||
m = (
|
||||
gamma * self.diff_net(x_j - x_i) + (1 - gamma) * self.x_net(x_j)
|
||||
) * gate
|
||||
m = self.diff_net(x_j - x_i) * gate
|
||||
m = m * self.c_ij_net(c_ij)
|
||||
return m
|
||||
@@ -158,17 +148,20 @@ class GatingGNO(nn.Module):
|
||||
plot_results=False,
|
||||
batch=None,
|
||||
pos=None,
|
||||
boundary_mask=None,
|
||||
):
|
||||
x = self.encoder_x(x)
|
||||
c = self.encoder_c(c)
|
||||
if plot_results:
|
||||
x_ = self.dec(x)
|
||||
plot_results_fn(x_, pos, 0, batch=batch)
|
||||
bc = x[boundary_mask]
|
||||
for _ in range(1, unrolling_steps + 1):
|
||||
for i, blk in enumerate(self.blocks):
|
||||
x = blk(x, c, edge_index, edge_attr=edge_attr)
|
||||
if plot_results:
|
||||
x_ = self.dec(x)
|
||||
assert bc == x[boundary_mask]
|
||||
plot_results_fn(x_, pos, i * _, batch=batch)
|
||||
|
||||
return self.dec(x)
|
||||
|
||||
Reference in New Issue
Block a user