minor fix, add few tests (#38)

This commit is contained in:
Nicola Demo
2022-11-29 12:42:01 +01:00
committed by GitHub
parent a1e947fede
commit 936f5e1043
4 changed files with 52 additions and 12 deletions

View File

@@ -131,3 +131,22 @@ class Plotter:
plt.savefig(filename)
else:
plt.show()
def plot_loss(self, pinn, label=None, log_scale=True):
"""
Plot the loss trend
TODO
"""
if not label:
label = str(pinn)
epochs = list(pinn.history_loss.keys())
loss = np.array(list(pinn.history_loss.values()))
if loss.ndim != 1:
loss = loss[:, 0]
plt.plot(epochs, loss, label=label)
if log_scale:
plt.yscale('log')