implement new model
This commit is contained in:
@@ -4,10 +4,6 @@ from torch_geometric.nn import MessagePassing
|
||||
from matplotlib.tri import Triangulation
|
||||
|
||||
|
||||
def _import_boundary_conditions(x, boundary, boundary_mask):
|
||||
x[boundary_mask] = boundary
|
||||
|
||||
|
||||
def plot_results_fn(x, pos, i, batch):
|
||||
x = x[batch == 0]
|
||||
pos = pos[batch == 0]
|
||||
@@ -63,7 +59,6 @@ class DecX(nn.Module):
|
||||
class ConditionalGNOBlock(MessagePassing):
|
||||
def __init__(self, hidden_ch, edge_ch=0, aggr="mean"):
|
||||
super().__init__(aggr=aggr, node_dim=0)
|
||||
# self.film_msg = FiLM(c_ch=hidden_ch, h_ch=hidden_ch)
|
||||
|
||||
self.edge_attr_net = nn.Sequential(
|
||||
nn.Linear(edge_ch, hidden_ch // 2),
|
||||
@@ -73,9 +68,9 @@ class ConditionalGNOBlock(MessagePassing):
|
||||
)
|
||||
|
||||
self.msg_proj = nn.Sequential(
|
||||
nn.Linear(hidden_ch, hidden_ch),
|
||||
nn.Linear(hidden_ch, hidden_ch, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_ch, hidden_ch),
|
||||
nn.Linear(hidden_ch, hidden_ch, bias=False),
|
||||
)
|
||||
|
||||
self.diff_net = nn.Sequential(
|
||||
@@ -98,7 +93,15 @@ class ConditionalGNOBlock(MessagePassing):
|
||||
)
|
||||
|
||||
self.balancing = nn.Parameter(torch.tensor(0.0))
|
||||
self.alpha = nn.Parameter(torch.tensor(1.0))
|
||||
|
||||
self.alpha_net = nn.Sequential(
|
||||
nn.Linear(2 * hidden_ch, hidden_ch),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_ch, hidden_ch // 2),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_ch // 2, 1),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, x, c, edge_index, edge_attr=None):
|
||||
return self.propagate(edge_index, x=x, c=c, edge_attr=edge_attr)
|
||||
@@ -108,14 +111,14 @@ class ConditionalGNOBlock(MessagePassing):
|
||||
alpha = torch.sigmoid(self.balancing)
|
||||
gate = torch.sigmoid(self.edge_attr_net(edge_attr))
|
||||
m = (
|
||||
alpha * self.diff_net(x_j - x_i)
|
||||
+ (1 - alpha) * self.x_net(x_j) * gate
|
||||
)
|
||||
alpha * self.diff_net(x_j - x_i) + (1 - alpha) * self.x_net(x_j)
|
||||
) * gate
|
||||
m = m * self.c_ij_net(c_ij)
|
||||
return m
|
||||
|
||||
def update(self, aggr_out, x):
|
||||
return x + self.alpha * self.msg_proj(aggr_out)
|
||||
alpha = self.alpha_net(torch.cat([x, aggr_out], dim=-1))
|
||||
return x + alpha * self.msg_proj(aggr_out)
|
||||
|
||||
|
||||
class GatingGNO(nn.Module):
|
||||
@@ -153,14 +156,10 @@ class GatingGNO(nn.Module):
|
||||
):
|
||||
x = self.encoder_x(x)
|
||||
c = self.encoder_c(c)
|
||||
boundary = self.encoder_x(boundary)
|
||||
if plot_results:
|
||||
_import_boundary_conditions(x, boundary, boundary_mask)
|
||||
x_ = self.dec(x)
|
||||
plot_results_fn(x_, pos, 0, batch=batch)
|
||||
|
||||
for _ in range(1, unrolling_steps + 1):
|
||||
_import_boundary_conditions(x, boundary, boundary_mask)
|
||||
for blk in self.blocks:
|
||||
x = blk(x, c, edge_index, edge_attr=edge_attr)
|
||||
if plot_results:
|
||||
|
||||
Reference in New Issue
Block a user