update plot_samples, plot methods
This commit is contained in:
committed by
Nicola Demo
parent
8cb4df13f0
commit
5336f36f08
@@ -1,8 +1,9 @@
|
|||||||
""" Module for plotting. """
|
""" Module for plotting. """
|
||||||
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import torch
|
import torch
|
||||||
from pina.callbacks import MetricTracker
|
from pina.callbacks import MetricTracker
|
||||||
from pina.utils import check_consistency
|
|
||||||
from pina import LabelTensor
|
from pina import LabelTensor
|
||||||
|
|
||||||
|
|
||||||
@@ -11,12 +12,12 @@ class Plotter:
|
|||||||
Implementation of a plotter class, for easy visualizations.
|
Implementation of a plotter class, for easy visualizations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def plot_samples(self, problem, variables=None):
|
def plot_samples(self, problem, variables=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Plot the training grid samples.
|
Plot the training grid samples.
|
||||||
|
|
||||||
:param SolverInterface solver: the SolverInterface object.
|
:param SolverInterface solver: The SolverInterface object.
|
||||||
:param list(str) variables: variables to plot. If None, all variables
|
:param list(str) variables: Variables to plot. If None, all variables
|
||||||
are plotted. If 'spatial', only spatial variables are plotted. If
|
are plotted. If 'spatial', only spatial variables are plotted. If
|
||||||
'temporal', only temporal variables are plotted. Defaults to None.
|
'temporal', only temporal variables are plotted. Defaults to None.
|
||||||
|
|
||||||
@@ -26,7 +27,7 @@ class Plotter:
|
|||||||
|
|
||||||
:Example:
|
:Example:
|
||||||
>>> plotter = Plotter()
|
>>> plotter = Plotter()
|
||||||
>>> plotter.plot_samples(solver=solver, variables='spatial')
|
>>> plotter.plot_samples(problem=problem, variables='spatial')
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if variables is None:
|
if variables is None:
|
||||||
@@ -47,9 +48,9 @@ class Plotter:
|
|||||||
variables).T.detach()
|
variables).T.detach()
|
||||||
if coords.shape[0] == 1: # 1D samples
|
if coords.shape[0] == 1: # 1D samples
|
||||||
ax.plot(coords.flatten(), torch.zeros(coords.flatten().shape), '.',
|
ax.plot(coords.flatten(), torch.zeros(coords.flatten().shape), '.',
|
||||||
label=location)
|
label=location, **kwargs)
|
||||||
else:
|
else:
|
||||||
ax.plot(*coords, '.', label=location)
|
ax.plot(*coords, '.', label=location, **kwargs)
|
||||||
|
|
||||||
ax.set_xlabel(variables[0])
|
ax.set_xlabel(variables[0])
|
||||||
try:
|
try:
|
||||||
@@ -72,7 +73,7 @@ class Plotter:
|
|||||||
:type pts: torch.Tensor
|
:type pts: torch.Tensor
|
||||||
:param pred: SolverInterface solution evaluated at 'pts'.
|
:param pred: SolverInterface solution evaluated at 'pts'.
|
||||||
:type pred: torch.Tensor
|
:type pred: torch.Tensor
|
||||||
:param method: not used, kept for code compatibility
|
:param method: Not used, kept for code compatibility
|
||||||
:type method: None
|
:type method: None
|
||||||
:param truth_solution: Real solution evaluated at 'pts',
|
:param truth_solution: Real solution evaluated at 'pts',
|
||||||
defaults to None.
|
defaults to None.
|
||||||
@@ -80,11 +81,11 @@ class Plotter:
|
|||||||
"""
|
"""
|
||||||
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 8))
|
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 8))
|
||||||
|
|
||||||
ax.plot(pts, pred.detach(), label='neural net solution', **kwargs)
|
ax.plot(pts, pred.detach(), label='Neural Network solution', **kwargs)
|
||||||
|
|
||||||
if truth_solution:
|
if truth_solution:
|
||||||
truth_output = truth_solution(pts).float()
|
truth_output = truth_solution(pts).float()
|
||||||
ax.plot(pts, truth_output.detach(), label='true solution', **kwargs)
|
ax.plot(pts, truth_output.detach(), label='True solution', **kwargs)
|
||||||
|
|
||||||
plt.xlabel(pts.labels[0])
|
plt.xlabel(pts.labels[0])
|
||||||
plt.ylabel(pred.labels[0])
|
plt.ylabel(pred.labels[0])
|
||||||
@@ -99,7 +100,7 @@ class Plotter:
|
|||||||
:type pts: torch.Tensor
|
:type pts: torch.Tensor
|
||||||
:param pred: SolverInterface solution evaluated at 'pts'.
|
:param pred: SolverInterface solution evaluated at 'pts'.
|
||||||
:type pred: torch.Tensor
|
:type pred: torch.Tensor
|
||||||
:param method: matplotlib method to plot 2-dimensional data,
|
:param method: Matplotlib method to plot 2-dimensional data,
|
||||||
see https://matplotlib.org/stable/api/axes_api.html for
|
see https://matplotlib.org/stable/api/axes_api.html for
|
||||||
reference.
|
reference.
|
||||||
:type method: str
|
:type method: str
|
||||||
@@ -118,40 +119,44 @@ class Plotter:
|
|||||||
cb = getattr(ax[0], method)(
|
cb = getattr(ax[0], method)(
|
||||||
*grids, pred_output.cpu().detach(), **kwargs)
|
*grids, pred_output.cpu().detach(), **kwargs)
|
||||||
fig.colorbar(cb, ax=ax[0])
|
fig.colorbar(cb, ax=ax[0])
|
||||||
|
ax[0].title.set_text('Neural Network prediction')
|
||||||
cb = getattr(ax[1], method)(
|
cb = getattr(ax[1], method)(
|
||||||
*grids, truth_output.cpu().detach(), **kwargs)
|
*grids, truth_output.cpu().detach(), **kwargs)
|
||||||
fig.colorbar(cb, ax=ax[1])
|
fig.colorbar(cb, ax=ax[1])
|
||||||
|
ax[1].title.set_text('True solution')
|
||||||
cb = getattr(ax[2], method)(*grids,
|
cb = getattr(ax[2], method)(*grids,
|
||||||
(truth_output-pred_output).cpu().detach(),
|
(truth_output-pred_output).cpu().detach(),
|
||||||
**kwargs)
|
**kwargs)
|
||||||
fig.colorbar(cb, ax=ax[2])
|
fig.colorbar(cb, ax=ax[2])
|
||||||
|
ax[2].title.set_text('Residual')
|
||||||
else:
|
else:
|
||||||
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
|
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
|
||||||
cb = getattr(ax, method)(
|
cb = getattr(ax, method)(
|
||||||
*grids, pred_output.cpu().detach(), **kwargs)
|
*grids, pred_output.cpu().detach(), **kwargs)
|
||||||
fig.colorbar(cb, ax=ax)
|
fig.colorbar(cb, ax=ax)
|
||||||
|
ax.title.set_text('Neural Network prediction')
|
||||||
|
|
||||||
def plot(self, trainer, components=None, fixed_variables={}, method='contourf',
|
def plot(self, solver, components=None, fixed_variables={}, method='contourf',
|
||||||
res=256, filename=None, **kwargs):
|
res=256, filename=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Plot sample of SolverInterface output.
|
Plot sample of SolverInterface output.
|
||||||
|
|
||||||
:param Trainer trainer: the Trainer object.
|
:param SolverInterface solver: The SolverInterface object instance.
|
||||||
:param list(str) components: the output variable to plot. If None, all
|
:param list(str) components: The output variable to plot. If None, all
|
||||||
the output variables of the problem are selected. Default value is
|
the output variables of the problem are selected. Default value is
|
||||||
None.
|
None.
|
||||||
:param dict fixed_variables: a dictionary with all the variables that
|
:param dict fixed_variables: A dictionary with all the variables that
|
||||||
should be kept fixed during the plot. The keys of the dictionary
|
should be kept fixed during the plot. The keys of the dictionary
|
||||||
are the variables name whereas the values are the corresponding
|
are the variables name whereas the values are the corresponding
|
||||||
values of the variables. Defaults is `dict()`.
|
values of the variables. Defaults is `dict()`.
|
||||||
:param {'contourf', 'pcolor'} method: the matplotlib method to use for
|
:param {'contourf', 'pcolor'} method: The matplotlib method to use for
|
||||||
plotting the solution. Default is 'contourf'.
|
plotting the solution. Default is 'contourf'.
|
||||||
:param int res: the resolution, aka the number of points used for
|
:param int res: The resolution, aka the number of points used for
|
||||||
plotting in each axis. Default is 256.
|
plotting in each axis. Default is 256.
|
||||||
:param str filename: the file name to save the plot. If None, the plot
|
:param str filename: The file name to save the plot. If None, the plot
|
||||||
is shown using the setted matplotlib frontend. Default is None.
|
is shown using the setted matplotlib frontend. Default is None.
|
||||||
"""
|
"""
|
||||||
solver = trainer.solver
|
|
||||||
if components is None:
|
if components is None:
|
||||||
components = [solver.problem.output_variables]
|
components = [solver.problem.output_variables]
|
||||||
v = [
|
v = [
|
||||||
@@ -182,9 +187,8 @@ class Plotter:
|
|||||||
self._2d_plot(pts, predicted_output, v, res, method,
|
self._2d_plot(pts, predicted_output, v, res, method,
|
||||||
truth_solution, **kwargs)
|
truth_solution, **kwargs)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
if filename:
|
if filename:
|
||||||
plt.title('Output {} with parameter {}'.format(components,
|
|
||||||
fixed_variables))
|
|
||||||
plt.savefig(filename)
|
plt.savefig(filename)
|
||||||
else:
|
else:
|
||||||
plt.show()
|
plt.show()
|
||||||
@@ -193,14 +197,14 @@ class Plotter:
|
|||||||
"""
|
"""
|
||||||
Plot the loss function values during traininig.
|
Plot the loss function values during traininig.
|
||||||
|
|
||||||
:param SolverInterface solver: the SolverInterface object.
|
:param Trainer trainer: the PINA Trainer object instance.
|
||||||
:param str/list(str) metric: The metrics to use in the y axis. If None, the mean loss
|
:param str/list(str) metric: The metrics to use in the y axis. If None, the mean loss
|
||||||
is plotted.
|
is plotted.
|
||||||
:param bool logy: If True, the y axis is in log scale. Default is
|
:param bool logy: If True, the y axis is in log scale. Default is
|
||||||
True.
|
True.
|
||||||
:param bool logx: If True, the x axis is in log scale. Default is
|
:param bool logx: If True, the x axis is in log scale. Default is
|
||||||
True.
|
True.
|
||||||
:param str filename: the file name to save the plot. If None, the plot
|
:param str filename: The file name to save the plot. If None, the plot
|
||||||
is shown using the setted matplotlib frontend. Default is None.
|
is shown using the setted matplotlib frontend. Default is None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user