plot_loss update
This commit is contained in:
committed by
Nicola Demo
parent
ead952da6f
commit
8cb4df13f0
@@ -2,7 +2,7 @@
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import torch
|
import torch
|
||||||
from pina.callbacks import MetricTracker
|
from pina.callbacks import MetricTracker
|
||||||
|
from pina.utils import check_consistency
|
||||||
from pina import LabelTensor
|
from pina import LabelTensor
|
||||||
|
|
||||||
|
|
||||||
@@ -189,15 +189,19 @@ class Plotter:
|
|||||||
else:
|
else:
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
def plot_loss(self, trainer, metric=None, label=None, log_scale=True):
|
def plot_loss(self, trainer, metrics=None, logy = False, logx=False, filename=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Plot the loss function values during traininig.
|
Plot the loss function values during traininig.
|
||||||
|
|
||||||
:param SolverInterface solver: the SolverInterface object.
|
:param SolverInterface solver: the SolverInterface object.
|
||||||
:param str metric: the metric to use in the y axis.
|
:param str/list(str) metric: The metrics to use in the y axis. If None, the mean loss
|
||||||
:param str label: the label to use in the legend, defaults to None.
|
is plotted.
|
||||||
:param bool log_scale: If True, the y axis is in log scale. Default is
|
:param bool logy: If True, the y axis is in log scale. Default is
|
||||||
True.
|
True.
|
||||||
|
:param bool logx: If True, the x axis is in log scale. Default is
|
||||||
|
True.
|
||||||
|
:param str filename: the file name to save the plot. If None, the plot
|
||||||
|
is shown using the setted matplotlib frontend. Default is None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# check that MetricTracker has been used
|
# check that MetricTracker has been used
|
||||||
@@ -206,21 +210,34 @@ class Plotter:
|
|||||||
raise FileNotFoundError('MetricTracker should be used as a callback during training to'
|
raise FileNotFoundError('MetricTracker should be used as a callback during training to'
|
||||||
' use this method.')
|
' use this method.')
|
||||||
|
|
||||||
metrics = trainer.callbacks[list_[0]].metrics
|
# extract trainer metrics
|
||||||
|
trainer_metrics = trainer.callbacks[list_[0]].metrics
|
||||||
|
if metrics is None:
|
||||||
|
metrics = ['mean_loss']
|
||||||
|
elif not isinstance(metrics, list):
|
||||||
|
raise ValueError('metrics must be class list.')
|
||||||
|
|
||||||
|
# loop over metrics to plot
|
||||||
|
for metric in metrics:
|
||||||
|
if metric not in trainer_metrics:
|
||||||
|
raise ValueError(f'{metric} not a valid metric. Available metrics are {list(trainer_metrics.keys())}.')
|
||||||
|
loss = trainer_metrics[metric]
|
||||||
|
epochs = range(len(loss))
|
||||||
|
plt.plot(epochs, loss, label=metric, **kwargs)
|
||||||
|
|
||||||
if not metric:
|
# plotting
|
||||||
metric = 'mean_loss'
|
|
||||||
|
|
||||||
loss = metrics[metric]
|
|
||||||
epochs = range(len(loss))
|
|
||||||
|
|
||||||
if label is not None:
|
|
||||||
plt.plot(epochs, loss, label=label)
|
|
||||||
plt.legend()
|
|
||||||
else:
|
|
||||||
plt.plot(epochs, loss)
|
|
||||||
|
|
||||||
if log_scale:
|
|
||||||
plt.yscale('log')
|
|
||||||
plt.xlabel('epoch')
|
plt.xlabel('epoch')
|
||||||
plt.ylabel(metric)
|
plt.ylabel('loss')
|
||||||
|
plt.legend()
|
||||||
|
|
||||||
|
# log axis
|
||||||
|
if logy:
|
||||||
|
plt.yscale('log')
|
||||||
|
if logx:
|
||||||
|
plt.xscale('log')
|
||||||
|
|
||||||
|
# saving in file
|
||||||
|
if filename:
|
||||||
|
plt.savefig(filename)
|
||||||
|
else:
|
||||||
|
plt.show()
|
||||||
|
|||||||
Reference in New Issue
Block a user