fix examples (#21)

This commit is contained in:
Nicola Demo
2022-07-21 13:41:59 +02:00
committed by GitHub
parent 62f203fcc3
commit e8c2f87460
14 changed files with 63 additions and 67 deletions

View File

@@ -50,22 +50,23 @@ class Plotter:
plt.legend()
plt.show()
def _1d_plot(self, pts, pred, method, truth_solution=None):
def _1d_plot(self, pts, pred, method, truth_solution=None, **kwargs):
"""
"""
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 8))
ax.plot(pts, pred.detach())
ax.plot(pts, pred.detach(), **kwargs)
if truth_solution:
truth_output = truth_solution(pts).float()
ax.plot(pts, truth_output.detach())
ax.plot(pts, truth_output.detach(), **kwargs)
plt.xlabel(pts.labels[0])
plt.ylabel(pred.labels[0])
plt.show()
def _2d_plot(self, pts, pred, v, res, method, truth_solution=None):
def _2d_plot(self, pts, pred, v, res, method, truth_solution=None,
**kwargs):
"""
"""
@@ -73,27 +74,29 @@ class Plotter:
pred_output = pred.reshape(res, res)
if truth_solution:
truth_output = truth_solution(*pts.T).float().reshape(res, res)
truth_output = truth_solution(pts).float().reshape(res, res)
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
cb = getattr(ax[0], method)(*grids, pred_output.detach())
cb = getattr(ax[0], method)(*grids, pred_output.detach(), **kwargs)
fig.colorbar(cb, ax=ax[0])
cb = getattr(ax[1], method)(*grids, truth_output.detach())
cb = getattr(ax[1], method)(*grids, truth_output.detach(), **kwargs)
fig.colorbar(cb, ax=ax[1])
cb = getattr(ax[2], method)(*grids,
(truth_output-pred_output).detach())
(truth_output-pred_output).detach(),
**kwargs)
fig.colorbar(cb, ax=ax[2])
else:
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
cb = getattr(ax, method)(*grids, pred_output.detach())
cb = getattr(ax, method)(*grids, pred_output.detach(), **kwargs)
fig.colorbar(cb, ax=ax)
def plot(self, pinn, components, fixed_variables={}, method='contourf',
res=256, filename=None):
def plot(self, pinn, components=None, fixed_variables={}, method='contourf',
res=256, filename=None, **kwargs):
"""
"""
if components is None:
components = [pinn.problem.output_variables]
v = [
var for var in pinn.problem.input_variables
if var not in fixed_variables.keys()
@@ -112,10 +115,11 @@ class Plotter:
truth_solution = getattr(pinn.problem, 'truth_solution', None)
if len(v) == 1:
self._1d_plot(pts, predicted_output, method, truth_solution)
self._1d_plot(pts, predicted_output, method, truth_solution,
**kwargs)
elif len(v) == 2:
self._2d_plot(pts, predicted_output, v, res, method,
truth_solution)
truth_solution, **kwargs)
if filename:
plt.title('Output {} with parameter {}'.format(components,