From 35770ace67a1ebd6278abe7e48a833207a583259 Mon Sep 17 00:00:00 2001 From: Filippo Olivo Date: Mon, 27 Oct 2025 10:54:35 +0100 Subject: [PATCH] add learnable finite difference and fix __init__.py --- ThermalSolver/model/__init__.py | 12 +- ThermalSolver/model/finite_difference.py | 41 ++++-- .../model/learnable_finite_difference.py | 125 ++++++++++++++++++ 3 files changed, 162 insertions(+), 16 deletions(-) create mode 100644 ThermalSolver/model/learnable_finite_difference.py diff --git a/ThermalSolver/model/__init__.py b/ThermalSolver/model/__init__.py index 2538982..37adaca 100644 --- a/ThermalSolver/model/__init__.py +++ b/ThermalSolver/model/__init__.py @@ -1,5 +1,13 @@ -__all__ = ["GraphFiniteDifference", "GatingGNO"] +__all__ = [ + "GraphFiniteDifference", + "GatingGNO", + "LearnableGraphFiniteDifference", + "PointNet", +] -from .learnable_finite_difference import GraphFiniteDifference +from .learnable_finite_difference import ( + GraphFiniteDifference as LearnableGraphFiniteDifference, +) +from .finite_difference import GraphFiniteDifference as GraphFiniteDifference from .local_gno import GatingGNO from .point_net import PointNet diff --git a/ThermalSolver/model/finite_difference.py b/ThermalSolver/model/finite_difference.py index 814eb84..05af1c9 100644 --- a/ThermalSolver/model/finite_difference.py +++ b/ThermalSolver/model/finite_difference.py @@ -15,21 +15,21 @@ class FiniteDifferenceStep(MessagePassing): aggr == "add" ), "Per somme pesate, l'aggregazione deve essere 'add'." self.root_weight = float(root_weight) + self.p = torch.nn.Parameter(torch.tensor(0.5)) - def forward(self, x, edge_index, edge_attr, deg, weight=1.0): + def forward(self, x, edge_index, edge_attr, deg): """ TODO: add docstring. """ - out = self.propagate( - edge_index, x=x, edge_attr=edge_attr, deg=deg, weight=weight - ) + out = self.propagate(edge_index, x=x, edge_attr=edge_attr, deg=deg) return out def message(self, x_j, edge_attr): """ TODO: add docstring. """ - return edge_attr.view(-1, 1) * x_j + p = torch.clamp(self.p, 0.0, 1.0) + return p * edge_attr.view(-1, 1) * x_j def aggregate(self, inputs, index, deg): """ @@ -39,12 +39,11 @@ class FiniteDifferenceStep(MessagePassing): deg = deg + 1e-7 return out / deg.view(-1, 1) - def update(self, aggr_out, x, weight): + def update(self, aggr_out, x): """ TODO: add docstring. """ - print(weight) - return weight * aggr_out + (1 - weight) * x + return self.root_weight * aggr_out + (1 - self.root_weight) * x class GraphFiniteDifference(nn.Module): @@ -94,13 +93,27 @@ class GraphFiniteDifference(nn.Module): c_ij = self._compute_c_ij(c, edge_index) edge_attr = edge_attr * c_ij deg = self._compute_deg(edge_index, edge_attr, x.size(0)) - conv_thres = self.threshold * torch.norm(x) - weight = 1.0 + + # Calcola la soglia staccando x dal grafo + conv_thres = self.threshold * torch.norm(x.detach()) + for _i in range(self.max_iters): - out = self.fd_step(x, edge_index, edge_attr, deg, weight=weight) - weight = weight * 0.9999 + out = self.fd_step(x, edge_index, edge_attr, deg) out[boundary_mask] = boundary_values.unsqueeze(-1) - if torch.norm(out - x) < conv_thres: + + # Controllo convergenza senza tracciamento gradienti + with torch.no_grad(): + residual_norm = torch.norm(out - x) + + if residual_norm < conv_thres: break - x = out + + # --- OTTIMIZZAZIONE CHIAVE --- + # Stacca 'out' dal grafo prima della prossima iterazione + # per evitare BPTT e risparmiare memoria. + x = out.detach() + + # Il 'out' finale restituito mantiene i gradienti + # dell'ULTIMA chiamata a fd_step, permettendo al modello + # di apprendere correttamente. return out, _i + 1 diff --git a/ThermalSolver/model/learnable_finite_difference.py b/ThermalSolver/model/learnable_finite_difference.py new file mode 100644 index 0000000..e4f44f1 --- /dev/null +++ b/ThermalSolver/model/learnable_finite_difference.py @@ -0,0 +1,125 @@ +import torch +import torch.nn as nn +from torch_geometric.nn import MessagePassing +from tqdm import tqdm + + +class FiniteDifferenceStep(MessagePassing): + """ + TODO: add docstring. + """ + + def __init__(self, aggr: str = "add", root_weight: float = 1.0): + super().__init__(aggr=aggr) + assert ( + aggr == "add" + ), "Per somme pesate, l'aggregazione deve essere 'add'." + self.root_weight = float(root_weight) + self.correction_net = nn.Sequential( + nn.Linear(2, 16), + nn.GELU(), + nn.Linear(16, 1), + nn.Softplus(), + ) + + def forward(self, x, edge_index, edge_attr, deg): + """ + TODO: add docstring. + """ + out = self.propagate(edge_index, x=x, edge_attr=edge_attr, deg=deg) + return out + + def message(self, x_j, edge_attr): + """ + TODO: add docstring. + """ + x_in = torch.cat([x_j, edge_attr.view(-1, 1)], dim=-1) + correction = self.correction_net(x_in) + return edge_attr.view(-1, 1) * x_j + correction + + def aggregate(self, inputs, index, deg): + """ + TODO: add docstring. + """ + out = super().aggregate(inputs, index) + deg = deg + 1e-7 + return out / deg.view(-1, 1) + + def update(self, aggr_out, x): + """ + TODO: add docstring. + """ + return self.root_weight * aggr_out + (1 - self.root_weight) * x + + +class GraphFiniteDifference(nn.Module): + """ + TODO: add docstring. + """ + + def __init__(self, max_iters: int = 5000, threshold: float = 1e-4): + """ + TODO: add docstring. + """ + super().__init__() + self.max_iters = max_iters + self.threshold = threshold + self.fd_step = FiniteDifferenceStep(aggr="add", root_weight=1.0) + + @staticmethod + def _compute_deg(edge_index, edge_attr, num_nodes): + """ + TODO: add docstring. + """ + deg = torch.zeros(num_nodes, device=edge_index.device) + deg = deg.scatter_add(0, edge_index[1], edge_attr) + return deg + 1e-7 + + @staticmethod + def _compute_c_ij(c, edge_index): + """ + TODO: add docstring. + """ + return (0.5 * (c[edge_index[0]] + c[edge_index[1]])).squeeze() + + def forward( + self, + x, + edge_index, + edge_attr, + c, + boundary_mask, + boundary_values, + **kwargs, + ): + """ + TODO: add docstring. + """ + edge_attr = 1 / edge_attr[:, -1] + c_ij = self._compute_c_ij(c, edge_index) + edge_attr = edge_attr * c_ij + deg = self._compute_deg(edge_index, edge_attr, x.size(0)) + + # Calcola la soglia staccando x dal grafo + conv_thres = self.threshold * torch.norm(x.detach()) + + for _i in range(self.max_iters): + out = self.fd_step(x, edge_index, edge_attr, deg) + out[boundary_mask] = boundary_values.unsqueeze(-1) + + # Controllo convergenza senza tracciamento gradienti + with torch.no_grad(): + residual_norm = torch.norm(out - x) + + if residual_norm < conv_thres: + break + + # --- OTTIMIZZAZIONE CHIAVE --- + # Stacca 'out' dal grafo prima della prossima iterazione + # per evitare BPTT e risparmiare memoria. + x = out.detach() + + # Il 'out' finale restituito mantiene i gradienti + # dell'ULTIMA chiamata a fd_step, permettendo al modello + # di apprendere correttamente. + return out, _i + 1