plot_loss update

This commit is contained in:
Dario Coscia
2023-10-30 13:02:07 +01:00
committed by Nicola Demo
parent ead952da6f
commit 8cb4df13f0

View File

@@ -2,7 +2,7 @@
import matplotlib.pyplot as plt
import torch
from pina.callbacks import MetricTracker
from pina.utils import check_consistency
from pina import LabelTensor
@@ -189,15 +189,19 @@ class Plotter:
else:
plt.show()
def plot_loss(self, trainer, metric=None, label=None, log_scale=True):
def plot_loss(self, trainer, metrics=None, logy = False, logx=False, filename=None, **kwargs):
"""
Plot the loss function values during traininig.
:param SolverInterface solver: the SolverInterface object.
:param str metric: the metric to use in the y axis.
:param str label: the label to use in the legend, defaults to None.
:param bool log_scale: If True, the y axis is in log scale. Default is
:param str/list(str) metric: The metrics to use in the y axis. If None, the mean loss
is plotted.
:param bool logy: If True, the y axis is in log scale. Default is
True.
:param bool logx: If True, the x axis is in log scale. Default is
True.
:param str filename: the file name to save the plot. If None, the plot
is shown using the setted matplotlib frontend. Default is None.
"""
# check that MetricTracker has been used
@@ -206,21 +210,34 @@ class Plotter:
raise FileNotFoundError('MetricTracker should be used as a callback during training to'
' use this method.')
metrics = trainer.callbacks[list_[0]].metrics
# extract trainer metrics
trainer_metrics = trainer.callbacks[list_[0]].metrics
if metrics is None:
metrics = ['mean_loss']
elif not isinstance(metrics, list):
raise ValueError('metrics must be class list.')
# loop over metrics to plot
for metric in metrics:
if metric not in trainer_metrics:
raise ValueError(f'{metric} not a valid metric. Available metrics are {list(trainer_metrics.keys())}.')
loss = trainer_metrics[metric]
epochs = range(len(loss))
plt.plot(epochs, loss, label=metric, **kwargs)
if not metric:
metric = 'mean_loss'
loss = metrics[metric]
epochs = range(len(loss))
if label is not None:
plt.plot(epochs, loss, label=label)
plt.legend()
else:
plt.plot(epochs, loss)
if log_scale:
plt.yscale('log')
# plotting
plt.xlabel('epoch')
plt.ylabel(metric)
plt.ylabel('loss')
plt.legend()
# log axis
if logy:
plt.yscale('log')
if logx:
plt.xscale('log')
# saving in file
if filename:
plt.savefig(filename)
else:
plt.show()