From ea9cf7c57c4a92641059707ba6ad8ab0dd61c309 Mon Sep 17 00:00:00 2001 From: Filippo Olivo Date: Thu, 13 Nov 2025 16:18:54 +0100 Subject: [PATCH] add final loss and change model --- ThermalSolver/graph_module.py | 42 ++++++++----------- .../model/learnable_finite_difference.py | 15 +++---- 2 files changed, 26 insertions(+), 31 deletions(-) diff --git a/ThermalSolver/graph_module.py b/ThermalSolver/graph_module.py index 7291977..35aec81 100644 --- a/ThermalSolver/graph_module.py +++ b/ThermalSolver/graph_module.py @@ -65,7 +65,7 @@ class GraphSolver(LightningModule): self.current_iters = start_iters self.accumulation_iters = accumulation_iters self.automatic_optimization = False - self.threshold = 1e-2 + self.threshold = 1e-5 def _compute_deg(self, edge_index, edge_attr, num_nodes): deg = torch.zeros(num_nodes, device=edge_index.device) @@ -102,14 +102,14 @@ class GraphSolver(LightningModule): def _check_convergence(self, out, x): residual_norm = torch.norm(out - x) - if residual_norm < self.threshold: + if residual_norm < self.threshold * torch.norm(x): return True return False def accumulate_gradients(self, losses, max_acc_iters): loss_ = torch.stack(losses, dim=0).mean() - loss = loss_ / self.accumulation_iters - self.manual_backward(loss / max_acc_iters) + loss = 0.5 * loss_ / self.accumulation_iters + self.manual_backward(loss / max_acc_iters, retain_graph=True) return loss_.item() def _preprocess_batch(self, batch: Batch): @@ -177,7 +177,9 @@ class GraphSolver(LightningModule): acc_loss = acc_loss + loss x = out - + if i % self.accumulation_iters != 0: + loss = self.loss(out, y) + loss.backward() optim.step() optim.zero_grad() @@ -190,6 +192,15 @@ class GraphSolver(LightningModule): prog_bar=True, batch_size=int(batch.num_graphs), ) + if hasattr(self.model, "p"): + self.log( + "train/p", + self.model.p, + on_step=False, + on_epoch=True, + prog_bar=True, + batch_size=int(batch.num_graphs), + ) def on_train_epoch_end(self): if self.curriculum_learning: @@ -220,7 +231,7 @@ class GraphSolver(LightningModule): if converged: break x = out - loss = self.loss(out, y) + loss = 0.5 * self.loss(out, y) self._log_loss(loss, batch, "val") self.log( "val/iterations", @@ -232,23 +243,6 @@ class GraphSolver(LightningModule): ) def test_step(self, batch: Batch, _): - # x, y, c, edge_index, edge_attr = self._preprocess_batch(batch) - # y_pred, _ = self.model( - # x, - # edge_index, - # edge_attr, - # c, - # batch.boundary_mask, - # batch.boundary_values, - # y=None, - # loss_fn=None, - # max_iters=1000, - # plot_results=True, - # batch=batch, - # ) - # loss = self._compute_loss(y_pred, y) - # # _plot_mesh(batch.pos, y, y_pred, batch.batch) - # self._log_loss(loss, batch, "test") x, y, edge_index, edge_attr = self._preprocess_batch(batch) deg = self._compute_deg(edge_index, edge_attr, x.size(0)) @@ -263,7 +257,7 @@ class GraphSolver(LightningModule): batch.boundary_values, ) converged = self._check_convergence(out, x) - _plot_mesh(batch.pos, y, out, batch.batch, i) + # _plot_mesh(batch.pos, y, out, batch.batch, i) if converged: break x = out diff --git a/ThermalSolver/model/learnable_finite_difference.py b/ThermalSolver/model/learnable_finite_difference.py index 13f3e45..3e69af9 100644 --- a/ThermalSolver/model/learnable_finite_difference.py +++ b/ThermalSolver/model/learnable_finite_difference.py @@ -17,11 +17,11 @@ class FiniteDifferenceStep(MessagePassing): spectral_norm(nn.Linear(hidden_dim // 2, hidden_dim)), ) - self.update_net = nn.Sequential( - spectral_norm(nn.Linear(2 * hidden_dim, hidden_dim)), - nn.GELU(), - spectral_norm(nn.Linear(hidden_dim, hidden_dim)), - ) + # self.update_net = nn.Sequential( + # spectral_norm(nn.Linear(2 * hidden_dim, hidden_dim)), + # nn.GELU(), + # spectral_norm(nn.Linear(hidden_dim, hidden_dim)), + # ) self.out_net = nn.Sequential( spectral_norm(nn.Linear(hidden_dim, hidden_dim // 2)), @@ -47,8 +47,9 @@ class FiniteDifferenceStep(MessagePassing): """ TODO: add docstring. """ - update_input = torch.cat([x, aggr_out], dim=-1) - return self.update_net(update_input) + # update_input = torch.cat([x, aggr_out], dim=-1) + # return self.update_net(update_input) + return aggr_out def aggregate(self, inputs, index, deg): """