add model and solver that maybe works
This commit is contained in:
@@ -69,7 +69,7 @@ class GraphSolver(LightningModule):
|
||||
self.automatic_optimization = False
|
||||
self.threshold = 1e-5
|
||||
|
||||
self.aplha = 0.1
|
||||
self.alpha = torch.nn.Parameter(torch.tensor(0.1))
|
||||
|
||||
def _compute_deg(self, edge_index, edge_attr, num_nodes):
|
||||
deg = torch.zeros(num_nodes, device=edge_index.device)
|
||||
@@ -100,15 +100,15 @@ class GraphSolver(LightningModule):
|
||||
def _compute_model_steps(
|
||||
self, x, edge_index, edge_attr, deg, boundary_mask, boundary_values
|
||||
):
|
||||
with torch.no_grad():
|
||||
out = self.fd_net(x, edge_index, edge_attr, deg)
|
||||
out[boundary_mask] = boundary_values.unsqueeze(-1)
|
||||
# diff = out - x
|
||||
correction = self.model(x, edge_index, edge_attr, deg)
|
||||
out = out + self.aplha * correction
|
||||
out[boundary_mask] = boundary_values.unsqueeze(-1)
|
||||
# out = self.model(x, edge_index, edge_attr, deg)
|
||||
# with torch.no_grad():
|
||||
# out = self.fd_net(x, edge_index, edge_attr, deg)
|
||||
# out[boundary_mask] = boundary_values.unsqueeze(-1)
|
||||
# diff = out - x
|
||||
# out = self.model(out, edge_index, edge_attr, deg)
|
||||
# out = out + self.alpha * correction
|
||||
# out[boundary_mask] = boundary_values.unsqueeze(-1)
|
||||
out = self.model(x, edge_index, edge_attr, deg)
|
||||
out[boundary_mask] = boundary_values.unsqueeze(-1)
|
||||
return out
|
||||
|
||||
def _check_convergence(self, out, x):
|
||||
|
||||
Reference in New Issue
Block a user