126 lines
4.0 KiB
Python
126 lines
4.0 KiB
Python
""" Module for plotting. """
|
|
import matplotlib
|
|
#matplotlib.use('Qt5Agg')
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import torch
|
|
|
|
from pina import LabelTensor
|
|
from pina import PINN
|
|
from .problem import SpatialProblem, TimeDependentProblem
|
|
#from pina.tdproblem1d import TimeDepProblem1D
|
|
|
|
|
|
class Plotter:
|
|
|
|
def plot_samples(self, pinn, variables=None):
|
|
|
|
if variables is None:
|
|
variables = pinn.problem.domain.variables
|
|
elif variables == 'spatial':
|
|
variables = pinn.problem.spatial_domain.variables
|
|
elif variables == 'temporal':
|
|
variables = pinn.problem.temporal_domain.variables
|
|
|
|
if len(variables) not in [1, 2, 3]:
|
|
raise ValueError
|
|
|
|
fig = plt.figure()
|
|
proj = '3d' if len(variables) == 3 else None
|
|
ax = fig.add_subplot(projection=proj)
|
|
for location in pinn.input_pts:
|
|
coords = pinn.input_pts[location].extract(variables).T.detach()
|
|
if coords.shape[0] == 1: # 1D samples
|
|
ax.plot(coords[0], torch.zeros(coords[0].shape), '.',
|
|
label=location)
|
|
else:
|
|
ax.plot(*coords, '.', label=location)
|
|
|
|
ax.set_xlabel(variables[0])
|
|
try:
|
|
ax.set_ylabel(variables[1])
|
|
except:
|
|
pass
|
|
|
|
try:
|
|
ax.set_zlabel(variables[2])
|
|
except:
|
|
pass
|
|
|
|
plt.legend()
|
|
plt.show()
|
|
|
|
def _1d_plot(self, pts, pred, method, truth_solution=None):
|
|
"""
|
|
"""
|
|
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 8))
|
|
|
|
ax.plot(pts, pred.detach())
|
|
|
|
if truth_solution:
|
|
truth_output = truth_solution(pts).float()
|
|
ax.plot(pts, truth_output.detach())
|
|
|
|
plt.xlabel(pts.labels[0])
|
|
plt.ylabel(pred.labels[0])
|
|
plt.show()
|
|
|
|
def _2d_plot(self, pts, pred, v, res, method, truth_solution=None):
|
|
"""
|
|
"""
|
|
|
|
grids = [p_.reshape(res, res) for p_ in pts.extract(v).T]
|
|
|
|
pred_output = pred.reshape(res, res)
|
|
if truth_solution:
|
|
truth_output = truth_solution(*pts.T).float().reshape(res, res)
|
|
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
|
|
|
|
cb = getattr(ax[0], method)(*grids, pred_output.detach())
|
|
fig.colorbar(cb, ax=ax[0])
|
|
cb = getattr(ax[1], method)(*grids, truth_output.detach())
|
|
fig.colorbar(cb, ax=ax[1])
|
|
cb = getattr(ax[2], method)(*grids,
|
|
(truth_output-pred_output).detach())
|
|
fig.colorbar(cb, ax=ax[2])
|
|
else:
|
|
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
|
|
cb = getattr(ax, method)(*grids, pred_output.detach())
|
|
fig.colorbar(cb, ax=ax)
|
|
|
|
|
|
def plot(self, pinn, components, fixed_variables={}, method='contourf',
|
|
res=256, filename=None):
|
|
"""
|
|
"""
|
|
|
|
v = [
|
|
var for var in pinn.problem.input_variables
|
|
if var not in fixed_variables.keys()
|
|
]
|
|
pts = pinn.problem.domain.sample(res, 'grid', variables=v)
|
|
|
|
for variable, value in fixed_variables.items():
|
|
new = LabelTensor(torch.ones(pts.shape[0], 1)*value, [variable])
|
|
pts = pts.append(new)
|
|
|
|
predicted_output = pinn.model(pts)
|
|
if isinstance(components, str):
|
|
predicted_output = predicted_output.extract(components)
|
|
elif callable(components):
|
|
predicted_output = components(predicted_output)
|
|
|
|
truth_solution = getattr(pinn.problem, 'truth_solution', None)
|
|
if len(v) == 1:
|
|
self._1d_plot(pts, predicted_output, method, truth_solution)
|
|
elif len(v) == 2:
|
|
self._2d_plot(pts, predicted_output, v, res, method,
|
|
truth_solution)
|
|
|
|
if filename:
|
|
plt.title('Output {} with parameter {}'.format(components,
|
|
fixed_variables))
|
|
plt.savefig(filename)
|
|
else:
|
|
plt.show()
|