improve training_step
This commit is contained in:
@@ -106,9 +106,9 @@ class GraphSolver(LightningModule):
|
||||
return True
|
||||
return False
|
||||
|
||||
def accumulate_gradients(self, losses, max_acc_iters):
|
||||
def accumulate_gradients(self, losses):
|
||||
loss_ = torch.stack(losses, dim=0).mean()
|
||||
self.manual_backward(loss_ / max_acc_iters, retain_graph=True)
|
||||
self.manual_backward(loss_, retain_graph=True)
|
||||
return loss_.item()
|
||||
|
||||
def _preprocess_batch(self, batch: Batch):
|
||||
@@ -124,7 +124,7 @@ class GraphSolver(LightningModule):
|
||||
edge_attr = edge_attr * c_ij
|
||||
return x, y, edge_index, edge_attr
|
||||
|
||||
def training_step(self, batch: Batch, batch_idx: int):
|
||||
def training_step(self, batch: Batch, _):
|
||||
optim = self.optimizers()
|
||||
optim.zero_grad()
|
||||
x, y, edge_index, edge_attr = self._preprocess_batch(batch)
|
||||
@@ -153,7 +153,7 @@ class GraphSolver(LightningModule):
|
||||
self.accumulation_iters is not None
|
||||
and (i + 1) % self.accumulation_iters == 0
|
||||
):
|
||||
loss = self.accumulate_gradients(losses, max_acc_iters)
|
||||
loss = self.accumulate_gradients(losses)
|
||||
losses = []
|
||||
acc_it += 1
|
||||
out = out.detach()
|
||||
@@ -163,8 +163,7 @@ class GraphSolver(LightningModule):
|
||||
converged = self._check_convergence(out, x)
|
||||
if converged:
|
||||
if losses:
|
||||
|
||||
loss = self.accumulate_gradients(losses, max_acc_iters)
|
||||
loss = self.accumulate_gradients(losses)
|
||||
acc_it += 1
|
||||
acc_loss = acc_loss + loss
|
||||
break
|
||||
@@ -172,26 +171,19 @@ class GraphSolver(LightningModule):
|
||||
# Final accumulation if we are at the last iteration
|
||||
if i == self.current_iters - 1:
|
||||
if losses:
|
||||
loss = self.accumulate_gradients(losses, max_acc_iters)
|
||||
loss = self.accumulate_gradients(losses)
|
||||
acc_it += 1
|
||||
acc_loss = acc_loss + loss
|
||||
|
||||
x = out
|
||||
|
||||
loss = self.loss(out, y)
|
||||
|
||||
# self.manual_backward(loss)
|
||||
for param in self.model.parameters():
|
||||
if param.grad is not None:
|
||||
param.grad /= acc_it
|
||||
optim.step()
|
||||
optim.zero_grad()
|
||||
|
||||
# self.log(
|
||||
# "train/y_loss",
|
||||
# loss,
|
||||
# on_step=False,
|
||||
# on_epoch=True,
|
||||
# prog_bar=True,
|
||||
# batch_size=int(batch.num_graphs),
|
||||
# )
|
||||
self.log(
|
||||
"train/accumulated_loss",
|
||||
(acc_loss / acc_it if acc_it > 0 else acc_loss),
|
||||
|
||||
Reference in New Issue
Block a user