From 31059bf86e3969c4a500565b9db10ec2cb56c546 Mon Sep 17 00:00:00 2001 From: Filippo Olivo Date: Thu, 20 Nov 2025 11:38:50 +0100 Subject: [PATCH] add model and solver that maybe works --- ThermalSolver/graph_module.py | 18 +-- ThermalSolver/model/finite_difference.py | 1 - .../model/learnable_finite_difference.py | 151 +++++++++++++----- 3 files changed, 120 insertions(+), 50 deletions(-) diff --git a/ThermalSolver/graph_module.py b/ThermalSolver/graph_module.py index e866fc4..88b5fcd 100644 --- a/ThermalSolver/graph_module.py +++ b/ThermalSolver/graph_module.py @@ -69,7 +69,7 @@ class GraphSolver(LightningModule): self.automatic_optimization = False self.threshold = 1e-5 - self.aplha = 0.1 + self.alpha = torch.nn.Parameter(torch.tensor(0.1)) def _compute_deg(self, edge_index, edge_attr, num_nodes): deg = torch.zeros(num_nodes, device=edge_index.device) @@ -100,15 +100,15 @@ class GraphSolver(LightningModule): def _compute_model_steps( self, x, edge_index, edge_attr, deg, boundary_mask, boundary_values ): - with torch.no_grad(): - out = self.fd_net(x, edge_index, edge_attr, deg) - out[boundary_mask] = boundary_values.unsqueeze(-1) - # diff = out - x - correction = self.model(x, edge_index, edge_attr, deg) - out = out + self.aplha * correction - out[boundary_mask] = boundary_values.unsqueeze(-1) - # out = self.model(x, edge_index, edge_attr, deg) + # with torch.no_grad(): + # out = self.fd_net(x, edge_index, edge_attr, deg) # out[boundary_mask] = boundary_values.unsqueeze(-1) + # diff = out - x + # out = self.model(out, edge_index, edge_attr, deg) + # out = out + self.alpha * correction + # out[boundary_mask] = boundary_values.unsqueeze(-1) + out = self.model(x, edge_index, edge_attr, deg) + out[boundary_mask] = boundary_values.unsqueeze(-1) return out def _check_convergence(self, out, x): diff --git a/ThermalSolver/model/finite_difference.py b/ThermalSolver/model/finite_difference.py index 18b0e85..89da373 100644 --- a/ThermalSolver/model/finite_difference.py +++ b/ThermalSolver/model/finite_difference.py @@ -23,7 +23,6 @@ class FiniteDifferenceStep(MessagePassing): """ TODO: add docstring. """ - # return self.message_net(x_j * edge_attr) return x_j * edge_attr def update(self, aggr_out, _): diff --git a/ThermalSolver/model/learnable_finite_difference.py b/ThermalSolver/model/learnable_finite_difference.py index f9ff49c..f78fff0 100644 --- a/ThermalSolver/model/learnable_finite_difference.py +++ b/ThermalSolver/model/learnable_finite_difference.py @@ -1,53 +1,124 @@ +# import torch +# import torch.nn as nn +# from torch_geometric.nn import MessagePassing +# from torch.nn.utils import spectral_norm + +# class GCNConvLayer(MessagePassing): +# def __init__(self, in_channels, out_channels): +# super().__init__(aggr="add") +# self.lin_l = spectral_norm(nn.Linear(in_channels, out_channels, bias=False)) +# self.lin_r = spectral_norm(nn.Linear(in_channels, out_channels, bias=False)) + +# def forward(self, x, edge_index, edge_attr, deg): +# out = self.propagate(edge_index, x=x, edge_attr=edge_attr, deg=deg) +# out = self.lin_l(out) +# return out + +# def message(self, x_j, edge_attr): +# return x_j * edge_attr + +# def aggregate(self, inputs, index, deg): +# """ +# TODO: add docstring. +# """ +# out = super().aggregate(inputs, index) +# deg = deg + 1e-7 +# return out / deg.view(-1, 1) + + +# class CorrectionNet(nn.Module): +# def __init__(self, hidden_dim=8, n_layers=1): +# super().__init__() +# # self.enc = GCNConvLayer(1, hidden_dim) +# self.enc = nn.Sequential( +# spectral_norm(nn.Linear(1, hidden_dim//2)), +# nn.GELU(), +# spectral_norm(nn.Linear(hidden_dim//2, hidden_dim)), +# ) +# self.layers = torch.nn.ModuleList([GCNConvLayer(hidden_dim, hidden_dim) for _ in range(n_layers)]) +# self.relu = nn.GELU() + +# self.dec = nn.Sequential( +# spectral_norm(nn.Linear(hidden_dim, hidden_dim//2)), +# nn.GELU(), +# spectral_norm(nn.Linear(hidden_dim//2, 1)), +# ) + +# def forward(self, x, edge_index, edge_attr, deg,): +# # h = self.enc(x, edge_index, edge_attr, deg) +# # h = self.relu(self.enc(x)) +# h = self.enc(x) +# for layer in self.layers: +# h = layer(h, edge_index, edge_attr, deg) +# # h = self.norm(h) +# h = self.relu(h) +# # out = self.dec(h, edge_index, edge_attr, deg) +# out = self.dec(h) +# return out + + import torch import torch.nn as nn from torch_geometric.nn import MessagePassing - -# from torch.nn.utils import spectral_norm +from torch.nn.utils import spectral_norm -class GCNConvLayer(MessagePassing): - def __init__(self, in_channels, out_channels): - super().__init__("add") - self.lin = nn.Sequential( - nn.Linear(in_channels, out_channels), - nn.ReLU(), - nn.Linear(out_channels, out_channels), - nn.ReLU(), +class CorrectionNet(MessagePassing): + """ + TODO: add docstring. + """ + + def __init__(self, hidden_dim=16): + super().__init__(aggr="add") + self.in_net = nn.Sequential( + spectral_norm(nn.Linear(1, hidden_dim // 2)), + nn.GELU(), + spectral_norm(nn.Linear(hidden_dim // 2, hidden_dim)), ) - def _compute_edge_weight(self, edge_index, edge_w, deg): - """ """ - return edge_w.squeeze() / ( - 1 + torch.sqrt(deg[edge_index[0]] * deg[edge_index[1]]) + self.out_net = nn.Sequential( + spectral_norm(nn.Linear(hidden_dim, hidden_dim // 2)), + nn.GELU(), + spectral_norm(nn.Linear(hidden_dim // 2, 1)), ) + self.lin_msg = spectral_norm( + nn.Linear(hidden_dim, hidden_dim, bias=False) + ) + self.lin_update = spectral_norm( + nn.Linear(hidden_dim, hidden_dim, bias=False) + ) + self.alpha = nn.Parameter(torch.tensor(0.0)) + self.beta = nn.Parameter(torch.tensor(0.0)) + def forward(self, x, edge_index, edge_attr, deg): - edge_w = self._compute_edge_weight(edge_index, edge_attr, deg) - return self.propagate(edge_index, x=x, edge_weight=edge_w, deg=deg) + """ + TODO: add docstring. + """ + x = self.in_net(x) + out = self.propagate(edge_index, x=x, edge_attr=edge_attr, deg=deg) + return self.out_net(out) - def message(self, x_j, edge_weight): - return edge_weight.view(-1, 1) * x_j + def message(self, x_j, edge_attr): + """ + TODO: add docstring. + """ + alpha = torch.sigmoid(self.alpha) + msg = x_j * edge_attr + msg = (1 - alpha) * msg + alpha * self.lin_msg(msg) + return msg + def update(self, aggr_out, x): + """ + TODO: add docstring. + """ + beta = torch.sigmoid(self.beta) + return aggr_out * (1 - beta) + self.lin_msg(x) * beta -class CorrectionNet(nn.Module): - def __init__(self, hidden_dim=8): - super().__init__() - self.enc = nn.Sequential( - nn.Linear(1, hidden_dim // 2), - nn.ReLU(), - nn.Linear(hidden_dim // 2, hidden_dim), - nn.ReLU(), - ) - self.model = GCNConvLayer(hidden_dim, hidden_dim) - self.dec = nn.Sequential( - nn.Linear(hidden_dim, hidden_dim // 2), - nn.ReLU(), - nn.Linear(hidden_dim // 2, 1), - nn.ReLU(), - ) - - def forward(self, x, edge_index, edge_attr, deg): - h = self.enc(x) - h = self.model(h, edge_index, edge_attr, deg) - out = self.dec(h) - return out + def aggregate(self, inputs, index, deg): + """ + TODO: add docstring. + """ + out = super().aggregate(inputs, index) + deg = deg + 1e-7 + return out / deg.view(-1, 1)