From dc59114f4a880bf3804e1dc25d3c0672af24499a Mon Sep 17 00:00:00 2001 From: Filippo Olivo Date: Wed, 12 Nov 2025 15:20:43 +0100 Subject: [PATCH] try a new model --- ThermalSolver/graph_datamodule.py | 9 ++- ThermalSolver/graph_module.py | 55 ++++++-------- .../model/learnable_finite_difference.py | 73 +------------------ 3 files changed, 31 insertions(+), 106 deletions(-) diff --git a/ThermalSolver/graph_datamodule.py b/ThermalSolver/graph_datamodule.py index 9920aa5..564dee6 100644 --- a/ThermalSolver/graph_datamodule.py +++ b/ThermalSolver/graph_datamodule.py @@ -122,10 +122,11 @@ class GraphDataModule(LightningDataModule): edge_index_mask = ~torch.isin(edge_index[1], boundary_idx) edge_index = edge_index[:, edge_index_mask] - edge_attr = pos[edge_index[0]] - pos[edge_index[1]] - edge_attr = torch.cat( - [edge_attr, torch.norm(edge_attr, dim=1).unsqueeze(-1)], dim=1 - ) + # edge_attr = pos[edge_index[0]] - pos[edge_index[1]] + # edge_attr = torch.cat( + # [edge_attr, torch.norm(edge_attr, dim=1).unsqueeze(-1)], dim=1 + # ) + edge_attr = torch.norm(pos[edge_index[0]] - pos[edge_index[1]], dim=1) x = torch.zeros_like(temperature, dtype=torch.float32).unsqueeze(-1) if self.remove_boundary_edges: diff --git a/ThermalSolver/graph_module.py b/ThermalSolver/graph_module.py index bf80043..7291977 100644 --- a/ThermalSolver/graph_module.py +++ b/ThermalSolver/graph_module.py @@ -75,9 +75,6 @@ class GraphSolver(LightningModule): def _compute_loss(self, x, y): return self.loss(x, y) - def _preprocess_batch(self, batch: Batch): - return batch.x, batch.y, batch.c, batch.edge_index, batch.edge_attr - def _log_loss(self, loss, batch, stage: str): self.log( f"{stage}/loss", @@ -115,19 +112,25 @@ class GraphSolver(LightningModule): self.manual_backward(loss / max_acc_iters) return loss_.item() + def _preprocess_batch(self, batch: Batch): + x, y, c, edge_index, edge_attr = ( + batch.x, + batch.y, + batch.c, + batch.edge_index, + batch.edge_attr, + ) + edge_attr = 1 / edge_attr + c_ij = self._compute_c_ij(c, edge_index) + edge_attr = edge_attr * c_ij + return x, y, edge_index, edge_attr + def training_step(self, batch: Batch, batch_idx: int): optim = self.optimizers() optim.zero_grad() - x, y, c, edge_index, edge_attr = self._preprocess_batch(batch) + x, y, edge_index, edge_attr = self._preprocess_batch(batch) - edge_w = 1 / edge_attr[:, -1] - c_ij = self._compute_c_ij(c, edge_index) - edge_w = edge_w * c_ij - deg = self._compute_deg(edge_index, edge_w, x.size(0)) - - edge_attr = torch.cat( - [edge_attr, edge_w.unsqueeze(-1), c_ij.unsqueeze(-1)], dim=1 - ) + deg = self._compute_deg(edge_index, edge_attr, x.size(0)) losses = [] acc_loss, acc_it = 0, 0 max_acc_iters = ( @@ -139,7 +142,7 @@ class GraphSolver(LightningModule): out = self._compute_model_steps( x, edge_index, - edge_attr, + edge_attr.unsqueeze(-1), deg, batch.boundary_mask, batch.boundary_values, @@ -199,21 +202,15 @@ class GraphSolver(LightningModule): return super().on_train_epoch_end() def validation_step(self, batch: Batch, _): - x, y, c, edge_index, edge_attr = self._preprocess_batch(batch) + x, y, edge_index, edge_attr = self._preprocess_batch(batch) - edge_w = 1 / edge_attr[:, -1] - c_ij = self._compute_c_ij(c, edge_index) - edge_w = edge_w * c_ij - deg = self._compute_deg(edge_index, edge_w, x.size(0)) + deg = self._compute_deg(edge_index, edge_attr, x.size(0)) - edge_attr = torch.cat( - [edge_attr, edge_w.unsqueeze(-1), c_ij.unsqueeze(-1)], dim=1 - ) for i in range(self.current_iters): out = self._compute_model_steps( x, edge_index, - edge_attr, + edge_attr.unsqueeze(-1), deg, batch.boundary_mask, batch.boundary_values, @@ -252,20 +249,15 @@ class GraphSolver(LightningModule): # loss = self._compute_loss(y_pred, y) # # _plot_mesh(batch.pos, y, y_pred, batch.batch) # self._log_loss(loss, batch, "test") - x, y, c, edge_index, edge_attr = self._preprocess_batch(batch) - edge_w = 1 / edge_attr[:, -1] - c_ij = self._compute_c_ij(c, edge_index) - edge_w = edge_w * c_ij - deg = self._compute_deg(edge_index, edge_w, x.size(0)) + x, y, edge_index, edge_attr = self._preprocess_batch(batch) + + deg = self._compute_deg(edge_index, edge_attr, x.size(0)) - edge_attr = torch.cat( - [edge_attr, edge_w.unsqueeze(-1), c_ij.unsqueeze(-1)], dim=1 - ) for i in range(self.max_iters): out = self._compute_model_steps( x, edge_index, - edge_attr, + edge_attr.unsqueeze(-1), deg, batch.boundary_mask, batch.boundary_values, @@ -278,7 +270,6 @@ class GraphSolver(LightningModule): loss = self.loss(out, y) self._log_loss(loss, batch, "test") - x = u self.log( "test/iterations", i + 1, diff --git a/ThermalSolver/model/learnable_finite_difference.py b/ThermalSolver/model/learnable_finite_difference.py index 28201bb..13f3e45 100644 --- a/ThermalSolver/model/learnable_finite_difference.py +++ b/ThermalSolver/model/learnable_finite_difference.py @@ -2,40 +2,6 @@ import torch import torch.nn as nn from torch_geometric.nn import MessagePassing from torch.nn.utils import spectral_norm -from matplotlib.tri import Triangulation -from matplotlib import pyplot as plt - - -def _plot_mesh(y_pred, batch, iteration=None): - - idx = batch.batch == 0 - y = batch.y[idx].detach().cpu() - y_pred = y_pred[idx].detach().cpu() - pos = batch.pos[idx].detach().cpu() - - pos = pos.detach().cpu() - tria = Triangulation(pos[:, 0], pos[:, 1]) - plt.figure(figsize=(18, 5)) - plt.subplot(1, 3, 1) - plt.tricontourf(tria, y.squeeze().numpy(), levels=14) - plt.colorbar() - plt.title("True temperature") - plt.subplot(1, 3, 2) - plt.tricontourf(tria, y_pred.squeeze().numpy(), levels=14) - plt.colorbar() - plt.title("Predicted temperature") - plt.subplot(1, 3, 3) - plt.tricontourf(tria, torch.abs(y_pred - y).squeeze().numpy(), levels=14) - plt.colorbar() - plt.title("Error") - plt.suptitle("GNO", fontsize=16) - name = ( - f"images/gno_iter_{iteration:04d}.png" - if iteration is not None - else "gno.png" - ) - plt.savefig(name, dpi=72) - plt.close() class FiniteDifferenceStep(MessagePassing): @@ -51,28 +17,12 @@ class FiniteDifferenceStep(MessagePassing): spectral_norm(nn.Linear(hidden_dim // 2, hidden_dim)), ) - self.edge_embedding = nn.Sequential( - spectral_norm(nn.Linear(edge_ch, hidden_dim // 2)), - nn.GELU(), - spectral_norm(nn.Linear(hidden_dim // 2, hidden_dim)), - ) - self.update_net = nn.Sequential( spectral_norm(nn.Linear(2 * hidden_dim, hidden_dim)), nn.GELU(), spectral_norm(nn.Linear(hidden_dim, hidden_dim)), - nn.GELU(), - # spectral_norm(nn.Linear(hidden_dim // 2, 1)), ) - # self.message_net = nn.Sequential( - # spectral_norm(nn.Linear(2 * hidden_dim, hidden_dim)), - # nn.GELU(), - # spectral_norm(nn.Linear(hidden_dim, hidden_dim // 2)), - # nn.GELU(), - # spectral_norm(nn.Linear(hidden_dim // 2, hidden_dim)), - # ) - self.out_net = nn.Sequential( spectral_norm(nn.Linear(hidden_dim, hidden_dim // 2)), nn.GELU(), @@ -84,17 +34,14 @@ class FiniteDifferenceStep(MessagePassing): TODO: add docstring. """ x_ = self.x_embedding(x) - edge_attr_ = self.edge_embedding(edge_attr) - out = self.propagate(edge_index, x=x_, edge_attr=edge_attr_, deg=deg) + out = self.propagate(edge_index, x=x_, edge_attr=edge_attr, deg=deg) return self.out_net(x_ + out) - def message(self, x_j, edge_attr): + def message(self, x_i, x_j, edge_attr): """ TODO: add docstring. """ - # msg_input = torch.cat([x_j, edge_attr], dim=-1) - # return self.message_net(msg_input) * edge_attr[:, 3].view(-1, 1) - return x_j * edge_attr + return (x_j - x_i) * edge_attr.view(-1, 1) def update(self, aggr_out, x): """ @@ -102,10 +49,6 @@ class FiniteDifferenceStep(MessagePassing): """ update_input = torch.cat([x, aggr_out], dim=-1) return self.update_net(update_input) - # return self.update_net(aggr_out) - # return aggr_out - # h = self.update_net(aggr_out, x) - # return h def aggregate(self, inputs, index, deg): """ @@ -114,13 +57,3 @@ class FiniteDifferenceStep(MessagePassing): out = super().aggregate(inputs, index) deg = deg + 1e-7 return out / deg.view(-1, 1) - - -# # Da fare: -# # - Finire calcolo della loss su ogni step e poi media -# # - Test con vari modelli -# # - Se non dovesse funzionare, provare ad adeguare il criterio di uscita - -# # PINN batching: -# # - Provare singola condizione -# # - Ottimizzatore del secondo ordine (LBFGS)