Tutorials v0.1 (#178)

Tutorial update and small fixes

* Tutorials update + Tutorial FNO
* Create a metric tracker callback
* Update PINN for logging
* Update plotter for plotting
* Small fix LabelTensor
* Small fix FNO

---------

Co-authored-by: Dario Coscia <dariocoscia@cli-10-110-13-250.WIFIeduroamSTUD.units.it>
Co-authored-by: Dario Coscia <dariocoscia@dhcp-176.eduroam.sissa.it>
This commit is contained in:
Dario Coscia
2023-09-26 17:29:37 +02:00
committed by Nicola Demo
parent 939353f517
commit a9b1bd2826
45 changed files with 2760 additions and 1321 deletions

View File

@@ -1,6 +1,7 @@
""" Module for plotting. """
import matplotlib.pyplot as plt
import torch
from pina.callbacks import MetricTracker
from pina import LabelTensor
@@ -129,12 +130,12 @@ class Plotter:
*grids, pred_output.cpu().detach(), **kwargs)
fig.colorbar(cb, ax=ax)
def plot(self, solver, components=None, fixed_variables={}, method='contourf',
def plot(self, trainer, components=None, fixed_variables={}, method='contourf',
res=256, filename=None, **kwargs):
"""
Plot sample of SolverInterface output.
:param SolverInterface solver: the SolverInterface object.
:param Trainer trainer: the Trainer object.
:param list(str) components: the output variable to plot. If None, all
the output variables of the problem are selected. Default value is
None.
@@ -149,6 +150,7 @@ class Plotter:
:param str filename: the file name to save the plot. If None, the plot
is shown using the setted matplotlib frontend. Default is None.
"""
solver = trainer.solver
if components is None:
components = [solver.problem.output_variables]
v = [
@@ -186,25 +188,38 @@ class Plotter:
else:
plt.show()
# TODO loss
# def plot_loss(self, solver, label=None, log_scale=True):
# """
# Plot the loss function values during traininig.
def plot_loss(self, trainer, metric=None, label=None, log_scale=True):
"""
Plot the loss function values during traininig.
# :param SolverInterface solver: the SolverInterface object.
# :param str label: the label to use in the legend, defaults to None.
# :param bool log_scale: If True, the y axis is in log scale. Default is
# True.
# """
:param SolverInterface solver: the SolverInterface object.
:param str metric: the metric to use in the y axis.
:param str label: the label to use in the legend, defaults to None.
:param bool log_scale: If True, the y axis is in log scale. Default is
True.
"""
# if not label:
# label = str(solver)
# check that MetricTracker has been used
list_ = [idx for idx, s in enumerate(trainer.callbacks) if isinstance(s, MetricTracker)]
if not bool(list_):
raise FileNotFoundError('MetricTracker should be used as a callback during training to'
' use this method.')
# epochs = list(solver.history_loss.keys())
# loss = np.array(list(solver.history_loss.values()))
# if loss.ndim != 1:
# loss = loss[:, 0]
metrics = trainer.callbacks[list_[0]].metrics
# plt.plot(epochs, loss, label=label)
# if log_scale:
# plt.yscale('log')
if not metric:
metric = 'mean_loss'
loss = metrics[metric]
epochs = range(len(loss))
if label is not None:
plt.plot(epochs, loss, label=label)
plt.legend()
else:
plt.plot(epochs, loss)
if log_scale:
plt.yscale('log')
plt.xlabel('epoch')
plt.ylabel(metric)