diff --git a/pina/plotter.py b/pina/plotter.py index 39102b9..b509e2d 100644 --- a/pina/plotter.py +++ b/pina/plotter.py @@ -186,7 +186,7 @@ class Plotter: else: plt.show() - def plot_loss(self, pinn, label=None, log_scale=True): + def plot_loss(self, pinn, label=None, log_scale=True, filename=None): """ Plot the loss function values during traininig. @@ -194,6 +194,8 @@ class Plotter: :param str label: the label to use in the legend, defaults to None. :param bool log_scale: If True, the y axis is in log scale. Default is True. + :param str filename: the file name to save the plot. If None, the plot + is not saved. Default is None. """ if not label: @@ -201,9 +203,20 @@ class Plotter: epochs = list(pinn.history_loss.keys()) loss = np.array(list(pinn.history_loss.values())) - if loss.ndim != 1: - loss = loss[:, 0] + # if multiple outputs, sum the loss + if loss.ndim != 1: + loss = np.sum(loss, axis=1) + + # plot loss plt.plot(epochs, loss, label=label) + plt.legend() if log_scale: plt.yscale('log') + plt.title('Loss function') + plt.xlabel('Epochs') + plt.ylabel('Loss') + + # save plot + if filename: + plt.savefig(filename)