add model and solver that maybe works

This commit is contained in:
Filippo Olivo
2025-11-20 11:38:50 +01:00
parent d865556c9f
commit 31059bf86e
3 changed files with 120 additions and 50 deletions

View File

@@ -69,7 +69,7 @@ class GraphSolver(LightningModule):
self.automatic_optimization = False self.automatic_optimization = False
self.threshold = 1e-5 self.threshold = 1e-5
self.aplha = 0.1 self.alpha = torch.nn.Parameter(torch.tensor(0.1))
def _compute_deg(self, edge_index, edge_attr, num_nodes): def _compute_deg(self, edge_index, edge_attr, num_nodes):
deg = torch.zeros(num_nodes, device=edge_index.device) deg = torch.zeros(num_nodes, device=edge_index.device)
@@ -100,15 +100,15 @@ class GraphSolver(LightningModule):
def _compute_model_steps( def _compute_model_steps(
self, x, edge_index, edge_attr, deg, boundary_mask, boundary_values self, x, edge_index, edge_attr, deg, boundary_mask, boundary_values
): ):
with torch.no_grad(): # with torch.no_grad():
out = self.fd_net(x, edge_index, edge_attr, deg) # out = self.fd_net(x, edge_index, edge_attr, deg)
out[boundary_mask] = boundary_values.unsqueeze(-1)
# diff = out - x
correction = self.model(x, edge_index, edge_attr, deg)
out = out + self.aplha * correction
out[boundary_mask] = boundary_values.unsqueeze(-1)
# out = self.model(x, edge_index, edge_attr, deg)
# out[boundary_mask] = boundary_values.unsqueeze(-1) # out[boundary_mask] = boundary_values.unsqueeze(-1)
# diff = out - x
# out = self.model(out, edge_index, edge_attr, deg)
# out = out + self.alpha * correction
# out[boundary_mask] = boundary_values.unsqueeze(-1)
out = self.model(x, edge_index, edge_attr, deg)
out[boundary_mask] = boundary_values.unsqueeze(-1)
return out return out
def _check_convergence(self, out, x): def _check_convergence(self, out, x):

View File

@@ -23,7 +23,6 @@ class FiniteDifferenceStep(MessagePassing):
""" """
TODO: add docstring. TODO: add docstring.
""" """
# return self.message_net(x_j * edge_attr)
return x_j * edge_attr return x_j * edge_attr
def update(self, aggr_out, _): def update(self, aggr_out, _):

View File

