fix training_step

This commit is contained in:
Filippo Olivo
2025-11-14 17:05:48 +01:00
parent ea9cf7c57c
commit e1117d89c6

View File

@@ -108,8 +108,7 @@ class GraphSolver(LightningModule):
def accumulate_gradients(self, losses, max_acc_iters): def accumulate_gradients(self, losses, max_acc_iters):
loss_ = torch.stack(losses, dim=0).mean() loss_ = torch.stack(losses, dim=0).mean()
loss = 0.5 * loss_ / self.accumulation_iters self.manual_backward(loss_ / max_acc_iters, retain_graph=True)
self.manual_backward(loss / max_acc_iters, retain_graph=True)
return loss_.item() return loss_.item()
def _preprocess_batch(self, batch: Batch): def _preprocess_batch(self, batch: Batch):
@@ -164,6 +163,7 @@ class GraphSolver(LightningModule):
converged = self._check_convergence(out, x) converged = self._check_convergence(out, x)
if converged: if converged:
if losses: if losses:
loss = self.accumulate_gradients(losses, max_acc_iters) loss = self.accumulate_gradients(losses, max_acc_iters)
acc_it += 1 acc_it += 1
acc_loss = acc_loss + loss acc_loss = acc_loss + loss
@@ -177,13 +177,29 @@ class GraphSolver(LightningModule):
acc_loss = acc_loss + loss acc_loss = acc_loss + loss
x = out x = out
if i % self.accumulation_iters != 0:
loss = self.loss(out, y) loss = self.loss(out, y)
loss.backward()
# self.manual_backward(loss)
optim.step() optim.step()
optim.zero_grad() optim.zero_grad()
self._log_loss(acc_loss / acc_it, batch, "train") # self.log(
# "train/y_loss",
# loss,
# on_step=False,
# on_epoch=True,
# prog_bar=True,
# batch_size=int(batch.num_graphs),
# )
self.log(
"train/accumulated_loss",
(acc_loss / acc_it if acc_it > 0 else acc_loss),
on_step=False,
on_epoch=True,
prog_bar=True,
batch_size=int(batch.num_graphs),
)
self.log( self.log(
"train/iterations", "train/iterations",
i + 1, i + 1,
@@ -231,7 +247,7 @@ class GraphSolver(LightningModule):
if converged: if converged:
break break
x = out x = out
loss = 0.5 * self.loss(out, y) loss = self.loss(out, y)
self._log_loss(loss, batch, "val") self._log_loss(loss, batch, "val")
self.log( self.log(
"val/iterations", "val/iterations",