fix : gpu loss plotting

This commit is contained in:
mathayay
2023-11-29 16:31:28 +01:00
committed by Nicola Demo
parent 42cf7fcaf1
commit 4639374961

View File

@@ -256,7 +256,7 @@ class Plotter:
)
loss = trainer_metrics[metric]
epochs = range(len(loss))
plt.plot(epochs, loss, **kwargs)
plt.plot(epochs, loss.cpu(), **kwargs)
# plotting
plt.xlabel('epoch')