diff --git a/pina/plotter.py b/pina/plotter.py index adeb895..ee6bfca 100644 --- a/pina/plotter.py +++ b/pina/plotter.py @@ -158,8 +158,9 @@ class Plotter: Plot sample of SolverInterface output. :param SolverInterface solver: The ``SolverInterface`` object instance. - :param list(str) components: The output variable to plot. If None, all - the output variables of the problem are selected. Default value is None. + :param str | list(str) components: The output variable(s) to plot. + If None, all the output variables of the problem are selected. + Default value is None. :param dict fixed_variables: A dictionary with all the variables that should be kept fixed during the plot. The keys of the dictionary are the variables name whereas the values are the corresponding @@ -176,6 +177,13 @@ class Plotter: if components is None: components = solver.problem.output_variables + if isinstance(components, str): + components = [components] + + if not isinstance(components, list): + raise NotImplementedError('Output variables must be passed' + 'as a string or a list of strings.') + if len(components) > 1: raise NotImplementedError('Multidimensional plots are not implemented, ' 'set components to an available components of the problem.')