diff --git a/ThermalSolver/graph_module.py b/ThermalSolver/graph_module.py index 4034e36..e866fc4 100644 --- a/ThermalSolver/graph_module.py +++ b/ThermalSolver/graph_module.py @@ -4,6 +4,7 @@ from torch_geometric.data import Batch import importlib from matplotlib import pyplot as plt from matplotlib.tri import Triangulation +from .model.finite_difference import FiniteDifferenceStep def import_class(class_path: str): @@ -56,6 +57,7 @@ class GraphSolver(LightningModule): ): super().__init__() self.model = import_class(model_class_path)(**model_init_args) + self.fd_net = FiniteDifferenceStep() self.loss = loss if loss is not None else torch.nn.MSELoss() self.curriculum_learning = curriculum_learning self.start_iters = start_iters @@ -67,6 +69,8 @@ class GraphSolver(LightningModule): self.automatic_optimization = False self.threshold = 1e-5 + self.aplha = 0.1 + def _compute_deg(self, edge_index, edge_attr, num_nodes): deg = torch.zeros(num_nodes, device=edge_index.device) deg = deg.scatter_add(0, edge_index[1], edge_attr) @@ -96,8 +100,15 @@ class GraphSolver(LightningModule): def _compute_model_steps( self, x, edge_index, edge_attr, deg, boundary_mask, boundary_values ): - out = self.model(x, edge_index, edge_attr, deg) + with torch.no_grad(): + out = self.fd_net(x, edge_index, edge_attr, deg) + out[boundary_mask] = boundary_values.unsqueeze(-1) + # diff = out - x + correction = self.model(x, edge_index, edge_attr, deg) + out = out + self.aplha * correction out[boundary_mask] = boundary_values.unsqueeze(-1) + # out = self.model(x, edge_index, edge_attr, deg) + # out[boundary_mask] = boundary_values.unsqueeze(-1) return out def _check_convergence(self, out, x): @@ -132,11 +143,7 @@ class GraphSolver(LightningModule): deg = self._compute_deg(edge_index, edge_attr, x.size(0)) losses = [] acc_loss, acc_it = 0, 0 - max_acc_iters = ( - self.current_iters // self.accumulation_iters + 1 - if self.accumulation_iters is not None - else 1 - ) + for i in range(self.current_iters): out = self._compute_model_steps( x, diff --git a/ThermalSolver/model/finite_difference.py b/ThermalSolver/model/finite_difference.py index f183f52..18b0e85 100644 --- a/ThermalSolver/model/finite_difference.py +++ b/ThermalSolver/model/finite_difference.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn from torch_geometric.nn import MessagePassing +from torch.nn.utils import spectral_norm class FiniteDifferenceStep(MessagePassing): @@ -8,14 +9,8 @@ class FiniteDifferenceStep(MessagePassing): TODO: add docstring. """ - def __init__(self, aggr: str = "add", root_weight: float = 1.0): - super().__init__(aggr=aggr) - assert ( - aggr == "add" - ), "Per somme pesate, l'aggregazione deve essere 'add'." - # self.root_weight = float(root_weight) - self.p = torch.nn.Parameter(torch.tensor(1.0)) - self.a = root_weight + def __init__(self): + super().__init__(aggr="add") def forward(self, x, edge_index, edge_attr, deg): """ @@ -28,8 +23,14 @@ class FiniteDifferenceStep(MessagePassing): """ TODO: add docstring. """ - p = torch.clamp(self.p, 0.0, 1.0) - return p * edge_attr.view(-1, 1) * x_j + # return self.message_net(x_j * edge_attr) + return x_j * edge_attr + + def update(self, aggr_out, _): + """ + TODO: add docstring. + """ + return aggr_out def aggregate(self, inputs, index, deg): """ @@ -38,82 +39,3 @@ class FiniteDifferenceStep(MessagePassing): out = super().aggregate(inputs, index) deg = deg + 1e-7 return out / deg.view(-1, 1) - - def update(self, aggr_out, x): - """ - TODO: add docstring. - """ - return aggr_out - - -class GraphFiniteDifference(nn.Module): - """ - TODO: add docstring. - """ - - def __init__(self, max_iters: int = 5000, threshold: float = 1e-4): - """ - TODO: add docstring. - """ - super().__init__() - self.max_iters = max_iters - self.threshold = threshold - self.fd_step = FiniteDifferenceStep(aggr="add", root_weight=1.0) - - @staticmethod - def _compute_deg(edge_index, edge_attr, num_nodes): - """ - TODO: add docstring. - """ - deg = torch.zeros(num_nodes, device=edge_index.device) - deg = deg.scatter_add(0, edge_index[1], edge_attr) - return deg + 1e-7 - - @staticmethod - def _compute_c_ij(c, edge_index): - """ - TODO: add docstring. - """ - return (0.5 * (c[edge_index[0]] + c[edge_index[1]])).squeeze() - - def forward( - self, - x, - edge_index, - edge_attr, - c, - boundary_mask, - boundary_values, - **kwargs, - ): - """ - TODO: add docstring. - """ - edge_attr = 1 / edge_attr[:, -1] - c_ij = self._compute_c_ij(c, edge_index) - edge_attr = edge_attr * c_ij - deg = self._compute_deg(edge_index, edge_attr, x.size(0)) - - # Calcola la soglia staccando x dal grafo - conv_thres = self.threshold * torch.norm(x.detach()) - - for _i in range(self.max_iters): - out = self.fd_step(x, edge_index, edge_attr, deg) - out[boundary_mask] = boundary_values.unsqueeze(-1) - - # Controllo convergenza senza tracciamento gradienti - with torch.no_grad(): - residual_norm = torch.norm(out - x) - - if residual_norm < conv_thres: - break - - # --- OTTIMIZZAZIONE CHIAVE --- - # Stacca 'out' dal grafo prima della prossima iterazione - # per evitare BPTT e risparmiare memoria. - x = out.detach() - - # Il 'out' finale restituito mantiene i gradienti - # dell'ULTIMA chiamata a fd_step, permettendo al modello - # di apprendere correttamente. - return out, _i + 1 diff --git a/ThermalSolver/model/learnable_finite_difference.py b/ThermalSolver/model/learnable_finite_difference.py index 94efba1..f9ff49c 100644 --- a/ThermalSolver/model/learnable_finite_difference.py +++ b/ThermalSolver/model/learnable_finite_difference.py @@ -1,53 +1,53 @@ import torch import torch.nn as nn from torch_geometric.nn import MessagePassing -from torch.nn.utils import spectral_norm + +# from torch.nn.utils import spectral_norm -class FiniteDifferenceStep(MessagePassing): - """ - TODO: add docstring. - """ - - def __init__(self, hidden_dim=16, aggr: str = "add"): - print(aggr) - super().__init__(aggr=aggr) - self.x_embedding = nn.Sequential( - spectral_norm(nn.Linear(1, hidden_dim // 2)), - nn.GELU(), - spectral_norm(nn.Linear(hidden_dim // 2, hidden_dim)), +class GCNConvLayer(MessagePassing): + def __init__(self, in_channels, out_channels): + super().__init__("add") + self.lin = nn.Sequential( + nn.Linear(in_channels, out_channels), + nn.ReLU(), + nn.Linear(out_channels, out_channels), + nn.ReLU(), ) - self.out_net = nn.Sequential( - spectral_norm(nn.Linear(hidden_dim, hidden_dim // 2)), - nn.GELU(), - spectral_norm(nn.Linear(hidden_dim // 2, 1)), + def _compute_edge_weight(self, edge_index, edge_w, deg): + """ """ + return edge_w.squeeze() / ( + 1 + torch.sqrt(deg[edge_index[0]] * deg[edge_index[1]]) ) def forward(self, x, edge_index, edge_attr, deg): - """ - TODO: add docstring. - """ - x_ = self.x_embedding(x) - out = self.propagate(edge_index, x=x_, edge_attr=edge_attr, deg=deg) - return self.out_net(out) + edge_w = self._compute_edge_weight(edge_index, edge_attr, deg) + return self.propagate(edge_index, x=x, edge_weight=edge_w, deg=deg) - def message(self, x_j, edge_attr): - """ - TODO: add docstring. - """ - return x_j * edge_attr.view(-1, 1) + def message(self, x_j, edge_weight): + return edge_weight.view(-1, 1) * x_j - def update(self, aggr_out, _): - """ - TODO: add docstring. - """ - return aggr_out - def aggregate(self, inputs, index, deg): - """ - TODO: add docstring. - """ - out = super().aggregate(inputs, index) - deg = deg + 1e-7 - return out / deg.view(-1, 1) +class CorrectionNet(nn.Module): + def __init__(self, hidden_dim=8): + super().__init__() + self.enc = nn.Sequential( + nn.Linear(1, hidden_dim // 2), + nn.ReLU(), + nn.Linear(hidden_dim // 2, hidden_dim), + nn.ReLU(), + ) + self.model = GCNConvLayer(hidden_dim, hidden_dim) + self.dec = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim // 2), + nn.ReLU(), + nn.Linear(hidden_dim // 2, 1), + nn.ReLU(), + ) + + def forward(self, x, edge_index, edge_attr, deg): + h = self.enc(x) + h = self.model(h, edge_index, edge_attr, deg) + out = self.dec(h) + return out