plotter update

This commit is contained in:
Dario Coscia
2024-01-16 11:38:39 +01:00
committed by Nicola Demo
parent 1464eee891
commit 927dbcf91e

View File

@@ -11,14 +11,16 @@ class Plotter:
Implementation of a plotter class, for easy visualizations.
"""
def plot_samples(self, problem, variables=None, **kwargs):
def plot_samples(self, problem, variables=None, filename=None, **kwargs):
"""
Plot the training grid samples.
:param SolverInterface solver: The ``SolverInterface`` object.
:param AbstractProblem problem: The PINA problem from where to plot the domain.
:param list(str) variables: Variables to plot. If None, all variables
are plotted. If 'spatial', only spatial variables are plotted. If
'temporal', only temporal variables are plotted. Defaults to None.
:param str filename: The file name to save the plot. If None, the plot
is shown using the setted matplotlib frontend. Default is None.
.. todo::
- Add support for 3D plots.
@@ -65,15 +67,21 @@ class Plotter:
pass
plt.legend()
plt.show()
if filename:
plt.savefig(filename)
plt.close()
else:
plt.show()
def _1d_plot(self, pts, pred, method, truth_solution=None, **kwargs):
def _1d_plot(self, pts, pred, v, method, truth_solution=None, **kwargs):
"""Plot solution for one dimensional function
:param pts: Points to plot the solution.
:type pts: torch.Tensor
:param pred: SolverInterface solution evaluated at 'pts'.
:type pred: torch.Tensor
:param v: Fixed variables when plotting the solution.
:type v: torch.Tensor
:param method: Not used, kept for code compatibility
:type method: None
:param truth_solution: Real solution evaluated at 'pts',
@@ -82,18 +90,17 @@ class Plotter:
"""
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 8))
ax.plot(pts, pred.detach(), label='Neural Network solution', **kwargs)
ax.plot(pts.extract(v), pred, label='Neural Network solution', **kwargs)
if truth_solution:
truth_output = truth_solution(pts).float()
ax.plot(pts, truth_output.detach(), label='True solution', **kwargs)
truth_output = truth_solution(pts).detach()
ax.plot(pts.extract(v), truth_output, label='True solution', **kwargs)
# TODO: pred is a torch.Tensor, so no labels is available
# extra variable for labels should be
# passed in the function arguments.
# plt.ylabel(pred.labels[0])
plt.legend()
plt.show()
def _2d_plot(self,
pts,
@@ -109,6 +116,8 @@ class Plotter:
:type pts: torch.Tensor
:param pred: ``SolverInterface`` solution evaluated at 'pts'.
:type pred: torch.Tensor
:param v: Fixed variables when plotting the solution.
:type v: torch.Tensor
:param method: Matplotlib method to plot 2-dimensional data,
see https://matplotlib.org/stable/api/axes_api.html for
reference.
@@ -118,30 +127,30 @@ class Plotter:
:type truth_solution: torch.Tensor, optional
"""
grids = [p_.reshape(res, res) for p_ in pts.extract(v).cpu().T]
grids = [p_.reshape(res, res) for p_ in pts.extract(v).T]
pred_output = pred.reshape(res, res)
if truth_solution:
truth_output = truth_solution(pts).float().reshape(res, res).as_subclass(torch.Tensor)
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
cb = getattr(ax[0], method)(*grids, pred_output.cpu().detach(),
cb = getattr(ax[0], method)(*grids, pred_output,
**kwargs)
fig.colorbar(cb, ax=ax[0])
ax[0].title.set_text('Neural Network prediction')
cb = getattr(ax[1], method)(*grids, truth_output.cpu().detach(),
cb = getattr(ax[1], method)(*grids, truth_output,
**kwargs)
fig.colorbar(cb, ax=ax[1])
ax[1].title.set_text('True solution')
cb = getattr(ax[2],
method)(*grids,
(truth_output - pred_output).cpu().detach(),
(truth_output - pred_output),
**kwargs)
fig.colorbar(cb, ax=ax[2])
ax[2].title.set_text('Residual')
else:
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
cb = getattr(ax, method)(*grids, pred_output.cpu().detach(),
cb = getattr(ax, method)(*grids, pred_output,
**kwargs)
fig.colorbar(cb, ax=ax)
ax.title.set_text('Neural Network prediction')
@@ -201,11 +210,13 @@ class Plotter:
pts = pts.append(fixed_pts)
pts = pts.to(device=solver.device)
predicted_output = solver.forward(pts).extract(components).as_subclass(torch.Tensor)
# computing soluting and sending to cpu
predicted_output = solver.forward(pts).extract(components).as_subclass(torch.Tensor).cpu().detach()
pts = pts.cpu()
truth_solution = getattr(solver.problem, 'truth_solution', None)
if len(v) == 1:
self._1d_plot(pts.extract(v), predicted_output, method, truth_solution,
self._1d_plot(pts, predicted_output, v, method, truth_solution,
**kwargs)
elif len(v) == 2:
self._2d_plot(pts, predicted_output, v, res, method, truth_solution,
@@ -214,6 +225,7 @@ class Plotter:
plt.tight_layout()
if filename:
plt.savefig(filename)
plt.close()
else:
plt.show()
@@ -280,3 +292,4 @@ class Plotter:
# saving in file
if filename:
plt.savefig(filename)
plt.close()