new data format

This commit is contained in:
Filippo Olivo
2025-12-12 10:18:16 +01:00
parent 27c2aeb736
commit 732d48c360
3 changed files with 118 additions and 87 deletions

View File

@@ -33,14 +33,12 @@ class DiffusionLayer(MessagePassing):
@property
def alpha(self):
return torch.clamp(self.alpha_param, min=1e-5, max=1.0)
return torch.clamp(self.alpha_param, min=1e-7, max=1.0)
def forward(self, x, edge_index, edge_weight, conductivity):
edge_weight = edge_weight.unsqueeze(-1)
conductance = self.phys_encoder(edge_weight)
net_flux = self.propagate(edge_index, x=x, conductance=conductance)
# return (1-self.alpha) * x + self.alpha * net_flux
# return net_flux + x
return x + self.alpha * net_flux
def message(self, x_i, x_j, conductance):