fix : gpu loss plotting

This commit is contained in:
mathayay
2023-11-29 16:31:28 +01:00
committed by Nicola Demo
parent 42cf7fcaf1
commit 4639374961

View File

@@ -256,7 +256,7 @@ class Plotter:
) )
loss = trainer_metrics[metric] loss = trainer_metrics[metric]
epochs = range(len(loss)) epochs = range(len(loss))
plt.plot(epochs, loss, **kwargs) plt.plot(epochs, loss.cpu(), **kwargs)
# plotting # plotting
plt.xlabel('epoch') plt.xlabel('epoch')