diff --git a/ThermalSolver/model/finite_difference.py b/ThermalSolver/model/finite_difference.py index 05af1c9..fe50eb3 100644 --- a/ThermalSolver/model/finite_difference.py +++ b/ThermalSolver/model/finite_difference.py @@ -1,7 +1,6 @@ import torch import torch.nn as nn from torch_geometric.nn import MessagePassing -from tqdm import tqdm class FiniteDifferenceStep(MessagePassing): @@ -14,8 +13,9 @@ class FiniteDifferenceStep(MessagePassing): assert ( aggr == "add" ), "Per somme pesate, l'aggregazione deve essere 'add'." - self.root_weight = float(root_weight) - self.p = torch.nn.Parameter(torch.tensor(0.5)) + # self.root_weight = float(root_weight) + self.p = torch.nn.Parameter(torch.tensor(0.8)) + self.a = root_weight def forward(self, x, edge_index, edge_attr, deg): """ @@ -43,7 +43,9 @@ class FiniteDifferenceStep(MessagePassing): """ TODO: add docstring. """ - return self.root_weight * aggr_out + (1 - self.root_weight) * x + a = torch.clamp(self.a, 0.0, 1.0) + return a * aggr_out + (1 - a) * x + # return self.a * aggr_out + (1 - self.a) * x class GraphFiniteDifference(nn.Module): diff --git a/ThermalSolver/model/learnable_finite_difference.py b/ThermalSolver/model/learnable_finite_difference.py index e4f44f1..8dc6b54 100644 --- a/ThermalSolver/model/learnable_finite_difference.py +++ b/ThermalSolver/model/learnable_finite_difference.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn from torch_geometric.nn import MessagePassing -from tqdm import tqdm +from torch.nn.utils import spectral_norm class FiniteDifferenceStep(MessagePassing): @@ -14,14 +14,29 @@ class FiniteDifferenceStep(MessagePassing): assert ( aggr == "add" ), "Per somme pesate, l'aggregazione deve essere 'add'." - self.root_weight = float(root_weight) + self.correction_net = nn.Sequential( - nn.Linear(2, 16), - nn.GELU(), - nn.Linear(16, 1), + nn.Linear(2, 6), + nn.Tanh(), + nn.Linear(6, 1), + nn.Tanh(), + ) + self.update_net = nn.Sequential( + spectral_norm(nn.Linear(1, 6)), + nn.Softplus(), + spectral_norm(nn.Linear(6, 1)), nn.Softplus(), ) + self.message_net = nn.Sequential( + spectral_norm(nn.Linear(1, 6)), + nn.Softplus(), + spectral_norm(nn.Linear(6, 1)), + nn.Softplus(), + ) + self.p = torch.nn.Parameter(torch.tensor(0.5)) + # self.a = torch.nn.Parameter(torch.tensor(root_weight)) + def forward(self, x, edge_index, edge_attr, deg): """ TODO: add docstring. @@ -33,9 +48,11 @@ class FiniteDifferenceStep(MessagePassing): """ TODO: add docstring. """ - x_in = torch.cat([x_j, edge_attr.view(-1, 1)], dim=-1) - correction = self.correction_net(x_in) - return edge_attr.view(-1, 1) * x_j + correction + # x_in = torch.cat([x_j, edge_attr.view(-1, 1)], dim=-1) + # correction = self.correction_net(x_in) + # p = torch.sigmoid(self.p) + # return (p * edge_attr.view(-1, 1) + (1 - p) * correction) * x_j + return edge_attr.view(-1, 1) * x_j def aggregate(self, inputs, index, deg): """ @@ -49,7 +66,7 @@ class FiniteDifferenceStep(MessagePassing): """ TODO: add docstring. """ - return self.root_weight * aggr_out + (1 - self.root_weight) * x + return self.update_net(aggr_out) class GraphFiniteDifference(nn.Module): @@ -99,27 +116,14 @@ class GraphFiniteDifference(nn.Module): c_ij = self._compute_c_ij(c, edge_index) edge_attr = edge_attr * c_ij deg = self._compute_deg(edge_index, edge_attr, x.size(0)) - - # Calcola la soglia staccando x dal grafo conv_thres = self.threshold * torch.norm(x.detach()) for _i in range(self.max_iters): out = self.fd_step(x, edge_index, edge_attr, deg) out[boundary_mask] = boundary_values.unsqueeze(-1) - - # Controllo convergenza senza tracciamento gradienti with torch.no_grad(): residual_norm = torch.norm(out - x) - if residual_norm < conv_thres: break - - # --- OTTIMIZZAZIONE CHIAVE --- - # Stacca 'out' dal grafo prima della prossima iterazione - # per evitare BPTT e risparmiare memoria. x = out.detach() - - # Il 'out' finale restituito mantiene i gradienti - # dell'ULTIMA chiamata a fd_step, permettendo al modello - # di apprendere correttamente. return out, _i + 1