version 0.0.1

This commit is contained in:
Your Name
2022-02-11 16:44:37 +01:00
parent fa8ffd5042
commit 1483746b45
29 changed files with 416 additions and 559 deletions

View File

@@ -7,7 +7,7 @@ import torch
from pina import LabelTensor
from pina import PINN
from .problem import Problem2D, Problem1D, TimeDependentProblem
from .problem import SpatialProblem, TimeDependentProblem
#from pina.tdproblem1d import TimeDepProblem1D
@@ -79,13 +79,31 @@ class Plotter:
def plot(self, obj, filename=None):
def plot(self, obj, method='contourf', filename=None):
"""
"""
if isinstance(obj.problem, (TimeDependentProblem, Problem1D)): # time-dep 1D
self._plot_1D_TimeDep(obj, method='pcolor')
elif isinstance(obj.problem, Problem2D): # 2D
self._plot_2D(obj, method='pcolor')
res = 256
pts = obj.problem.domain.sample(res, 'grid')
grids_container = [
pts[:, 0].reshape(res, res),
pts[:, 1].reshape(res, res),
]
predicted_output = obj.model(pts)
if hasattr(obj.problem, 'truth_solution'):
truth_output = obj.problem.truth_solution(*pts.tensor.T).float()
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
cb = getattr(axes[0], method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
fig.colorbar(cb, ax=axes[0])
cb = getattr(axes[1], method)(*grids_container, truth_output.reshape(res, res).detach())
fig.colorbar(cb, ax=axes[1])
cb = getattr(axes[2], method)(*grids_container, (truth_output-predicted_output.tensor.float().flatten()).detach().reshape(res, res))
fig.colorbar(cb, ax=axes[2])
else:
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
cb = getattr(axes, method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
fig.colorbar(cb, ax=axes)
if filename:
plt.savefig(filename)