improve training_step
This commit is contained in:
@@ -106,9 +106,9 @@ class GraphSolver(LightningModule):
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def accumulate_gradients(self, losses, max_acc_iters):
|
def accumulate_gradients(self, losses):
|
||||||
loss_ = torch.stack(losses, dim=0).mean()
|
loss_ = torch.stack(losses, dim=0).mean()
|
||||||
self.manual_backward(loss_ / max_acc_iters, retain_graph=True)
|
self.manual_backward(loss_, retain_graph=True)
|
||||||
return loss_.item()
|
return loss_.item()
|
||||||
|
|
||||||
def _preprocess_batch(self, batch: Batch):
|
def _preprocess_batch(self, batch: Batch):
|
||||||
@@ -124,7 +124,7 @@ class GraphSolver(LightningModule):
|
|||||||
edge_attr = edge_attr * c_ij
|
edge_attr = edge_attr * c_ij
|
||||||
return x, y, edge_index, edge_attr
|
return x, y, edge_index, edge_attr
|
||||||
|
|
||||||
def training_step(self, batch: Batch, batch_idx: int):
|
def training_step(self, batch: Batch, _):
|
||||||
optim = self.optimizers()
|
optim = self.optimizers()
|
||||||
optim.zero_grad()
|
optim.zero_grad()
|
||||||
x, y, edge_index, edge_attr = self._preprocess_batch(batch)
|
x, y, edge_index, edge_attr = self._preprocess_batch(batch)
|
||||||
@@ -153,7 +153,7 @@ class GraphSolver(LightningModule):
|
|||||||
self.accumulation_iters is not None
|
self.accumulation_iters is not None
|
||||||
and (i + 1) % self.accumulation_iters == 0
|
and (i + 1) % self.accumulation_iters == 0
|
||||||
):
|
):
|
||||||
loss = self.accumulate_gradients(losses, max_acc_iters)
|
loss = self.accumulate_gradients(losses)
|
||||||
losses = []
|
losses = []
|
||||||
acc_it += 1
|
acc_it += 1
|
||||||
out = out.detach()
|
out = out.detach()
|
||||||
@@ -163,8 +163,7 @@ class GraphSolver(LightningModule):
|
|||||||
converged = self._check_convergence(out, x)
|
converged = self._check_convergence(out, x)
|
||||||
if converged:
|
if converged:
|
||||||
if losses:
|
if losses:
|
||||||
|
loss = self.accumulate_gradients(losses)
|
||||||
loss = self.accumulate_gradients(losses, max_acc_iters)
|
|
||||||
acc_it += 1
|
acc_it += 1
|
||||||
acc_loss = acc_loss + loss
|
acc_loss = acc_loss + loss
|
||||||
break
|
break
|
||||||
@@ -172,26 +171,19 @@ class GraphSolver(LightningModule):
|
|||||||
# Final accumulation if we are at the last iteration
|
# Final accumulation if we are at the last iteration
|
||||||
if i == self.current_iters - 1:
|
if i == self.current_iters - 1:
|
||||||
if losses:
|
if losses:
|
||||||
loss = self.accumulate_gradients(losses, max_acc_iters)
|
loss = self.accumulate_gradients(losses)
|
||||||
acc_it += 1
|
acc_it += 1
|
||||||
acc_loss = acc_loss + loss
|
acc_loss = acc_loss + loss
|
||||||
|
|
||||||
x = out
|
x = out
|
||||||
|
|
||||||
loss = self.loss(out, y)
|
loss = self.loss(out, y)
|
||||||
|
for param in self.model.parameters():
|
||||||
# self.manual_backward(loss)
|
if param.grad is not None:
|
||||||
|
param.grad /= acc_it
|
||||||
optim.step()
|
optim.step()
|
||||||
optim.zero_grad()
|
optim.zero_grad()
|
||||||
|
|
||||||
# self.log(
|
|
||||||
# "train/y_loss",
|
|
||||||
# loss,
|
|
||||||
# on_step=False,
|
|
||||||
# on_epoch=True,
|
|
||||||
# prog_bar=True,
|
|
||||||
# batch_size=int(batch.num_graphs),
|
|
||||||
# )
|
|
||||||
self.log(
|
self.log(
|
||||||
"train/accumulated_loss",
|
"train/accumulated_loss",
|
||||||
(acc_loss / acc_it if acc_it > 0 else acc_loss),
|
(acc_loss / acc_it if acc_it > 0 else acc_loss),
|
||||||
|
|||||||
Reference in New Issue
Block a user