From e1117d89c651624644caba00a70aec3e55827da1 Mon Sep 17 00:00:00 2001 From: Filippo Olivo Date: Fri, 14 Nov 2025 17:05:48 +0100 Subject: [PATCH] fix training_step --- ThermalSolver/graph_module.py | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/ThermalSolver/graph_module.py b/ThermalSolver/graph_module.py index 35aec81..5829db8 100644 --- a/ThermalSolver/graph_module.py +++ b/ThermalSolver/graph_module.py @@ -108,8 +108,7 @@ class GraphSolver(LightningModule): def accumulate_gradients(self, losses, max_acc_iters): loss_ = torch.stack(losses, dim=0).mean() - loss = 0.5 * loss_ / self.accumulation_iters - self.manual_backward(loss / max_acc_iters, retain_graph=True) + self.manual_backward(loss_ / max_acc_iters, retain_graph=True) return loss_.item() def _preprocess_batch(self, batch: Batch): @@ -164,6 +163,7 @@ class GraphSolver(LightningModule): converged = self._check_convergence(out, x) if converged: if losses: + loss = self.accumulate_gradients(losses, max_acc_iters) acc_it += 1 acc_loss = acc_loss + loss @@ -177,13 +177,29 @@ class GraphSolver(LightningModule): acc_loss = acc_loss + loss x = out - if i % self.accumulation_iters != 0: - loss = self.loss(out, y) - loss.backward() + + loss = self.loss(out, y) + + # self.manual_backward(loss) optim.step() optim.zero_grad() - self._log_loss(acc_loss / acc_it, batch, "train") + # self.log( + # "train/y_loss", + # loss, + # on_step=False, + # on_epoch=True, + # prog_bar=True, + # batch_size=int(batch.num_graphs), + # ) + self.log( + "train/accumulated_loss", + (acc_loss / acc_it if acc_it > 0 else acc_loss), + on_step=False, + on_epoch=True, + prog_bar=True, + batch_size=int(batch.num_graphs), + ) self.log( "train/iterations", i + 1, @@ -231,7 +247,7 @@ class GraphSolver(LightningModule): if converged: break x = out - loss = 0.5 * self.loss(out, y) + loss = self.loss(out, y) self._log_loss(loss, batch, "val") self.log( "val/iterations",