Issue #95 fixed and improved plot_loss (#106)

* plotter.py plot_loss fixed Issue #95
* Improved the plot_loss to show title and labels.
This commit is contained in:
Bo van Hasselt
2023-06-28 15:25:24 +02:00
committed by GitHub
parent 67fb7fe891
commit b6adcffe07

View File

@@ -186,7 +186,7 @@ class Plotter:
else:
plt.show()
def plot_loss(self, pinn, label=None, log_scale=True):
def plot_loss(self, pinn, label=None, log_scale=True, filename=None):
"""
Plot the loss function values during traininig.
@@ -194,6 +194,8 @@ class Plotter:
:param str label: the label to use in the legend, defaults to None.
:param bool log_scale: If True, the y axis is in log scale. Default is
True.
:param str filename: the file name to save the plot. If None, the plot
is not saved. Default is None.
"""
if not label:
@@ -201,9 +203,20 @@ class Plotter:
epochs = list(pinn.history_loss.keys())
loss = np.array(list(pinn.history_loss.values()))
if loss.ndim != 1:
loss = loss[:, 0]
# if multiple outputs, sum the loss
if loss.ndim != 1:
loss = np.sum(loss, axis=1)
# plot loss
plt.plot(epochs, loss, label=label)
plt.legend()
if log_scale:
plt.yscale('log')
plt.title('Loss function')
plt.xlabel('Epochs')
plt.ylabel('Loss')
# save plot
if filename:
plt.savefig(filename)