diff --git a/pina/plotter.py b/pina/plotter.py index 6fb3af1..0d64afc 100644 --- a/pina/plotter.py +++ b/pina/plotter.py @@ -2,7 +2,7 @@ import matplotlib.pyplot as plt import torch from pina.callbacks import MetricTracker - +from pina.utils import check_consistency from pina import LabelTensor @@ -189,15 +189,19 @@ class Plotter: else: plt.show() - def plot_loss(self, trainer, metric=None, label=None, log_scale=True): + def plot_loss(self, trainer, metrics=None, logy = False, logx=False, filename=None, **kwargs): """ Plot the loss function values during traininig. :param SolverInterface solver: the SolverInterface object. - :param str metric: the metric to use in the y axis. - :param str label: the label to use in the legend, defaults to None. - :param bool log_scale: If True, the y axis is in log scale. Default is + :param str/list(str) metric: The metrics to use in the y axis. If None, the mean loss + is plotted. + :param bool logy: If True, the y axis is in log scale. Default is True. + :param bool logx: If True, the x axis is in log scale. Default is + True. + :param str filename: the file name to save the plot. If None, the plot + is shown using the setted matplotlib frontend. Default is None. """ # check that MetricTracker has been used @@ -206,21 +210,34 @@ class Plotter: raise FileNotFoundError('MetricTracker should be used as a callback during training to' ' use this method.') - metrics = trainer.callbacks[list_[0]].metrics + # extract trainer metrics + trainer_metrics = trainer.callbacks[list_[0]].metrics + if metrics is None: + metrics = ['mean_loss'] + elif not isinstance(metrics, list): + raise ValueError('metrics must be class list.') + + # loop over metrics to plot + for metric in metrics: + if metric not in trainer_metrics: + raise ValueError(f'{metric} not a valid metric. Available metrics are {list(trainer_metrics.keys())}.') + loss = trainer_metrics[metric] + epochs = range(len(loss)) + plt.plot(epochs, loss, label=metric, **kwargs) - if not metric: - metric = 'mean_loss' - - loss = metrics[metric] - epochs = range(len(loss)) - - if label is not None: - plt.plot(epochs, loss, label=label) - plt.legend() - else: - plt.plot(epochs, loss) - - if log_scale: - plt.yscale('log') + # plotting plt.xlabel('epoch') - plt.ylabel(metric) + plt.ylabel('loss') + plt.legend() + + # log axis + if logy: + plt.yscale('log') + if logx: + plt.xscale('log') + + # saving in file + if filename: + plt.savefig(filename) + else: + plt.show()