diff --git a/pina/plotter.py b/pina/plotter.py index 4c2ef11..d00b44d 100644 --- a/pina/plotter.py +++ b/pina/plotter.py @@ -15,7 +15,8 @@ class Plotter: """ Plot the training grid samples. - :param AbstractProblem problem: The PINA problem from where to plot the domain. + :param AbstractProblem problem: The PINA problem from where to plot + the domain. :param list(str) variables: Variables to plot. If None, all variables are plotted. If 'spatial', only spatial variables are plotted. If 'temporal', only temporal variables are plotted. Defaults to None. @@ -39,7 +40,8 @@ class Plotter: variables = problem.temporal_domain.variables if len(variables) not in [1, 2, 3]: - raise ValueError('Samples can be plotted only in dimensions 1, 2 and 3.') + raise ValueError('Samples can be plotted only in ' + 'dimensions 1, 2 and 3.') fig = plt.figure() proj = '3d' if len(variables) == 3 else None @@ -96,7 +98,8 @@ class Plotter: if truth_solution: truth_output = truth_solution(pts).detach() - ax.plot(pts.extract(v), truth_output, label='True solution', **kwargs) + ax.plot(pts.extract(v), truth_output, + label='True solution', **kwargs) # TODO: pred is a torch.Tensor, so no labels is available # extra variable for labels should be @@ -197,7 +200,8 @@ class Plotter: if len(components) > 1: raise NotImplementedError('Multidimensional plots are not implemented, ' - 'set components to an available components of the problem.') + 'set components to an available components of' + ' the problem.') v = [ var for var in solver.problem.input_variables if var not in fixed_variables.keys() @@ -213,7 +217,8 @@ class Plotter: pts = pts.to(device=solver.device) # computing soluting and sending to cpu - predicted_output = solver.forward(pts).extract(components).as_subclass(torch.Tensor).cpu().detach() + predicted_output = solver.forward(pts).extract(components) + predicted_output = predicted_output.as_subclass(torch.Tensor).cpu().detach() pts = pts.cpu() truth_solution = getattr(solver.problem, 'truth_solution', None)