From 927dbcf91e98dcf72ff9f4d6948f8e7835c3d55a Mon Sep 17 00:00:00 2001 From: Dario Coscia Date: Tue, 16 Jan 2024 11:38:39 +0100 Subject: [PATCH] plotter update --- pina/plotter.py | 43 ++++++++++++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/pina/plotter.py b/pina/plotter.py index ee6bfca..19823e5 100644 --- a/pina/plotter.py +++ b/pina/plotter.py @@ -11,14 +11,16 @@ class Plotter: Implementation of a plotter class, for easy visualizations. """ - def plot_samples(self, problem, variables=None, **kwargs): + def plot_samples(self, problem, variables=None, filename=None, **kwargs): """ Plot the training grid samples. - :param SolverInterface solver: The ``SolverInterface`` object. + :param AbstractProblem problem: The PINA problem from where to plot the domain. :param list(str) variables: Variables to plot. If None, all variables are plotted. If 'spatial', only spatial variables are plotted. If 'temporal', only temporal variables are plotted. Defaults to None. + :param str filename: The file name to save the plot. If None, the plot + is shown using the setted matplotlib frontend. Default is None. .. todo:: - Add support for 3D plots. @@ -65,15 +67,21 @@ class Plotter: pass plt.legend() - plt.show() + if filename: + plt.savefig(filename) + plt.close() + else: + plt.show() - def _1d_plot(self, pts, pred, method, truth_solution=None, **kwargs): + def _1d_plot(self, pts, pred, v, method, truth_solution=None, **kwargs): """Plot solution for one dimensional function :param pts: Points to plot the solution. :type pts: torch.Tensor :param pred: SolverInterface solution evaluated at 'pts'. :type pred: torch.Tensor + :param v: Fixed variables when plotting the solution. + :type v: torch.Tensor :param method: Not used, kept for code compatibility :type method: None :param truth_solution: Real solution evaluated at 'pts', @@ -82,18 +90,17 @@ class Plotter: """ fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 8)) - ax.plot(pts, pred.detach(), label='Neural Network solution', **kwargs) + ax.plot(pts.extract(v), pred, label='Neural Network solution', **kwargs) if truth_solution: - truth_output = truth_solution(pts).float() - ax.plot(pts, truth_output.detach(), label='True solution', **kwargs) + truth_output = truth_solution(pts).detach() + ax.plot(pts.extract(v), truth_output, label='True solution', **kwargs) # TODO: pred is a torch.Tensor, so no labels is available # extra variable for labels should be # passed in the function arguments. # plt.ylabel(pred.labels[0]) plt.legend() - plt.show() def _2d_plot(self, pts, @@ -109,6 +116,8 @@ class Plotter: :type pts: torch.Tensor :param pred: ``SolverInterface`` solution evaluated at 'pts'. :type pred: torch.Tensor + :param v: Fixed variables when plotting the solution. + :type v: torch.Tensor :param method: Matplotlib method to plot 2-dimensional data, see https://matplotlib.org/stable/api/axes_api.html for reference. @@ -118,30 +127,30 @@ class Plotter: :type truth_solution: torch.Tensor, optional """ - grids = [p_.reshape(res, res) for p_ in pts.extract(v).cpu().T] + grids = [p_.reshape(res, res) for p_ in pts.extract(v).T] pred_output = pred.reshape(res, res) if truth_solution: truth_output = truth_solution(pts).float().reshape(res, res).as_subclass(torch.Tensor) fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(16, 6)) - cb = getattr(ax[0], method)(*grids, pred_output.cpu().detach(), + cb = getattr(ax[0], method)(*grids, pred_output, **kwargs) fig.colorbar(cb, ax=ax[0]) ax[0].title.set_text('Neural Network prediction') - cb = getattr(ax[1], method)(*grids, truth_output.cpu().detach(), + cb = getattr(ax[1], method)(*grids, truth_output, **kwargs) fig.colorbar(cb, ax=ax[1]) ax[1].title.set_text('True solution') cb = getattr(ax[2], method)(*grids, - (truth_output - pred_output).cpu().detach(), + (truth_output - pred_output), **kwargs) fig.colorbar(cb, ax=ax[2]) ax[2].title.set_text('Residual') else: fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 6)) - cb = getattr(ax, method)(*grids, pred_output.cpu().detach(), + cb = getattr(ax, method)(*grids, pred_output, **kwargs) fig.colorbar(cb, ax=ax) ax.title.set_text('Neural Network prediction') @@ -201,11 +210,13 @@ class Plotter: pts = pts.append(fixed_pts) pts = pts.to(device=solver.device) - predicted_output = solver.forward(pts).extract(components).as_subclass(torch.Tensor) + # computing soluting and sending to cpu + predicted_output = solver.forward(pts).extract(components).as_subclass(torch.Tensor).cpu().detach() + pts = pts.cpu() truth_solution = getattr(solver.problem, 'truth_solution', None) if len(v) == 1: - self._1d_plot(pts.extract(v), predicted_output, method, truth_solution, + self._1d_plot(pts, predicted_output, v, method, truth_solution, **kwargs) elif len(v) == 2: self._2d_plot(pts, predicted_output, v, res, method, truth_solution, @@ -214,6 +225,7 @@ class Plotter: plt.tight_layout() if filename: plt.savefig(filename) + plt.close() else: plt.show() @@ -280,3 +292,4 @@ class Plotter: # saving in file if filename: plt.savefig(filename) + plt.close()