Tutorials v0.1 (#178)
Tutorial update and small fixes * Tutorials update + Tutorial FNO * Create a metric tracker callback * Update PINN for logging * Update plotter for plotting * Small fix LabelTensor * Small fix FNO --------- Co-authored-by: Dario Coscia <dariocoscia@cli-10-110-13-250.WIFIeduroamSTUD.units.it> Co-authored-by: Dario Coscia <dariocoscia@dhcp-176.eduroam.sissa.it>
This commit is contained in:
committed by
Nicola Demo
parent
939353f517
commit
a9b1bd2826
@@ -1,6 +1,7 @@
|
||||
""" Module for plotting. """
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from pina.callbacks import MetricTracker
|
||||
|
||||
from pina import LabelTensor
|
||||
|
||||
@@ -129,12 +130,12 @@ class Plotter:
|
||||
*grids, pred_output.cpu().detach(), **kwargs)
|
||||
fig.colorbar(cb, ax=ax)
|
||||
|
||||
def plot(self, solver, components=None, fixed_variables={}, method='contourf',
|
||||
def plot(self, trainer, components=None, fixed_variables={}, method='contourf',
|
||||
res=256, filename=None, **kwargs):
|
||||
"""
|
||||
Plot sample of SolverInterface output.
|
||||
|
||||
:param SolverInterface solver: the SolverInterface object.
|
||||
:param Trainer trainer: the Trainer object.
|
||||
:param list(str) components: the output variable to plot. If None, all
|
||||
the output variables of the problem are selected. Default value is
|
||||
None.
|
||||
@@ -149,6 +150,7 @@ class Plotter:
|
||||
:param str filename: the file name to save the plot. If None, the plot
|
||||
is shown using the setted matplotlib frontend. Default is None.
|
||||
"""
|
||||
solver = trainer.solver
|
||||
if components is None:
|
||||
components = [solver.problem.output_variables]
|
||||
v = [
|
||||
@@ -186,25 +188,38 @@ class Plotter:
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
# TODO loss
|
||||
# def plot_loss(self, solver, label=None, log_scale=True):
|
||||
# """
|
||||
# Plot the loss function values during traininig.
|
||||
def plot_loss(self, trainer, metric=None, label=None, log_scale=True):
|
||||
"""
|
||||
Plot the loss function values during traininig.
|
||||
|
||||
# :param SolverInterface solver: the SolverInterface object.
|
||||
# :param str label: the label to use in the legend, defaults to None.
|
||||
# :param bool log_scale: If True, the y axis is in log scale. Default is
|
||||
# True.
|
||||
# """
|
||||
:param SolverInterface solver: the SolverInterface object.
|
||||
:param str metric: the metric to use in the y axis.
|
||||
:param str label: the label to use in the legend, defaults to None.
|
||||
:param bool log_scale: If True, the y axis is in log scale. Default is
|
||||
True.
|
||||
"""
|
||||
|
||||
# if not label:
|
||||
# label = str(solver)
|
||||
# check that MetricTracker has been used
|
||||
list_ = [idx for idx, s in enumerate(trainer.callbacks) if isinstance(s, MetricTracker)]
|
||||
if not bool(list_):
|
||||
raise FileNotFoundError('MetricTracker should be used as a callback during training to'
|
||||
' use this method.')
|
||||
|
||||
# epochs = list(solver.history_loss.keys())
|
||||
# loss = np.array(list(solver.history_loss.values()))
|
||||
# if loss.ndim != 1:
|
||||
# loss = loss[:, 0]
|
||||
metrics = trainer.callbacks[list_[0]].metrics
|
||||
|
||||
# plt.plot(epochs, loss, label=label)
|
||||
# if log_scale:
|
||||
# plt.yscale('log')
|
||||
if not metric:
|
||||
metric = 'mean_loss'
|
||||
|
||||
loss = metrics[metric]
|
||||
epochs = range(len(loss))
|
||||
|
||||
if label is not None:
|
||||
plt.plot(epochs, loss, label=label)
|
||||
plt.legend()
|
||||
else:
|
||||
plt.plot(epochs, loss)
|
||||
|
||||
if log_scale:
|
||||
plt.yscale('log')
|
||||
plt.xlabel('epoch')
|
||||
plt.ylabel(metric)
|
||||
|
||||
Reference in New Issue
Block a user