plotter update

This commit is contained in:
Dario Coscia
2024-01-16 11:38:39 +01:00
committed by Nicola Demo
parent 1464eee891
commit 927dbcf91e

View File

@@ -11,14 +11,16 @@ class Plotter:
Implementation of a plotter class, for easy visualizations. Implementation of a plotter class, for easy visualizations.
""" """
def plot_samples(self, problem, variables=None, **kwargs): def plot_samples(self, problem, variables=None, filename=None, **kwargs):
""" """
Plot the training grid samples. Plot the training grid samples.
:param SolverInterface solver: The ``SolverInterface`` object. :param AbstractProblem problem: The PINA problem from where to plot the domain.
:param list(str) variables: Variables to plot. If None, all variables :param list(str) variables: Variables to plot. If None, all variables
are plotted. If 'spatial', only spatial variables are plotted. If are plotted. If 'spatial', only spatial variables are plotted. If
'temporal', only temporal variables are plotted. Defaults to None. 'temporal', only temporal variables are plotted. Defaults to None.
:param str filename: The file name to save the plot. If None, the plot
is shown using the setted matplotlib frontend. Default is None.
.. todo:: .. todo::
- Add support for 3D plots. - Add support for 3D plots.
@@ -65,15 +67,21 @@ class Plotter:
pass pass
plt.legend() plt.legend()
plt.show() if filename:
plt.savefig(filename)
plt.close()
else:
plt.show()
def _1d_plot(self, pts, pred, method, truth_solution=None, **kwargs): def _1d_plot(self, pts, pred, v, method, truth_solution=None, **kwargs):
"""Plot solution for one dimensional function """Plot solution for one dimensional function
:param pts: Points to plot the solution. :param pts: Points to plot the solution.
:type pts: torch.Tensor :type pts: torch.Tensor
:param pred: SolverInterface solution evaluated at 'pts'. :param pred: SolverInterface solution evaluated at 'pts'.
:type pred: torch.Tensor :type pred: torch.Tensor
:param v: Fixed variables when plotting the solution.
:type v: torch.Tensor
:param method: Not used, kept for code compatibility :param method: Not used, kept for code compatibility
:type method: None :type method: None
:param truth_solution: Real solution evaluated at 'pts', :param truth_solution: Real solution evaluated at 'pts',
@@ -82,18 +90,17 @@ class Plotter:
""" """
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 8)) fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 8))
ax.plot(pts, pred.detach(), label='Neural Network solution', **kwargs) ax.plot(pts.extract(v), pred, label='Neural Network solution', **kwargs)
if truth_solution: if truth_solution:
truth_output = truth_solution(pts).float() truth_output = truth_solution(pts).detach()
ax.plot(pts, truth_output.detach(), label='True solution', **kwargs) ax.plot(pts.extract(v), truth_output, label='True solution', **kwargs)
# TODO: pred is a torch.Tensor, so no labels is available # TODO: pred is a torch.Tensor, so no labels is available
# extra variable for labels should be # extra variable for labels should be
# passed in the function arguments. # passed in the function arguments.
# plt.ylabel(pred.labels[0]) # plt.ylabel(pred.labels[0])
plt.legend() plt.legend()
plt.show()
def _2d_plot(self, def _2d_plot(self,
pts, pts,
@@ -109,6 +116,8 @@ class Plotter:
:type pts: torch.Tensor :type pts: torch.Tensor
:param pred: ``SolverInterface`` solution evaluated at 'pts'. :param pred: ``SolverInterface`` solution evaluated at 'pts'.
:type pred: torch.Tensor :type pred: torch.Tensor
:param v: Fixed variables when plotting the solution.
:type v: torch.Tensor
:param method: Matplotlib method to plot 2-dimensional data, :param method: Matplotlib method to plot 2-dimensional data,
see https://matplotlib.org/stable/api/axes_api.html for see https://matplotlib.org/stable/api/axes_api.html for
reference. reference.
@@ -118,30 +127,30 @@ class Plotter:
:type truth_solution: torch.Tensor, optional :type truth_solution: torch.Tensor, optional
""" """
grids = [p_.reshape(res, res) for p_ in pts.extract(v).cpu().T] grids = [p_.reshape(res, res) for p_ in pts.extract(v).T]
pred_output = pred.reshape(res, res) pred_output = pred.reshape(res, res)
if truth_solution: if truth_solution:
truth_output = truth_solution(pts).float().reshape(res, res).as_subclass(torch.Tensor) truth_output = truth_solution(pts).float().reshape(res, res).as_subclass(torch.Tensor)
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(16, 6)) fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
cb = getattr(ax[0], method)(*grids, pred_output.cpu().detach(), cb = getattr(ax[0], method)(*grids, pred_output,
**kwargs) **kwargs)
fig.colorbar(cb, ax=ax[0]) fig.colorbar(cb, ax=ax[0])
ax[0].title.set_text('Neural Network prediction') ax[0].title.set_text('Neural Network prediction')
cb = getattr(ax[1], method)(*grids, truth_output.cpu().detach(), cb = getattr(ax[1], method)(*grids, truth_output,
**kwargs) **kwargs)
fig.colorbar(cb, ax=ax[1]) fig.colorbar(cb, ax=ax[1])
ax[1].title.set_text('True solution') ax[1].title.set_text('True solution')
cb = getattr(ax[2], cb = getattr(ax[2],
method)(*grids, method)(*grids,
(truth_output - pred_output).cpu().detach(), (truth_output - pred_output),
**kwargs) **kwargs)
fig.colorbar(cb, ax=ax[2]) fig.colorbar(cb, ax=ax[2])
ax[2].title.set_text('Residual') ax[2].title.set_text('Residual')
else: else:
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 6)) fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
cb = getattr(ax, method)(*grids, pred_output.cpu().detach(), cb = getattr(ax, method)(*grids, pred_output,
**kwargs) **kwargs)
fig.colorbar(cb, ax=ax) fig.colorbar(cb, ax=ax)
ax.title.set_text('Neural Network prediction') ax.title.set_text('Neural Network prediction')
@@ -201,11 +210,13 @@ class Plotter:
pts = pts.append(fixed_pts) pts = pts.append(fixed_pts)
pts = pts.to(device=solver.device) pts = pts.to(device=solver.device)
predicted_output = solver.forward(pts).extract(components).as_subclass(torch.Tensor) # computing soluting and sending to cpu
predicted_output = solver.forward(pts).extract(components).as_subclass(torch.Tensor).cpu().detach()
pts = pts.cpu()
truth_solution = getattr(solver.problem, 'truth_solution', None) truth_solution = getattr(solver.problem, 'truth_solution', None)
if len(v) == 1: if len(v) == 1:
self._1d_plot(pts.extract(v), predicted_output, method, truth_solution, self._1d_plot(pts, predicted_output, v, method, truth_solution,
**kwargs) **kwargs)
elif len(v) == 2: elif len(v) == 2:
self._2d_plot(pts, predicted_output, v, res, method, truth_solution, self._2d_plot(pts, predicted_output, v, res, method, truth_solution,
@@ -214,6 +225,7 @@ class Plotter:
plt.tight_layout() plt.tight_layout()
if filename: if filename:
plt.savefig(filename) plt.savefig(filename)
plt.close()
else: else:
plt.show() plt.show()
@@ -280,3 +292,4 @@ class Plotter:
# saving in file # saving in file
if filename: if filename:
plt.savefig(filename) plt.savefig(filename)
plt.close()