fix training_step
This commit is contained in:
@@ -108,8 +108,7 @@ class GraphSolver(LightningModule):
|
|||||||
|
|
||||||
def accumulate_gradients(self, losses, max_acc_iters):
|
def accumulate_gradients(self, losses, max_acc_iters):
|
||||||
loss_ = torch.stack(losses, dim=0).mean()
|
loss_ = torch.stack(losses, dim=0).mean()
|
||||||
loss = 0.5 * loss_ / self.accumulation_iters
|
self.manual_backward(loss_ / max_acc_iters, retain_graph=True)
|
||||||
self.manual_backward(loss / max_acc_iters, retain_graph=True)
|
|
||||||
return loss_.item()
|
return loss_.item()
|
||||||
|
|
||||||
def _preprocess_batch(self, batch: Batch):
|
def _preprocess_batch(self, batch: Batch):
|
||||||
@@ -164,6 +163,7 @@ class GraphSolver(LightningModule):
|
|||||||
converged = self._check_convergence(out, x)
|
converged = self._check_convergence(out, x)
|
||||||
if converged:
|
if converged:
|
||||||
if losses:
|
if losses:
|
||||||
|
|
||||||
loss = self.accumulate_gradients(losses, max_acc_iters)
|
loss = self.accumulate_gradients(losses, max_acc_iters)
|
||||||
acc_it += 1
|
acc_it += 1
|
||||||
acc_loss = acc_loss + loss
|
acc_loss = acc_loss + loss
|
||||||
@@ -177,13 +177,29 @@ class GraphSolver(LightningModule):
|
|||||||
acc_loss = acc_loss + loss
|
acc_loss = acc_loss + loss
|
||||||
|
|
||||||
x = out
|
x = out
|
||||||
if i % self.accumulation_iters != 0:
|
|
||||||
loss = self.loss(out, y)
|
loss = self.loss(out, y)
|
||||||
loss.backward()
|
|
||||||
|
# self.manual_backward(loss)
|
||||||
optim.step()
|
optim.step()
|
||||||
optim.zero_grad()
|
optim.zero_grad()
|
||||||
|
|
||||||
self._log_loss(acc_loss / acc_it, batch, "train")
|
# self.log(
|
||||||
|
# "train/y_loss",
|
||||||
|
# loss,
|
||||||
|
# on_step=False,
|
||||||
|
# on_epoch=True,
|
||||||
|
# prog_bar=True,
|
||||||
|
# batch_size=int(batch.num_graphs),
|
||||||
|
# )
|
||||||
|
self.log(
|
||||||
|
"train/accumulated_loss",
|
||||||
|
(acc_loss / acc_it if acc_it > 0 else acc_loss),
|
||||||
|
on_step=False,
|
||||||
|
on_epoch=True,
|
||||||
|
prog_bar=True,
|
||||||
|
batch_size=int(batch.num_graphs),
|
||||||
|
)
|
||||||
self.log(
|
self.log(
|
||||||
"train/iterations",
|
"train/iterations",
|
||||||
i + 1,
|
i + 1,
|
||||||
@@ -231,7 +247,7 @@ class GraphSolver(LightningModule):
|
|||||||
if converged:
|
if converged:
|
||||||
break
|
break
|
||||||
x = out
|
x = out
|
||||||
loss = 0.5 * self.loss(out, y)
|
loss = self.loss(out, y)
|
||||||
self._log_loss(loss, batch, "val")
|
self._log_loss(loss, batch, "val")
|
||||||
self.log(
|
self.log(
|
||||||
"val/iterations",
|
"val/iterations",
|
||||||
|
|||||||
Reference in New Issue
Block a user