fix examples (#21)
This commit is contained in:
@@ -50,22 +50,23 @@ class Plotter:
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
def _1d_plot(self, pts, pred, method, truth_solution=None):
|
||||
def _1d_plot(self, pts, pred, method, truth_solution=None, **kwargs):
|
||||
"""
|
||||
"""
|
||||
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 8))
|
||||
|
||||
ax.plot(pts, pred.detach())
|
||||
ax.plot(pts, pred.detach(), **kwargs)
|
||||
|
||||
if truth_solution:
|
||||
truth_output = truth_solution(pts).float()
|
||||
ax.plot(pts, truth_output.detach())
|
||||
ax.plot(pts, truth_output.detach(), **kwargs)
|
||||
|
||||
plt.xlabel(pts.labels[0])
|
||||
plt.ylabel(pred.labels[0])
|
||||
plt.show()
|
||||
|
||||
def _2d_plot(self, pts, pred, v, res, method, truth_solution=None):
|
||||
def _2d_plot(self, pts, pred, v, res, method, truth_solution=None,
|
||||
**kwargs):
|
||||
"""
|
||||
"""
|
||||
|
||||
@@ -73,27 +74,29 @@ class Plotter:
|
||||
|
||||
pred_output = pred.reshape(res, res)
|
||||
if truth_solution:
|
||||
truth_output = truth_solution(*pts.T).float().reshape(res, res)
|
||||
truth_output = truth_solution(pts).float().reshape(res, res)
|
||||
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
|
||||
|
||||
cb = getattr(ax[0], method)(*grids, pred_output.detach())
|
||||
cb = getattr(ax[0], method)(*grids, pred_output.detach(), **kwargs)
|
||||
fig.colorbar(cb, ax=ax[0])
|
||||
cb = getattr(ax[1], method)(*grids, truth_output.detach())
|
||||
cb = getattr(ax[1], method)(*grids, truth_output.detach(), **kwargs)
|
||||
fig.colorbar(cb, ax=ax[1])
|
||||
cb = getattr(ax[2], method)(*grids,
|
||||
(truth_output-pred_output).detach())
|
||||
(truth_output-pred_output).detach(),
|
||||
**kwargs)
|
||||
fig.colorbar(cb, ax=ax[2])
|
||||
else:
|
||||
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
|
||||
cb = getattr(ax, method)(*grids, pred_output.detach())
|
||||
cb = getattr(ax, method)(*grids, pred_output.detach(), **kwargs)
|
||||
fig.colorbar(cb, ax=ax)
|
||||
|
||||
|
||||
def plot(self, pinn, components, fixed_variables={}, method='contourf',
|
||||
res=256, filename=None):
|
||||
def plot(self, pinn, components=None, fixed_variables={}, method='contourf',
|
||||
res=256, filename=None, **kwargs):
|
||||
"""
|
||||
"""
|
||||
|
||||
if components is None:
|
||||
components = [pinn.problem.output_variables]
|
||||
v = [
|
||||
var for var in pinn.problem.input_variables
|
||||
if var not in fixed_variables.keys()
|
||||
@@ -112,10 +115,11 @@ class Plotter:
|
||||
|
||||
truth_solution = getattr(pinn.problem, 'truth_solution', None)
|
||||
if len(v) == 1:
|
||||
self._1d_plot(pts, predicted_output, method, truth_solution)
|
||||
self._1d_plot(pts, predicted_output, method, truth_solution,
|
||||
**kwargs)
|
||||
elif len(v) == 2:
|
||||
self._2d_plot(pts, predicted_output, v, res, method,
|
||||
truth_solution)
|
||||
truth_solution, **kwargs)
|
||||
|
||||
if filename:
|
||||
plt.title('Output {} with parameter {}'.format(components,
|
||||
|
||||
Reference in New Issue
Block a user