plot_loss update
This commit is contained in:
committed by
Nicola Demo
parent
ead952da6f
commit
8cb4df13f0
@@ -2,7 +2,7 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from pina.callbacks import MetricTracker
|
||||
|
||||
from pina.utils import check_consistency
|
||||
from pina import LabelTensor
|
||||
|
||||
|
||||
@@ -189,15 +189,19 @@ class Plotter:
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
def plot_loss(self, trainer, metric=None, label=None, log_scale=True):
|
||||
def plot_loss(self, trainer, metrics=None, logy = False, logx=False, filename=None, **kwargs):
|
||||
"""
|
||||
Plot the loss function values during traininig.
|
||||
|
||||
:param SolverInterface solver: the SolverInterface object.
|
||||
:param str metric: the metric to use in the y axis.
|
||||
:param str label: the label to use in the legend, defaults to None.
|
||||
:param bool log_scale: If True, the y axis is in log scale. Default is
|
||||
:param str/list(str) metric: The metrics to use in the y axis. If None, the mean loss
|
||||
is plotted.
|
||||
:param bool logy: If True, the y axis is in log scale. Default is
|
||||
True.
|
||||
:param bool logx: If True, the x axis is in log scale. Default is
|
||||
True.
|
||||
:param str filename: the file name to save the plot. If None, the plot
|
||||
is shown using the setted matplotlib frontend. Default is None.
|
||||
"""
|
||||
|
||||
# check that MetricTracker has been used
|
||||
@@ -206,21 +210,34 @@ class Plotter:
|
||||
raise FileNotFoundError('MetricTracker should be used as a callback during training to'
|
||||
' use this method.')
|
||||
|
||||
metrics = trainer.callbacks[list_[0]].metrics
|
||||
# extract trainer metrics
|
||||
trainer_metrics = trainer.callbacks[list_[0]].metrics
|
||||
if metrics is None:
|
||||
metrics = ['mean_loss']
|
||||
elif not isinstance(metrics, list):
|
||||
raise ValueError('metrics must be class list.')
|
||||
|
||||
# loop over metrics to plot
|
||||
for metric in metrics:
|
||||
if metric not in trainer_metrics:
|
||||
raise ValueError(f'{metric} not a valid metric. Available metrics are {list(trainer_metrics.keys())}.')
|
||||
loss = trainer_metrics[metric]
|
||||
epochs = range(len(loss))
|
||||
plt.plot(epochs, loss, label=metric, **kwargs)
|
||||
|
||||
if not metric:
|
||||
metric = 'mean_loss'
|
||||
|
||||
loss = metrics[metric]
|
||||
epochs = range(len(loss))
|
||||
|
||||
if label is not None:
|
||||
plt.plot(epochs, loss, label=label)
|
||||
plt.legend()
|
||||
else:
|
||||
plt.plot(epochs, loss)
|
||||
|
||||
if log_scale:
|
||||
plt.yscale('log')
|
||||
# plotting
|
||||
plt.xlabel('epoch')
|
||||
plt.ylabel(metric)
|
||||
plt.ylabel('loss')
|
||||
plt.legend()
|
||||
|
||||
# log axis
|
||||
if logy:
|
||||
plt.yscale('log')
|
||||
if logx:
|
||||
plt.xscale('log')
|
||||
|
||||
# saving in file
|
||||
if filename:
|
||||
plt.savefig(filename)
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
Reference in New Issue
Block a user