add model and solver that maybe works

This commit is contained in:
Filippo Olivo
2025-11-20 11:38:50 +01:00
parent d865556c9f
commit 31059bf86e
3 changed files with 120 additions and 50 deletions

View File

@@ -69,7 +69,7 @@ class GraphSolver(LightningModule):
self.automatic_optimization = False
self.threshold = 1e-5
self.aplha = 0.1
self.alpha = torch.nn.Parameter(torch.tensor(0.1))
def _compute_deg(self, edge_index, edge_attr, num_nodes):
deg = torch.zeros(num_nodes, device=edge_index.device)
@@ -100,15 +100,15 @@ class GraphSolver(LightningModule):
def _compute_model_steps(
self, x, edge_index, edge_attr, deg, boundary_mask, boundary_values
):
with torch.no_grad():
out = self.fd_net(x, edge_index, edge_attr, deg)
out[boundary_mask] = boundary_values.unsqueeze(-1)
# diff = out - x
correction = self.model(x, edge_index, edge_attr, deg)
out = out + self.aplha * correction
out[boundary_mask] = boundary_values.unsqueeze(-1)
# out = self.model(x, edge_index, edge_attr, deg)
# with torch.no_grad():
# out = self.fd_net(x, edge_index, edge_attr, deg)
# out[boundary_mask] = boundary_values.unsqueeze(-1)
# diff = out - x
# out = self.model(out, edge_index, edge_attr, deg)
# out = out + self.alpha * correction
# out[boundary_mask] = boundary_values.unsqueeze(-1)
out = self.model(x, edge_index, edge_attr, deg)
out[boundary_mask] = boundary_values.unsqueeze(-1)
return out
def _check_convergence(self, out, x):