fix : gpu loss plotting
This commit is contained in:
@@ -256,7 +256,7 @@ class Plotter:
|
|||||||
)
|
)
|
||||||
loss = trainer_metrics[metric]
|
loss = trainer_metrics[metric]
|
||||||
epochs = range(len(loss))
|
epochs = range(len(loss))
|
||||||
plt.plot(epochs, loss, **kwargs)
|
plt.plot(epochs, loss.cpu(), **kwargs)
|
||||||
|
|
||||||
# plotting
|
# plotting
|
||||||
plt.xlabel('epoch')
|
plt.xlabel('epoch')
|
||||||
|
|||||||
Reference in New Issue
Block a user