minor fix, add few tests (#38)
This commit is contained in:
@@ -131,3 +131,22 @@ class Plotter:
|
||||
plt.savefig(filename)
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
def plot_loss(self, pinn, label=None, log_scale=True):
|
||||
"""
|
||||
Plot the loss trend
|
||||
|
||||
TODO
|
||||
"""
|
||||
|
||||
if not label:
|
||||
label = str(pinn)
|
||||
|
||||
epochs = list(pinn.history_loss.keys())
|
||||
loss = np.array(list(pinn.history_loss.values()))
|
||||
if loss.ndim != 1:
|
||||
loss = loss[:, 0]
|
||||
|
||||
plt.plot(epochs, loss, label=label)
|
||||
if log_scale:
|
||||
plt.yscale('log')
|
||||
|
||||
Reference in New Issue
Block a user