update plot_samples, plot methods

This commit is contained in:
Dario Coscia
2023-10-31 10:54:04 +01:00
committed by Nicola Demo
parent 8cb4df13f0
commit 5336f36f08

View File

@@ -1,8 +1,9 @@
""" Module for plotting. """ """ Module for plotting. """
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import torch import torch
from pina.callbacks import MetricTracker from pina.callbacks import MetricTracker
from pina.utils import check_consistency
from pina import LabelTensor from pina import LabelTensor
@@ -11,12 +12,12 @@ class Plotter:
Implementation of a plotter class, for easy visualizations. Implementation of a plotter class, for easy visualizations.
""" """
def plot_samples(self, problem, variables=None): def plot_samples(self, problem, variables=None, **kwargs):
""" """
Plot the training grid samples. Plot the training grid samples.
:param SolverInterface solver: the SolverInterface object. :param SolverInterface solver: The SolverInterface object.
:param list(str) variables: variables to plot. If None, all variables :param list(str) variables: Variables to plot. If None, all variables
are plotted. If 'spatial', only spatial variables are plotted. If are plotted. If 'spatial', only spatial variables are plotted. If
'temporal', only temporal variables are plotted. Defaults to None. 'temporal', only temporal variables are plotted. Defaults to None.
@@ -26,7 +27,7 @@ class Plotter:
:Example: :Example:
>>> plotter = Plotter() >>> plotter = Plotter()
>>> plotter.plot_samples(solver=solver, variables='spatial') >>> plotter.plot_samples(problem=problem, variables='spatial')
""" """
if variables is None: if variables is None:
@@ -47,9 +48,9 @@ class Plotter:
variables).T.detach() variables).T.detach()
if coords.shape[0] == 1: # 1D samples if coords.shape[0] == 1: # 1D samples
ax.plot(coords.flatten(), torch.zeros(coords.flatten().shape), '.', ax.plot(coords.flatten(), torch.zeros(coords.flatten().shape), '.',
label=location) label=location, **kwargs)
else: else:
ax.plot(*coords, '.', label=location) ax.plot(*coords, '.', label=location, **kwargs)
ax.set_xlabel(variables[0]) ax.set_xlabel(variables[0])
try: try:
@@ -72,7 +73,7 @@ class Plotter:
:type pts: torch.Tensor :type pts: torch.Tensor
:param pred: SolverInterface solution evaluated at 'pts'. :param pred: SolverInterface solution evaluated at 'pts'.
:type pred: torch.Tensor :type pred: torch.Tensor
:param method: not used, kept for code compatibility :param method: Not used, kept for code compatibility
:type method: None :type method: None
:param truth_solution: Real solution evaluated at 'pts', :param truth_solution: Real solution evaluated at 'pts',
defaults to None. defaults to None.
@@ -80,11 +81,11 @@ class Plotter:
""" """
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 8)) fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 8))
ax.plot(pts, pred.detach(), label='neural net solution', **kwargs) ax.plot(pts, pred.detach(), label='Neural Network solution', **kwargs)
if truth_solution: if truth_solution:
truth_output = truth_solution(pts).float() truth_output = truth_solution(pts).float()
ax.plot(pts, truth_output.detach(), label='true solution', **kwargs) ax.plot(pts, truth_output.detach(), label='True solution', **kwargs)
plt.xlabel(pts.labels[0]) plt.xlabel(pts.labels[0])
plt.ylabel(pred.labels[0]) plt.ylabel(pred.labels[0])
@@ -99,7 +100,7 @@ class Plotter:
:type pts: torch.Tensor :type pts: torch.Tensor
:param pred: SolverInterface solution evaluated at 'pts'. :param pred: SolverInterface solution evaluated at 'pts'.
:type pred: torch.Tensor :type pred: torch.Tensor
:param method: matplotlib method to plot 2-dimensional data, :param method: Matplotlib method to plot 2-dimensional data,
see https://matplotlib.org/stable/api/axes_api.html for see https://matplotlib.org/stable/api/axes_api.html for
reference. reference.
:type method: str :type method: str
@@ -118,40 +119,44 @@ class Plotter:
cb = getattr(ax[0], method)( cb = getattr(ax[0], method)(
*grids, pred_output.cpu().detach(), **kwargs) *grids, pred_output.cpu().detach(), **kwargs)
fig.colorbar(cb, ax=ax[0]) fig.colorbar(cb, ax=ax[0])
ax[0].title.set_text('Neural Network prediction')
cb = getattr(ax[1], method)( cb = getattr(ax[1], method)(
*grids, truth_output.cpu().detach(), **kwargs) *grids, truth_output.cpu().detach(), **kwargs)
fig.colorbar(cb, ax=ax[1]) fig.colorbar(cb, ax=ax[1])
ax[1].title.set_text('True solution')
cb = getattr(ax[2], method)(*grids, cb = getattr(ax[2], method)(*grids,
(truth_output-pred_output).cpu().detach(), (truth_output-pred_output).cpu().detach(),
**kwargs) **kwargs)
fig.colorbar(cb, ax=ax[2]) fig.colorbar(cb, ax=ax[2])
ax[2].title.set_text('Residual')
else: else:
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 6)) fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
cb = getattr(ax, method)( cb = getattr(ax, method)(
*grids, pred_output.cpu().detach(), **kwargs) *grids, pred_output.cpu().detach(), **kwargs)
fig.colorbar(cb, ax=ax) fig.colorbar(cb, ax=ax)
ax.title.set_text('Neural Network prediction')
def plot(self, trainer, components=None, fixed_variables={}, method='contourf', def plot(self, solver, components=None, fixed_variables={}, method='contourf',
res=256, filename=None, **kwargs): res=256, filename=None, **kwargs):
""" """
Plot sample of SolverInterface output. Plot sample of SolverInterface output.
:param Trainer trainer: the Trainer object. :param SolverInterface solver: The SolverInterface object instance.
:param list(str) components: the output variable to plot. If None, all :param list(str) components: The output variable to plot. If None, all
the output variables of the problem are selected. Default value is the output variables of the problem are selected. Default value is
None. None.
:param dict fixed_variables: a dictionary with all the variables that :param dict fixed_variables: A dictionary with all the variables that
should be kept fixed during the plot. The keys of the dictionary should be kept fixed during the plot. The keys of the dictionary
are the variables name whereas the values are the corresponding are the variables name whereas the values are the corresponding
values of the variables. Defaults is `dict()`. values of the variables. Defaults is `dict()`.
:param {'contourf', 'pcolor'} method: the matplotlib method to use for :param {'contourf', 'pcolor'} method: The matplotlib method to use for
plotting the solution. Default is 'contourf'. plotting the solution. Default is 'contourf'.
:param int res: the resolution, aka the number of points used for :param int res: The resolution, aka the number of points used for
plotting in each axis. Default is 256. plotting in each axis. Default is 256.
:param str filename: the file name to save the plot. If None, the plot :param str filename: The file name to save the plot. If None, the plot
is shown using the setted matplotlib frontend. Default is None. is shown using the setted matplotlib frontend. Default is None.
""" """
solver = trainer.solver
if components is None: if components is None:
components = [solver.problem.output_variables] components = [solver.problem.output_variables]
v = [ v = [
@@ -182,9 +187,8 @@ class Plotter:
self._2d_plot(pts, predicted_output, v, res, method, self._2d_plot(pts, predicted_output, v, res, method,
truth_solution, **kwargs) truth_solution, **kwargs)
plt.tight_layout()
if filename: if filename:
plt.title('Output {} with parameter {}'.format(components,
fixed_variables))
plt.savefig(filename) plt.savefig(filename)
else: else:
plt.show() plt.show()
@@ -193,14 +197,14 @@ class Plotter:
""" """
Plot the loss function values during traininig. Plot the loss function values during traininig.
:param SolverInterface solver: the SolverInterface object. :param Trainer trainer: the PINA Trainer object instance.
:param str/list(str) metric: The metrics to use in the y axis. If None, the mean loss :param str/list(str) metric: The metrics to use in the y axis. If None, the mean loss
is plotted. is plotted.
:param bool logy: If True, the y axis is in log scale. Default is :param bool logy: If True, the y axis is in log scale. Default is
True. True.
:param bool logx: If True, the x axis is in log scale. Default is :param bool logx: If True, the x axis is in log scale. Default is
True. True.
:param str filename: the file name to save the plot. If None, the plot :param str filename: The file name to save the plot. If None, the plot
is shown using the setted matplotlib frontend. Default is None. is shown using the setted matplotlib frontend. Default is None.
""" """