update plotter
This commit is contained in:
committed by
Nicola Demo
parent
934ae409ff
commit
0d38de5afe
@@ -88,7 +88,6 @@ class Plotter:
|
||||
truth_output = truth_solution(pts).float()
|
||||
ax.plot(pts, truth_output.detach(), label='True solution', **kwargs)
|
||||
|
||||
plt.xlabel(pts.labels[0])
|
||||
plt.ylabel(pred.labels[0])
|
||||
plt.legend()
|
||||
plt.show()
|
||||
@@ -120,7 +119,7 @@ class Plotter:
|
||||
|
||||
pred_output = pred.reshape(res, res)
|
||||
if truth_solution:
|
||||
truth_output = truth_solution(pts).float().reshape(res, res)
|
||||
truth_output = truth_solution(pts).float().reshape(res, res).as_subclass(torch.Tensor)
|
||||
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
|
||||
|
||||
cb = getattr(ax[0], method)(*grids, pred_output.cpu().detach(),
|
||||
@@ -157,8 +156,7 @@ class Plotter:
|
||||
|
||||
:param SolverInterface solver: The ``SolverInterface`` object instance.
|
||||
:param list(str) components: The output variable to plot. If None, all
|
||||
the output variables of the problem are selected. Default value is
|
||||
None.
|
||||
the output variables of the problem are selected. Default value is None.
|
||||
:param dict fixed_variables: A dictionary with all the variables that
|
||||
should be kept fixed during the plot. The keys of the dictionary
|
||||
are the variables name whereas the values are the corresponding
|
||||
@@ -173,7 +171,11 @@ class Plotter:
|
||||
"""
|
||||
|
||||
if components is None:
|
||||
components = [solver.problem.output_variables]
|
||||
components = solver.problem.output_variables
|
||||
|
||||
if len(components) > 1:
|
||||
raise NotImplementedError('Multidimensional plots are not implemented, '
|
||||
'set components to an available components of the problem.')
|
||||
v = [
|
||||
var for var in solver.problem.input_variables
|
||||
if var not in fixed_variables.keys()
|
||||
@@ -188,13 +190,9 @@ class Plotter:
|
||||
pts = pts.append(fixed_pts)
|
||||
pts = pts.to(device=solver.device)
|
||||
|
||||
predicted_output = solver.forward(pts)
|
||||
if isinstance(components, str):
|
||||
predicted_output = predicted_output.extract(components)
|
||||
elif callable(components):
|
||||
predicted_output = components(predicted_output)
|
||||
|
||||
predicted_output = solver.forward(pts).extract(components).as_subclass(torch.Tensor)
|
||||
truth_solution = getattr(solver.problem, 'truth_solution', None)
|
||||
|
||||
if len(v) == 1:
|
||||
self._1d_plot(pts, predicted_output, method, truth_solution,
|
||||
**kwargs)
|
||||
|
||||
Reference in New Issue
Block a user