improve training_step

This commit is contained in:
Filippo Olivo
2025-11-17 15:23:46 +01:00
parent 94ad6ff160
commit 1c7b593762

View File

@@ -106,9 +106,9 @@ class GraphSolver(LightningModule):
return True
return False
def accumulate_gradients(self, losses, max_acc_iters):
def accumulate_gradients(self, losses):
loss_ = torch.stack(losses, dim=0).mean()
self.manual_backward(loss_ / max_acc_iters, retain_graph=True)
self.manual_backward(loss_, retain_graph=True)
return loss_.item()
def _preprocess_batch(self, batch: Batch):
@@ -124,7 +124,7 @@ class GraphSolver(LightningModule):
edge_attr = edge_attr * c_ij
return x, y, edge_index, edge_attr
def training_step(self, batch: Batch, batch_idx: int):
def training_step(self, batch: Batch, _):
optim = self.optimizers()
optim.zero_grad()
x, y, edge_index, edge_attr = self._preprocess_batch(batch)
@@ -153,7 +153,7 @@ class GraphSolver(LightningModule):
self.accumulation_iters is not None
and (i + 1) % self.accumulation_iters == 0
):
loss = self.accumulate_gradients(losses, max_acc_iters)
loss = self.accumulate_gradients(losses)
losses = []
acc_it += 1
out = out.detach()
@@ -163,8 +163,7 @@ class GraphSolver(LightningModule):
converged = self._check_convergence(out, x)
if converged:
if losses:
loss = self.accumulate_gradients(losses, max_acc_iters)
loss = self.accumulate_gradients(losses)
acc_it += 1
acc_loss = acc_loss + loss
break
@@ -172,26 +171,19 @@ class GraphSolver(LightningModule):
# Final accumulation if we are at the last iteration
if i == self.current_iters - 1:
if losses:
loss = self.accumulate_gradients(losses, max_acc_iters)
loss = self.accumulate_gradients(losses)
acc_it += 1
acc_loss = acc_loss + loss
x = out
loss = self.loss(out, y)
# self.manual_backward(loss)
for param in self.model.parameters():
if param.grad is not None:
param.grad /= acc_it
optim.step()
optim.zero_grad()
# self.log(
# "train/y_loss",
# loss,
# on_step=False,
# on_epoch=True,
# prog_bar=True,
# batch_size=int(batch.num_graphs),
# )
self.log(
"train/accumulated_loss",
(acc_loss / acc_it if acc_it > 0 else acc_loss),