plotter update
This commit is contained in:
committed by
Nicola Demo
parent
1464eee891
commit
927dbcf91e
@@ -11,14 +11,16 @@ class Plotter:
|
||||
Implementation of a plotter class, for easy visualizations.
|
||||
"""
|
||||
|
||||
def plot_samples(self, problem, variables=None, **kwargs):
|
||||
def plot_samples(self, problem, variables=None, filename=None, **kwargs):
|
||||
"""
|
||||
Plot the training grid samples.
|
||||
|
||||
:param SolverInterface solver: The ``SolverInterface`` object.
|
||||
:param AbstractProblem problem: The PINA problem from where to plot the domain.
|
||||
:param list(str) variables: Variables to plot. If None, all variables
|
||||
are plotted. If 'spatial', only spatial variables are plotted. If
|
||||
'temporal', only temporal variables are plotted. Defaults to None.
|
||||
:param str filename: The file name to save the plot. If None, the plot
|
||||
is shown using the setted matplotlib frontend. Default is None.
|
||||
|
||||
.. todo::
|
||||
- Add support for 3D plots.
|
||||
@@ -65,15 +67,21 @@ class Plotter:
|
||||
pass
|
||||
|
||||
plt.legend()
|
||||
plt.show()
|
||||
if filename:
|
||||
plt.savefig(filename)
|
||||
plt.close()
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
def _1d_plot(self, pts, pred, method, truth_solution=None, **kwargs):
|
||||
def _1d_plot(self, pts, pred, v, method, truth_solution=None, **kwargs):
|
||||
"""Plot solution for one dimensional function
|
||||
|
||||
:param pts: Points to plot the solution.
|
||||
:type pts: torch.Tensor
|
||||
:param pred: SolverInterface solution evaluated at 'pts'.
|
||||
:type pred: torch.Tensor
|
||||
:param v: Fixed variables when plotting the solution.
|
||||
:type v: torch.Tensor
|
||||
:param method: Not used, kept for code compatibility
|
||||
:type method: None
|
||||
:param truth_solution: Real solution evaluated at 'pts',
|
||||
@@ -82,18 +90,17 @@ class Plotter:
|
||||
"""
|
||||
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 8))
|
||||
|
||||
ax.plot(pts, pred.detach(), label='Neural Network solution', **kwargs)
|
||||
ax.plot(pts.extract(v), pred, label='Neural Network solution', **kwargs)
|
||||
|
||||
if truth_solution:
|
||||
truth_output = truth_solution(pts).float()
|
||||
ax.plot(pts, truth_output.detach(), label='True solution', **kwargs)
|
||||
truth_output = truth_solution(pts).detach()
|
||||
ax.plot(pts.extract(v), truth_output, label='True solution', **kwargs)
|
||||
|
||||
# TODO: pred is a torch.Tensor, so no labels is available
|
||||
# extra variable for labels should be
|
||||
# passed in the function arguments.
|
||||
# plt.ylabel(pred.labels[0])
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
def _2d_plot(self,
|
||||
pts,
|
||||
@@ -109,6 +116,8 @@ class Plotter:
|
||||
:type pts: torch.Tensor
|
||||
:param pred: ``SolverInterface`` solution evaluated at 'pts'.
|
||||
:type pred: torch.Tensor
|
||||
:param v: Fixed variables when plotting the solution.
|
||||
:type v: torch.Tensor
|
||||
:param method: Matplotlib method to plot 2-dimensional data,
|
||||
see https://matplotlib.org/stable/api/axes_api.html for
|
||||
reference.
|
||||
@@ -118,30 +127,30 @@ class Plotter:
|
||||
:type truth_solution: torch.Tensor, optional
|
||||
"""
|
||||
|
||||
grids = [p_.reshape(res, res) for p_ in pts.extract(v).cpu().T]
|
||||
grids = [p_.reshape(res, res) for p_ in pts.extract(v).T]
|
||||
|
||||
pred_output = pred.reshape(res, res)
|
||||
if truth_solution:
|
||||
truth_output = truth_solution(pts).float().reshape(res, res).as_subclass(torch.Tensor)
|
||||
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
|
||||
|
||||
cb = getattr(ax[0], method)(*grids, pred_output.cpu().detach(),
|
||||
cb = getattr(ax[0], method)(*grids, pred_output,
|
||||
**kwargs)
|
||||
fig.colorbar(cb, ax=ax[0])
|
||||
ax[0].title.set_text('Neural Network prediction')
|
||||
cb = getattr(ax[1], method)(*grids, truth_output.cpu().detach(),
|
||||
cb = getattr(ax[1], method)(*grids, truth_output,
|
||||
**kwargs)
|
||||
fig.colorbar(cb, ax=ax[1])
|
||||
ax[1].title.set_text('True solution')
|
||||
cb = getattr(ax[2],
|
||||
method)(*grids,
|
||||
(truth_output - pred_output).cpu().detach(),
|
||||
(truth_output - pred_output),
|
||||
**kwargs)
|
||||
fig.colorbar(cb, ax=ax[2])
|
||||
ax[2].title.set_text('Residual')
|
||||
else:
|
||||
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
|
||||
cb = getattr(ax, method)(*grids, pred_output.cpu().detach(),
|
||||
cb = getattr(ax, method)(*grids, pred_output,
|
||||
**kwargs)
|
||||
fig.colorbar(cb, ax=ax)
|
||||
ax.title.set_text('Neural Network prediction')
|
||||
@@ -201,11 +210,13 @@ class Plotter:
|
||||
pts = pts.append(fixed_pts)
|
||||
pts = pts.to(device=solver.device)
|
||||
|
||||
predicted_output = solver.forward(pts).extract(components).as_subclass(torch.Tensor)
|
||||
# computing soluting and sending to cpu
|
||||
predicted_output = solver.forward(pts).extract(components).as_subclass(torch.Tensor).cpu().detach()
|
||||
pts = pts.cpu()
|
||||
truth_solution = getattr(solver.problem, 'truth_solution', None)
|
||||
|
||||
if len(v) == 1:
|
||||
self._1d_plot(pts.extract(v), predicted_output, method, truth_solution,
|
||||
self._1d_plot(pts, predicted_output, v, method, truth_solution,
|
||||
**kwargs)
|
||||
elif len(v) == 2:
|
||||
self._2d_plot(pts, predicted_output, v, res, method, truth_solution,
|
||||
@@ -214,6 +225,7 @@ class Plotter:
|
||||
plt.tight_layout()
|
||||
if filename:
|
||||
plt.savefig(filename)
|
||||
plt.close()
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
@@ -280,3 +292,4 @@ class Plotter:
|
||||
# saving in file
|
||||
if filename:
|
||||
plt.savefig(filename)
|
||||
plt.close()
|
||||
|
||||
Reference in New Issue
Block a user