""" Module for plotting. """ import matplotlib matplotlib.use('Qt5Agg') import matplotlib.pyplot as plt import numpy as np import torch from pina import LabelTensor from pina import PINN from .problem import SpatialProblem, TimeDependentProblem #from pina.tdproblem1d import TimeDepProblem1D class Plotter: def _plot_2D(self, obj, method='contourf'): """ """ if not isinstance(obj, PINN): raise RuntimeError res = 256 pts = obj.problem.spatial_domain.discretize(res, 'grid') grids_container = [ pts[:, 0].reshape(res, res), pts[:, 1].reshape(res, res), ] pts = LabelTensor(torch.tensor(pts), obj.problem.input_variables) predicted_output = obj.model(pts.tensor) if hasattr(obj.problem, 'truth_solution'): truth_output = obj.problem.truth_solution(*pts.tensor.T).float() fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6)) cb = getattr(axes[0], method)(*grids_container, predicted_output.tensor.reshape(res, res).detach()) fig.colorbar(cb, ax=axes[0]) cb = getattr(axes[1], method)(*grids_container, truth_output.reshape(res, res).detach()) fig.colorbar(cb, ax=axes[1]) cb = getattr(axes[2], method)(*grids_container, (truth_output-predicted_output.tensor.float().flatten()).detach().reshape(res, res)) fig.colorbar(cb, ax=axes[2]) else: fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6)) cb = getattr(axes, method)(*grids_container, predicted_output.tensor.reshape(res, res).detach()) fig.colorbar(cb, ax=axes) def _plot_1D_TimeDep(self, obj, method='contourf'): """ """ if not isinstance(obj, PINN): raise RuntimeError res = 256 grids_container = np.meshgrid( obj.problem.spatial_domain.discretize(res, 'grid'), obj.problem.temporal_domain.discretize(res, 'grid'), ) pts = np.hstack([ grids_container[0].reshape(-1, 1), grids_container[1].reshape(-1, 1), ]) pts = LabelTensor(torch.tensor(pts), obj.problem.input_variables) predicted_output = obj.model(pts.tensor) if hasattr(obj.problem, 'truth_solution'): truth_output = obj.problem.truth_solution(*pts.tensor.T).float() fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6)) cb = getattr(axes[0], method)(*grids_container, predicted_output.tensor.reshape(res, res).detach()) fig.colorbar(cb, ax=axes[0]) cb = getattr(axes[1], method)(*grids_container, truth_output.reshape(res, res).detach()) fig.colorbar(cb, ax=axes[1]) cb = getattr(axes[2], method)(*grids_container, (truth_output-predicted_output.tensor.float().flatten()).detach().reshape(res, res)) fig.colorbar(cb, ax=axes[2]) else: fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6)) cb = getattr(axes, method)(*grids_container, predicted_output.tensor.reshape(res, res).detach()) fig.colorbar(cb, ax=axes) def plot(self, obj, method='contourf', filename=None): """ """ res = 256 pts = obj.problem.domain.sample(res, 'grid') grids_container = [ pts[:, 0].reshape(res, res), pts[:, 1].reshape(res, res), ] predicted_output = obj.model(pts) if hasattr(obj.problem, 'truth_solution'): truth_output = obj.problem.truth_solution(*pts.tensor.T).float() fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6)) cb = getattr(axes[0], method)(*grids_container, predicted_output.tensor.reshape(res, res).detach()) fig.colorbar(cb, ax=axes[0]) cb = getattr(axes[1], method)(*grids_container, truth_output.reshape(res, res).detach()) fig.colorbar(cb, ax=axes[1]) cb = getattr(axes[2], method)(*grids_container, (truth_output-predicted_output.tensor.float().flatten()).detach().reshape(res, res)) fig.colorbar(cb, ax=axes[2]) else: fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6)) cb = getattr(axes, method)(*grids_container, predicted_output.tensor.reshape(res, res).detach()) fig.colorbar(cb, ax=axes) if filename: plt.savefig(filename) else: plt.show()