* plotter.py plot_loss fixed Issue #95 * Improved the plot_loss to show title and labels.
This commit is contained in:
@@ -186,7 +186,7 @@ class Plotter:
|
|||||||
else:
|
else:
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
def plot_loss(self, pinn, label=None, log_scale=True):
|
def plot_loss(self, pinn, label=None, log_scale=True, filename=None):
|
||||||
"""
|
"""
|
||||||
Plot the loss function values during traininig.
|
Plot the loss function values during traininig.
|
||||||
|
|
||||||
@@ -194,6 +194,8 @@ class Plotter:
|
|||||||
:param str label: the label to use in the legend, defaults to None.
|
:param str label: the label to use in the legend, defaults to None.
|
||||||
:param bool log_scale: If True, the y axis is in log scale. Default is
|
:param bool log_scale: If True, the y axis is in log scale. Default is
|
||||||
True.
|
True.
|
||||||
|
:param str filename: the file name to save the plot. If None, the plot
|
||||||
|
is not saved. Default is None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not label:
|
if not label:
|
||||||
@@ -201,9 +203,20 @@ class Plotter:
|
|||||||
|
|
||||||
epochs = list(pinn.history_loss.keys())
|
epochs = list(pinn.history_loss.keys())
|
||||||
loss = np.array(list(pinn.history_loss.values()))
|
loss = np.array(list(pinn.history_loss.values()))
|
||||||
if loss.ndim != 1:
|
|
||||||
loss = loss[:, 0]
|
|
||||||
|
|
||||||
|
# if multiple outputs, sum the loss
|
||||||
|
if loss.ndim != 1:
|
||||||
|
loss = np.sum(loss, axis=1)
|
||||||
|
|
||||||
|
# plot loss
|
||||||
plt.plot(epochs, loss, label=label)
|
plt.plot(epochs, loss, label=label)
|
||||||
|
plt.legend()
|
||||||
if log_scale:
|
if log_scale:
|
||||||
plt.yscale('log')
|
plt.yscale('log')
|
||||||
|
plt.title('Loss function')
|
||||||
|
plt.xlabel('Epochs')
|
||||||
|
plt.ylabel('Loss')
|
||||||
|
|
||||||
|
# save plot
|
||||||
|
if filename:
|
||||||
|
plt.savefig(filename)
|
||||||
|
|||||||
Reference in New Issue
Block a user