Files
PINA/pina/plotter.py
2022-03-07 10:09:40 +01:00

160 lines
6.3 KiB
Python

""" Module for plotting. """
import matplotlib
#matplotlib.use('Qt5Agg')
import matplotlib.pyplot as plt
import numpy as np
import torch
from pina import LabelTensor
from pina import PINN
from .problem import SpatialProblem, TimeDependentProblem
#from pina.tdproblem1d import TimeDepProblem1D
class Plotter:
def _plot_2D(self, obj, method='contourf'):
"""
"""
if not isinstance(obj, PINN):
raise RuntimeError
res = 256
pts = obj.problem.spatial_domain.discretize(res, 'grid')
grids_container = [
pts[:, 0].reshape(res, res),
pts[:, 1].reshape(res, res),
]
pts = LabelTensor(torch.tensor(pts), obj.problem.input_variables)
predicted_output = obj.model(pts.tensor)
if hasattr(obj.problem, 'truth_solution'):
truth_output = obj.problem.truth_solution(*pts.tensor.T).float()
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
cb = getattr(axes[0], method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
fig.colorbar(cb, ax=axes[0])
cb = getattr(axes[1], method)(*grids_container, truth_output.reshape(res, res).detach())
fig.colorbar(cb, ax=axes[1])
cb = getattr(axes[2], method)(*grids_container, (truth_output-predicted_output.tensor.float().flatten()).detach().reshape(res, res))
fig.colorbar(cb, ax=axes[2])
else:
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
cb = getattr(axes, method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
fig.colorbar(cb, ax=axes)
def _plot_1D_TimeDep(self, obj, method='contourf'):
"""
"""
if not isinstance(obj, PINN):
raise RuntimeError
res = 256
grids_container = np.meshgrid(
obj.problem.spatial_domain.discretize(res, 'grid'),
obj.problem.temporal_domain.discretize(res, 'grid'),
)
pts = np.hstack([
grids_container[0].reshape(-1, 1),
grids_container[1].reshape(-1, 1),
])
pts = LabelTensor(torch.tensor(pts), obj.problem.input_variables)
predicted_output = obj.model(pts.tensor)
if hasattr(obj.problem, 'truth_solution'):
truth_output = obj.problem.truth_solution(*pts.tensor.T).float()
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
cb = getattr(axes[0], method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
fig.colorbar(cb, ax=axes[0])
cb = getattr(axes[1], method)(*grids_container, truth_output.reshape(res, res).detach())
fig.colorbar(cb, ax=axes[1])
cb = getattr(axes[2], method)(*grids_container, (truth_output-predicted_output.tensor.float().flatten()).detach().reshape(res, res))
fig.colorbar(cb, ax=axes[2])
else:
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
cb = getattr(axes, method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
fig.colorbar(cb, ax=axes)
def plot(self, obj, method='contourf', filename=None):
"""
"""
res = 256
pts = obj.problem.domain.sample(res, 'grid')
print(pts)
grids_container = [
pts.tensor[:, 0].reshape(res, res),
pts.tensor[:, 1].reshape(res, res),
]
predicted_output = obj.model(pts)
predicted_output = predicted_output['p']
if hasattr(obj.problem, 'truth_solution'):
truth_output = obj.problem.truth_solution(*pts.tensor.T).float()
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
cb = getattr(axes[0], method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
fig.colorbar(cb, ax=axes[0])
cb = getattr(axes[1], method)(*grids_container, truth_output.reshape(res, res).detach())
fig.colorbar(cb, ax=axes[1])
cb = getattr(axes[2], method)(*grids_container, (truth_output-predicted_output.tensor.float().flatten()).detach().reshape(res, res))
fig.colorbar(cb, ax=axes[2])
else:
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
# cb = getattr(axes, method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
cb = getattr(axes, method)(*grids_container, predicted_output.reshape(res, res).detach())
fig.colorbar(cb, ax=axes)
if filename:
plt.savefig(filename)
else:
plt.show()
def plot(self, obj, method='contourf', filename=None):
"""
"""
res = 256
pts = obj.problem.domain.sample(res, 'grid')
print(pts)
grids_container = [
pts.tensor[:, 0].reshape(res, res),
pts.tensor[:, 1].reshape(res, res),
]
predicted_output = obj.model(pts)
predicted_output = predicted_output['ux']
if hasattr(obj.problem, 'truth_solution'):
truth_output = obj.problem.truth_solution(*pts.tensor.T).float()
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
cb = getattr(axes[0], method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
fig.colorbar(cb, ax=axes[0])
cb = getattr(axes[1], method)(*grids_container, truth_output.reshape(res, res).detach())
fig.colorbar(cb, ax=axes[1])
cb = getattr(axes[2], method)(*grids_container, (truth_output-predicted_output.tensor.float().flatten()).detach().reshape(res, res))
fig.colorbar(cb, ax=axes[2])
else:
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
# cb = getattr(axes, method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
cb = getattr(axes, method)(*grids_container, predicted_output.reshape(res, res).detach())
fig.colorbar(cb, ax=axes)
if filename:
plt.savefig(filename)
else:
plt.show()
def plot_samples(self, obj):
for location in obj.input_pts:
plt.plot(*obj.input_pts[location].tensor.T.detach(), '.', label=location)
plt.legend()
plt.show()