diff --git a/pina/plotter.py b/pina/plotter.py index 19823e5..4c2ef11 100644 --- a/pina/plotter.py +++ b/pina/plotter.py @@ -39,31 +39,33 @@ class Plotter: variables = problem.temporal_domain.variables if len(variables) not in [1, 2, 3]: - raise ValueError + raise ValueError('Samples can be plotted only in dimensions 1, 2 and 3.') fig = plt.figure() proj = '3d' if len(variables) == 3 else None ax = fig.add_subplot(projection=proj) for location in problem.input_pts: coords = problem.input_pts[location].extract(variables).T.detach() - if coords.shape[0] == 1: # 1D samples + if len(variables)==1: # 1D samples ax.plot(coords.flatten(), torch.zeros(coords.flatten().shape), '.', label=location, **kwargs) - else: + elif len(variables)==2: ax.plot(*coords, '.', label=location, **kwargs) + elif len(variables)==3: + ax.scatter(*coords, '.', label=location, **kwargs) ax.set_xlabel(variables[0]) try: ax.set_ylabel(variables[1]) - except: + except (IndexError, AttributeError): pass try: ax.set_zlabel(variables[2]) - except: + except (IndexError, AttributeError): pass plt.legend() diff --git a/tests/test_plotter.py b/tests/test_plotter.py new file mode 100644 index 0000000..99f99bc --- /dev/null +++ b/tests/test_plotter.py @@ -0,0 +1,69 @@ +from pina.geometry import CartesianDomain +from pina import Condition, Plotter +from matplotlib.testing.decorators import image_comparison +import matplotlib.pyplot as plt +from pina.problem import SpatialProblem +from pina.equation import FixedValue + + +class FooProblem1D(SpatialProblem): + + # assign output/ spatial and temporal variables + output_variables = ['u'] + spatial_domain = CartesianDomain({'x' : [-1, 1]}) + + # problem condition statement + conditions = { + 'D': Condition(location=CartesianDomain({'x': [-1, 1]}), equation=FixedValue(0.)), + } + +class FooProblem2D(SpatialProblem): + + # assign output/ spatial and temporal variables + output_variables = ['u'] + spatial_domain = CartesianDomain({'x' : [-1, 1], 'y': [-1, 1]}) + + # problem condition statement + conditions = { + 'D': Condition(location=CartesianDomain({'x' : [-1, 1], 'y': [-1, 1]}), equation=FixedValue(0.)), + } + +class FooProblem3D(SpatialProblem): + + # assign output/ spatial and temporal variables + output_variables = ['u'] + spatial_domain = CartesianDomain({'x' : [-1, 1], 'y': [-1, 1], 'z':[-1,1]}) + + # problem condition statement + conditions = { + 'D': Condition(location=CartesianDomain({'x' : [-1, 1], 'y': [-1, 1], 'z':[-1,1]}), equation=FixedValue(0.)), + } + + + +def test_constructor(): + Plotter() + +def test_plot_samples_1d(): + problem = FooProblem1D() + problem.discretise_domain(n=10, mode='grid', variables = 'x', locations=['D']) + pl = Plotter() + pl.plot_samples(problem=problem, filename='fig.png') + import os + os.remove('fig.png') + +def test_plot_samples_2d(): + problem = FooProblem2D() + problem.discretise_domain(n=10, mode='grid', variables = ['x', 'y'], locations=['D']) + pl = Plotter() + pl.plot_samples(problem=problem, filename='fig.png') + import os + os.remove('fig.png') + +def test_plot_samples_3d(): + problem = FooProblem3D() + problem.discretise_domain(n=10, mode='grid', variables = ['x', 'y', 'z'], locations=['D']) + pl = Plotter() + pl.plot_samples(problem=problem, filename='fig.png') + import os + os.remove('fig.png') \ No newline at end of file