Documentation for v0.1 version (#199)
* Adding Equations, solving typos * improve _code.rst * the team rst and restuctore index.rst * fixing errors --------- Co-authored-by: Dario Coscia <dariocoscia@dhcp-015.eduroam.sissa.it>
This commit is contained in:
committed by
Nicola Demo
parent
3f9305d475
commit
8b7b61b3bd
@@ -1,6 +1,5 @@
|
||||
""" Module for plotting. """
|
||||
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from pina.callbacks import MetricTracker
|
||||
@@ -14,9 +13,9 @@ class Plotter:
|
||||
|
||||
def plot_samples(self, problem, variables=None, **kwargs):
|
||||
"""
|
||||
Plot the training grid samples.
|
||||
Plot the training grid samples.
|
||||
|
||||
:param SolverInterface solver: The SolverInterface object.
|
||||
:param SolverInterface solver: The ``SolverInterface`` object.
|
||||
:param list(str) variables: Variables to plot. If None, all variables
|
||||
are plotted. If 'spatial', only spatial variables are plotted. If
|
||||
'temporal', only temporal variables are plotted. Defaults to None.
|
||||
@@ -44,11 +43,13 @@ class Plotter:
|
||||
proj = '3d' if len(variables) == 3 else None
|
||||
ax = fig.add_subplot(projection=proj)
|
||||
for location in problem.input_pts:
|
||||
coords = problem.input_pts[location].extract(
|
||||
variables).T.detach()
|
||||
coords = problem.input_pts[location].extract(variables).T.detach()
|
||||
if coords.shape[0] == 1: # 1D samples
|
||||
ax.plot(coords.flatten(), torch.zeros(coords.flatten().shape), '.',
|
||||
label=location, **kwargs)
|
||||
ax.plot(coords.flatten(),
|
||||
torch.zeros(coords.flatten().shape),
|
||||
'.',
|
||||
label=location,
|
||||
**kwargs)
|
||||
else:
|
||||
ax.plot(*coords, '.', label=location, **kwargs)
|
||||
|
||||
@@ -92,13 +93,19 @@ class Plotter:
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
def _2d_plot(self, pts, pred, v, res, method, truth_solution=None,
|
||||
def _2d_plot(self,
|
||||
pts,
|
||||
pred,
|
||||
v,
|
||||
res,
|
||||
method,
|
||||
truth_solution=None,
|
||||
**kwargs):
|
||||
"""Plot solution for two dimensional function
|
||||
|
||||
:param pts: Points to plot the solution.
|
||||
:type pts: torch.Tensor
|
||||
:param pred: SolverInterface solution evaluated at 'pts'.
|
||||
:param pred: ``SolverInterface`` solution evaluated at 'pts'.
|
||||
:type pred: torch.Tensor
|
||||
:param method: Matplotlib method to plot 2-dimensional data,
|
||||
see https://matplotlib.org/stable/api/axes_api.html for
|
||||
@@ -116,32 +123,39 @@ class Plotter:
|
||||
truth_output = truth_solution(pts).float().reshape(res, res)
|
||||
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
|
||||
|
||||
cb = getattr(ax[0], method)(
|
||||
*grids, pred_output.cpu().detach(), **kwargs)
|
||||
cb = getattr(ax[0], method)(*grids, pred_output.cpu().detach(),
|
||||
**kwargs)
|
||||
fig.colorbar(cb, ax=ax[0])
|
||||
ax[0].title.set_text('Neural Network prediction')
|
||||
cb = getattr(ax[1], method)(
|
||||
*grids, truth_output.cpu().detach(), **kwargs)
|
||||
cb = getattr(ax[1], method)(*grids, truth_output.cpu().detach(),
|
||||
**kwargs)
|
||||
fig.colorbar(cb, ax=ax[1])
|
||||
ax[1].title.set_text('True solution')
|
||||
cb = getattr(ax[2], method)(*grids,
|
||||
(truth_output-pred_output).cpu().detach(),
|
||||
**kwargs)
|
||||
cb = getattr(ax[2],
|
||||
method)(*grids,
|
||||
(truth_output - pred_output).cpu().detach(),
|
||||
**kwargs)
|
||||
fig.colorbar(cb, ax=ax[2])
|
||||
ax[2].title.set_text('Residual')
|
||||
else:
|
||||
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
|
||||
cb = getattr(ax, method)(
|
||||
*grids, pred_output.cpu().detach(), **kwargs)
|
||||
cb = getattr(ax, method)(*grids, pred_output.cpu().detach(),
|
||||
**kwargs)
|
||||
fig.colorbar(cb, ax=ax)
|
||||
ax.title.set_text('Neural Network prediction')
|
||||
|
||||
def plot(self, solver, components=None, fixed_variables={}, method='contourf',
|
||||
res=256, filename=None, **kwargs):
|
||||
def plot(self,
|
||||
solver,
|
||||
components=None,
|
||||
fixed_variables={},
|
||||
method='contourf',
|
||||
res=256,
|
||||
filename=None,
|
||||
**kwargs):
|
||||
"""
|
||||
Plot sample of SolverInterface output.
|
||||
|
||||
:param SolverInterface solver: The SolverInterface object instance.
|
||||
:param SolverInterface solver: The ``SolverInterface`` object instance.
|
||||
:param list(str) components: The output variable to plot. If None, all
|
||||
the output variables of the problem are selected. Default value is
|
||||
None.
|
||||
@@ -149,8 +163,9 @@ class Plotter:
|
||||
should be kept fixed during the plot. The keys of the dictionary
|
||||
are the variables name whereas the values are the corresponding
|
||||
values of the variables. Defaults is `dict()`.
|
||||
:param {'contourf', 'pcolor'} method: The matplotlib method to use for
|
||||
plotting the solution. Default is 'contourf'.
|
||||
:param str method: The matplotlib method to use for
|
||||
plotting the solution. Available methods are {'contourf', 'pcolor'}.
|
||||
Default is 'contourf'.
|
||||
:param int res: The resolution, aka the number of points used for
|
||||
plotting in each axis. Default is 256.
|
||||
:param str filename: The file name to save the plot. If None, the plot
|
||||
@@ -184,8 +199,8 @@ class Plotter:
|
||||
self._1d_plot(pts, predicted_output, method, truth_solution,
|
||||
**kwargs)
|
||||
elif len(v) == 2:
|
||||
self._2d_plot(pts, predicted_output, v, res, method,
|
||||
truth_solution, **kwargs)
|
||||
self._2d_plot(pts, predicted_output, v, res, method, truth_solution,
|
||||
**kwargs)
|
||||
|
||||
plt.tight_layout()
|
||||
if filename:
|
||||
@@ -193,12 +208,19 @@ class Plotter:
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
def plot_loss(self, trainer, metrics=None, logy = False, logx=False, filename=None, **kwargs):
|
||||
def plot_loss(self,
|
||||
trainer,
|
||||
metrics=None,
|
||||
logy=False,
|
||||
logx=False,
|
||||
filename=None,
|
||||
**kwargs):
|
||||
"""
|
||||
Plot the loss function values during traininig.
|
||||
|
||||
:param Trainer trainer: the PINA Trainer object instance.
|
||||
:param str/list(str) metric: The metrics to use in the y axis. If None, the mean loss
|
||||
:param trainer: the PINA Trainer object instance.
|
||||
:type trainer: Trainer
|
||||
:param str | list(str) metric: The metrics to use in the y axis. If None, the mean loss
|
||||
is plotted.
|
||||
:param bool logy: If True, the y axis is in log scale. Default is
|
||||
True.
|
||||
@@ -209,10 +231,14 @@ class Plotter:
|
||||
"""
|
||||
|
||||
# check that MetricTracker has been used
|
||||
list_ = [idx for idx, s in enumerate(trainer.callbacks) if isinstance(s, MetricTracker)]
|
||||
list_ = [
|
||||
idx for idx, s in enumerate(trainer.callbacks)
|
||||
if isinstance(s, MetricTracker)
|
||||
]
|
||||
if not bool(list_):
|
||||
raise FileNotFoundError('MetricTracker should be used as a callback during training to'
|
||||
' use this method.')
|
||||
raise FileNotFoundError(
|
||||
'MetricTracker should be used as a callback during training to'
|
||||
' use this method.')
|
||||
|
||||
# extract trainer metrics
|
||||
trainer_metrics = trainer.callbacks[list_[0]].metrics
|
||||
@@ -220,11 +246,13 @@ class Plotter:
|
||||
metrics = ['mean_loss']
|
||||
elif not isinstance(metrics, list):
|
||||
raise ValueError('metrics must be class list.')
|
||||
|
||||
|
||||
# loop over metrics to plot
|
||||
for metric in metrics:
|
||||
if metric not in trainer_metrics:
|
||||
raise ValueError(f'{metric} not a valid metric. Available metrics are {list(trainer_metrics.keys())}.')
|
||||
raise ValueError(
|
||||
f'{metric} not a valid metric. Available metrics are {list(trainer_metrics.keys())}.'
|
||||
)
|
||||
loss = trainer_metrics[metric]
|
||||
epochs = range(len(loss))
|
||||
plt.plot(epochs, loss, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user