add refined geometries in datamodule
This commit is contained in:
@@ -106,10 +106,13 @@ class ConditionalGNOBlock(MessagePassing):
|
||||
def message(self, x_i, x_j, c_i, c_j, edge_attr):
|
||||
c_ij = 0.5 * (c_i + c_j)
|
||||
alpha = torch.sigmoid(self.balancing)
|
||||
m = alpha * self.diff_net(x_j - x_i) + (1 - alpha) * self.x_net(x_j)
|
||||
gate = torch.sigmoid(self.edge_attr_net(edge_attr))
|
||||
m = (
|
||||
alpha * self.diff_net(x_j - x_i)
|
||||
+ (1 - alpha) * self.x_net(x_j) * gate
|
||||
)
|
||||
m = m * self.c_ij_net(c_ij)
|
||||
gate = self.edge_attr_net(edge_attr)
|
||||
return m * torch.sigmoid(gate)
|
||||
return m
|
||||
|
||||
def update(self, aggr_out, x):
|
||||
return x + self.alpha * self.msg_proj(aggr_out)
|
||||
|
||||
Reference in New Issue
Block a user