* plotter.py plot_loss fixed Issue #95 * Improved the plot_loss to show title and labels.
This commit is contained in:
@@ -186,7 +186,7 @@ class Plotter:
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
def plot_loss(self, pinn, label=None, log_scale=True):
|
||||
def plot_loss(self, pinn, label=None, log_scale=True, filename=None):
|
||||
"""
|
||||
Plot the loss function values during traininig.
|
||||
|
||||
@@ -194,6 +194,8 @@ class Plotter:
|
||||
:param str label: the label to use in the legend, defaults to None.
|
||||
:param bool log_scale: If True, the y axis is in log scale. Default is
|
||||
True.
|
||||
:param str filename: the file name to save the plot. If None, the plot
|
||||
is not saved. Default is None.
|
||||
"""
|
||||
|
||||
if not label:
|
||||
@@ -201,9 +203,20 @@ class Plotter:
|
||||
|
||||
epochs = list(pinn.history_loss.keys())
|
||||
loss = np.array(list(pinn.history_loss.values()))
|
||||
if loss.ndim != 1:
|
||||
loss = loss[:, 0]
|
||||
|
||||
# if multiple outputs, sum the loss
|
||||
if loss.ndim != 1:
|
||||
loss = np.sum(loss, axis=1)
|
||||
|
||||
# plot loss
|
||||
plt.plot(epochs, loss, label=label)
|
||||
plt.legend()
|
||||
if log_scale:
|
||||
plt.yscale('log')
|
||||
plt.title('Loss function')
|
||||
plt.xlabel('Epochs')
|
||||
plt.ylabel('Loss')
|
||||
|
||||
# save plot
|
||||
if filename:
|
||||
plt.savefig(filename)
|
||||
|
||||
Reference in New Issue
Block a user