Refactoring code

This commit is contained in:
Your Name
2022-01-27 14:55:42 +01:00
parent fb16fc7f3a
commit fa8ffd5042
32 changed files with 417 additions and 442 deletions

93
pina/plotter.py Normal file
View File

@@ -0,0 +1,93 @@
""" Module for plotting. """
import matplotlib
matplotlib.use('Qt5Agg')
import matplotlib.pyplot as plt
import numpy as np
import torch
from pina import LabelTensor
from pina import PINN
from .problem import Problem2D, Problem1D, TimeDependentProblem
#from pina.tdproblem1d import TimeDepProblem1D
class Plotter:
def _plot_2D(self, obj, method='contourf'):
"""
"""
if not isinstance(obj, PINN):
raise RuntimeError
res = 256
pts = obj.problem.spatial_domain.discretize(res, 'grid')
grids_container = [
pts[:, 0].reshape(res, res),
pts[:, 1].reshape(res, res),
]
pts = LabelTensor(torch.tensor(pts), obj.problem.input_variables)
predicted_output = obj.model(pts.tensor)
if hasattr(obj.problem, 'truth_solution'):
truth_output = obj.problem.truth_solution(*pts.tensor.T).float()
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
cb = getattr(axes[0], method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
fig.colorbar(cb, ax=axes[0])
cb = getattr(axes[1], method)(*grids_container, truth_output.reshape(res, res).detach())
fig.colorbar(cb, ax=axes[1])
cb = getattr(axes[2], method)(*grids_container, (truth_output-predicted_output.tensor.float().flatten()).detach().reshape(res, res))
fig.colorbar(cb, ax=axes[2])
else:
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
cb = getattr(axes, method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
fig.colorbar(cb, ax=axes)
def _plot_1D_TimeDep(self, obj, method='contourf'):
"""
"""
if not isinstance(obj, PINN):
raise RuntimeError
res = 256
grids_container = np.meshgrid(
obj.problem.spatial_domain.discretize(res, 'grid'),
obj.problem.temporal_domain.discretize(res, 'grid'),
)
pts = np.hstack([
grids_container[0].reshape(-1, 1),
grids_container[1].reshape(-1, 1),
])
pts = LabelTensor(torch.tensor(pts), obj.problem.input_variables)
predicted_output = obj.model(pts.tensor)
if hasattr(obj.problem, 'truth_solution'):
truth_output = obj.problem.truth_solution(*pts.tensor.T).float()
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
cb = getattr(axes[0], method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
fig.colorbar(cb, ax=axes[0])
cb = getattr(axes[1], method)(*grids_container, truth_output.reshape(res, res).detach())
fig.colorbar(cb, ax=axes[1])
cb = getattr(axes[2], method)(*grids_container, (truth_output-predicted_output.tensor.float().flatten()).detach().reshape(res, res))
fig.colorbar(cb, ax=axes[2])
else:
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
cb = getattr(axes, method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
fig.colorbar(cb, ax=axes)
def plot(self, obj, filename=None):
"""
"""
if isinstance(obj.problem, (TimeDependentProblem, Problem1D)): # time-dep 1D
self._plot_1D_TimeDep(obj, method='pcolor')
elif isinstance(obj.problem, Problem2D): # 2D
self._plot_2D(obj, method='pcolor')
if filename:
plt.savefig(filename)
else:
plt.show()