new data format
This commit is contained in:
@@ -33,14 +33,12 @@ class DiffusionLayer(MessagePassing):
|
||||
|
||||
@property
|
||||
def alpha(self):
|
||||
return torch.clamp(self.alpha_param, min=1e-5, max=1.0)
|
||||
return torch.clamp(self.alpha_param, min=1e-7, max=1.0)
|
||||
|
||||
def forward(self, x, edge_index, edge_weight, conductivity):
|
||||
edge_weight = edge_weight.unsqueeze(-1)
|
||||
conductance = self.phys_encoder(edge_weight)
|
||||
net_flux = self.propagate(edge_index, x=x, conductance=conductance)
|
||||
# return (1-self.alpha) * x + self.alpha * net_flux
|
||||
# return net_flux + x
|
||||
return x + self.alpha * net_flux
|
||||
|
||||
def message(self, x_i, x_j, conductance):
|
||||
|
||||
Reference in New Issue
Block a user