fix training_step

This commit is contained in:
Filippo Olivo
2025-11-14 17:05:48 +01:00
parent ea9cf7c57c
commit e1117d89c6

View File

@@ -108,8 +108,7 @@ class GraphSolver(LightningModule):
def accumulate_gradients(self, losses, max_acc_iters):
loss_ = torch.stack(losses, dim=0).mean()
loss = 0.5 * loss_ / self.accumulation_iters
self.manual_backward(loss / max_acc_iters, retain_graph=True)
self.manual_backward(loss_ / max_acc_iters, retain_graph=True)
return loss_.item()
def _preprocess_batch(self, batch: Batch):
@@ -164,6 +163,7 @@ class GraphSolver(LightningModule):
converged = self._check_convergence(out, x)
if converged:
if losses:
loss = self.accumulate_gradients(losses, max_acc_iters)
acc_it += 1
acc_loss = acc_loss + loss
@@ -177,13 +177,29 @@ class GraphSolver(LightningModule):
acc_loss = acc_loss + loss
x = out
if i % self.accumulation_iters != 0:
loss = self.loss(out, y)
loss.backward()
# self.manual_backward(loss)
optim.step()
optim.zero_grad()
self._log_loss(acc_loss / acc_it, batch, "train")
# self.log(
# "train/y_loss",
# loss,
# on_step=False,
# on_epoch=True,
# prog_bar=True,
# batch_size=int(batch.num_graphs),
# )
self.log(
"train/accumulated_loss",
(acc_loss / acc_it if acc_it > 0 else acc_loss),
on_step=False,
on_epoch=True,
prog_bar=True,
batch_size=int(batch.num_graphs),
)
self.log(
"train/iterations",
i + 1,
@@ -231,7 +247,7 @@ class GraphSolver(LightningModule):
if converged:
break
x = out
loss = 0.5 * self.loss(out, y)
loss = self.loss(out, y)
self._log_loss(loss, batch, "val")
self.log(
"val/iterations",