diff --git a/ThermalSolver/graph_module.py b/ThermalSolver/graph_module.py index 5829db8..4034e36 100644 --- a/ThermalSolver/graph_module.py +++ b/ThermalSolver/graph_module.py @@ -106,9 +106,9 @@ class GraphSolver(LightningModule): return True return False - def accumulate_gradients(self, losses, max_acc_iters): + def accumulate_gradients(self, losses): loss_ = torch.stack(losses, dim=0).mean() - self.manual_backward(loss_ / max_acc_iters, retain_graph=True) + self.manual_backward(loss_, retain_graph=True) return loss_.item() def _preprocess_batch(self, batch: Batch): @@ -124,7 +124,7 @@ class GraphSolver(LightningModule): edge_attr = edge_attr * c_ij return x, y, edge_index, edge_attr - def training_step(self, batch: Batch, batch_idx: int): + def training_step(self, batch: Batch, _): optim = self.optimizers() optim.zero_grad() x, y, edge_index, edge_attr = self._preprocess_batch(batch) @@ -153,7 +153,7 @@ class GraphSolver(LightningModule): self.accumulation_iters is not None and (i + 1) % self.accumulation_iters == 0 ): - loss = self.accumulate_gradients(losses, max_acc_iters) + loss = self.accumulate_gradients(losses) losses = [] acc_it += 1 out = out.detach() @@ -163,8 +163,7 @@ class GraphSolver(LightningModule): converged = self._check_convergence(out, x) if converged: if losses: - - loss = self.accumulate_gradients(losses, max_acc_iters) + loss = self.accumulate_gradients(losses) acc_it += 1 acc_loss = acc_loss + loss break @@ -172,26 +171,19 @@ class GraphSolver(LightningModule): # Final accumulation if we are at the last iteration if i == self.current_iters - 1: if losses: - loss = self.accumulate_gradients(losses, max_acc_iters) + loss = self.accumulate_gradients(losses) acc_it += 1 acc_loss = acc_loss + loss x = out loss = self.loss(out, y) - - # self.manual_backward(loss) + for param in self.model.parameters(): + if param.grad is not None: + param.grad /= acc_it optim.step() optim.zero_grad() - # self.log( - # "train/y_loss", - # loss, - # on_step=False, - # on_epoch=True, - # prog_bar=True, - # batch_size=int(batch.num_graphs), - # ) self.log( "train/accumulated_loss", (acc_loss / acc_it if acc_it > 0 else acc_loss),