update plotter

This commit is contained in:
Dario Coscia
2023-11-09 18:20:51 +01:00
committed by Nicola Demo
parent 934ae409ff
commit 0d38de5afe
21 changed files with 171 additions and 165 deletions

View File

@@ -88,7 +88,6 @@ class Plotter:
truth_output = truth_solution(pts).float()
ax.plot(pts, truth_output.detach(), label='True solution', **kwargs)
plt.xlabel(pts.labels[0])
plt.ylabel(pred.labels[0])
plt.legend()
plt.show()
@@ -120,7 +119,7 @@ class Plotter:
pred_output = pred.reshape(res, res)
if truth_solution:
truth_output = truth_solution(pts).float().reshape(res, res)
truth_output = truth_solution(pts).float().reshape(res, res).as_subclass(torch.Tensor)
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
cb = getattr(ax[0], method)(*grids, pred_output.cpu().detach(),
@@ -157,8 +156,7 @@ class Plotter:
:param SolverInterface solver: The ``SolverInterface`` object instance.
:param list(str) components: The output variable to plot. If None, all
the output variables of the problem are selected. Default value is
None.
the output variables of the problem are selected. Default value is None.
:param dict fixed_variables: A dictionary with all the variables that
should be kept fixed during the plot. The keys of the dictionary
are the variables name whereas the values are the corresponding
@@ -173,7 +171,11 @@ class Plotter:
"""
if components is None:
components = [solver.problem.output_variables]
components = solver.problem.output_variables
if len(components) > 1:
raise NotImplementedError('Multidimensional plots are not implemented, '
'set components to an available components of the problem.')
v = [
var for var in solver.problem.input_variables
if var not in fixed_variables.keys()
@@ -188,13 +190,9 @@ class Plotter:
pts = pts.append(fixed_pts)
pts = pts.to(device=solver.device)
predicted_output = solver.forward(pts)
if isinstance(components, str):
predicted_output = predicted_output.extract(components)
elif callable(components):
predicted_output = components(predicted_output)
predicted_output = solver.forward(pts).extract(components).as_subclass(torch.Tensor)
truth_solution = getattr(solver.problem, 'truth_solution', None)
if len(v) == 1:
self._1d_plot(pts, predicted_output, method, truth_solution,
**kwargs)