fix training_step
This commit is contained in:
@@ -108,8 +108,7 @@ class GraphSolver(LightningModule):
|
||||
|
||||
def accumulate_gradients(self, losses, max_acc_iters):
|
||||
loss_ = torch.stack(losses, dim=0).mean()
|
||||
loss = 0.5 * loss_ / self.accumulation_iters
|
||||
self.manual_backward(loss / max_acc_iters, retain_graph=True)
|
||||
self.manual_backward(loss_ / max_acc_iters, retain_graph=True)
|
||||
return loss_.item()
|
||||
|
||||
def _preprocess_batch(self, batch: Batch):
|
||||
@@ -164,6 +163,7 @@ class GraphSolver(LightningModule):
|
||||
converged = self._check_convergence(out, x)
|
||||
if converged:
|
||||
if losses:
|
||||
|
||||
loss = self.accumulate_gradients(losses, max_acc_iters)
|
||||
acc_it += 1
|
||||
acc_loss = acc_loss + loss
|
||||
@@ -177,13 +177,29 @@ class GraphSolver(LightningModule):
|
||||
acc_loss = acc_loss + loss
|
||||
|
||||
x = out
|
||||
if i % self.accumulation_iters != 0:
|
||||
|
||||
loss = self.loss(out, y)
|
||||
loss.backward()
|
||||
|
||||
# self.manual_backward(loss)
|
||||
optim.step()
|
||||
optim.zero_grad()
|
||||
|
||||
self._log_loss(acc_loss / acc_it, batch, "train")
|
||||
# self.log(
|
||||
# "train/y_loss",
|
||||
# loss,
|
||||
# on_step=False,
|
||||
# on_epoch=True,
|
||||
# prog_bar=True,
|
||||
# batch_size=int(batch.num_graphs),
|
||||
# )
|
||||
self.log(
|
||||
"train/accumulated_loss",
|
||||
(acc_loss / acc_it if acc_it > 0 else acc_loss),
|
||||
on_step=False,
|
||||
on_epoch=True,
|
||||
prog_bar=True,
|
||||
batch_size=int(batch.num_graphs),
|
||||
)
|
||||
self.log(
|
||||
"train/iterations",
|
||||
i + 1,
|
||||
@@ -231,7 +247,7 @@ class GraphSolver(LightningModule):
|
||||
if converged:
|
||||
break
|
||||
x = out
|
||||
loss = 0.5 * self.loss(out, y)
|
||||
loss = self.loss(out, y)
|
||||
self._log_loss(loss, batch, "val")
|
||||
self.log(
|
||||
"val/iterations",
|
||||
|
||||
Reference in New Issue
Block a user