@@ -1,53 +1,124 @@
# import torch
# import torch.nn as nn
# from torch_geometric.nn import MessagePassing
# from torch.nn.utils import spectral_norm
# class GCNConvLayer(MessagePassing):
# def __init__(self, in_channels, out_channels):
# super().__init__(aggr="add")
# self.lin_l = spectral_norm(nn.Linear(in_channels, out_channels, bias=False))
# self.lin_r = spectral_norm(nn.Linear(in_channels, out_channels, bias=False))
# def forward(self, x, edge_index, edge_attr, deg):
# out = self.propagate(edge_index, x=x, edge_attr=edge_attr, deg=deg)
# out = self.lin_l(out)
# return out
# def message(self, x_j, edge_attr):
# return x_j * edge_attr
# def aggregate(self, inputs, index, deg):
# """
# TODO: add docstring.
# """
# out = super().aggregate(inputs, index)
# deg = deg + 1e-7
# return out / deg.view(-1, 1)
# class CorrectionNet(nn.Module):
# def __init__(self, hidden_dim=8, n_layers=1):
# super().__init__()
# # self.enc = GCNConvLayer(1, hidden_dim)
# self.enc = nn.Sequential(
# spectral_norm(nn.Linear(1, hidden_dim//2)),
# nn.GELU(),
# spectral_norm(nn.Linear(hidden_dim//2, hidden_dim)),
# )
# self.layers = torch.nn.ModuleList([GCNConvLayer(hidden_dim, hidden_dim) for _ in range(n_layers)])
# self.relu = nn.GELU()
# self.dec = nn.Sequential(
# spectral_norm(nn.Linear(hidden_dim, hidden_dim//2)),
# nn.GELU(),
# spectral_norm(nn.Linear(hidden_dim//2, 1)),
# )
# def forward(self, x, edge_index, edge_attr, deg,):
# # h = self.enc(x, edge_index, edge_attr, deg)
# # h = self.relu(self.enc(x))
# h = self.enc(x)
# for layer in self.layers:
# h = layer(h, edge_index, edge_attr, deg)
# # h = self.norm(h)
# h = self.relu(h)
# # out = self.dec(h, edge_index, edge_attr, deg)
# out = self.dec(h)
# return out
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch_geometric.nn import MessagePassing from torch_geometric.nn import MessagePassing
from torch.nn.utils import spectral_norm
# from torch.nn.utils import spectral_norm
class GCNConvLayer(MessagePassing): class CorrectionNet(MessagePassing):
def __init__(self, in_channels, out_channels): """
super().__init__("add") TODO: add docstring.
self.lin = nn.Sequential( """
nn.Linear(in_channels, out_channels),
nn.ReLU(), def __init__(self, hidden_dim=16):
nn.Linear(out_channels, out_channels), super().__init__(aggr="add")
nn.ReLU(), self.in_net = nn.Sequential(
spectral_norm(nn.Linear(1, hidden_dim // 2)),
nn.GELU(),
spectral_norm(nn.Linear(hidden_dim // 2, hidden_dim)),
) )
def _compute_edge_weight(self, edge_index, edge_w, deg): self.out_net = nn.Sequential(
""" """ spectral_norm(nn.Linear(hidden_dim, hidden_dim // 2)),
return edge_w.squeeze() / ( nn.GELU(),
1 + torch.sqrt(deg[edge_index[0]] * deg[edge_index[1]]) spectral_norm(nn.Linear(hidden_dim // 2, 1)),
) )
self.lin_msg = spectral_norm(
nn.Linear(hidden_dim, hidden_dim, bias=False)
)
self.lin_update = spectral_norm(
nn.Linear(hidden_dim, hidden_dim, bias=False)
)
self.alpha = nn.Parameter(torch.tensor(0.0))
self.beta = nn.Parameter(torch.tensor(0.0))
def forward(self, x, edge_index, edge_attr, deg): def forward(self, x, edge_index, edge_attr, deg):
edge_w = self._compute_edge_weight(edge_index, edge_attr, deg) """
return self.propagate(edge_index, x=x, edge_weight=edge_w, deg=deg) TODO: add docstring.
"""
x = self.in_net(x)
out = self.propagate(edge_index, x=x, edge_attr=edge_attr, deg=deg)
return self.out_net(out)
def message(self, x_j, edge_weight): def message(self, x_j, edge_attr):
return edge_weight.view(-1, 1) * x_j """
TODO: add docstring.
"""
alpha = torch.sigmoid(self.alpha)
msg = x_j * edge_attr
msg = (1 - alpha) * msg + alpha * self.lin_msg(msg)
return msg
def update(self, aggr_out, x):
"""
TODO: add docstring.
"""
beta = torch.sigmoid(self.beta)
return aggr_out * (1 - beta) + self.lin_msg(x) * beta
class CorrectionNet(nn.Module): def aggregate(self, inputs, index, deg):
def __init__(self, hidden_dim=8): """
super().__init__() TODO: add docstring.
self.enc = nn.Sequential( """
nn.Linear(1, hidden_dim // 2), out = super().aggregate(inputs, index)
nn.ReLU(), deg = deg + 1e-7
nn.Linear(hidden_dim // 2, hidden_dim), return out / deg.view(-1, 1)
nn.ReLU(),
)
self.model = GCNConvLayer(hidden_dim, hidden_dim)
self.dec = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, 1),
nn.ReLU(),
)
def forward(self, x, edge_index, edge_attr, deg):
h = self.enc(x)
h = self.model(h, edge_index, edge_attr, deg)
out = self.dec(h)
return out