modify 3d plot/ adding tests

This commit is contained in:
Dario Coscia
2024-01-22 18:20:48 +01:00
committed by Nicola Demo
parent 927dbcf91e
commit 513144dfaf
2 changed files with 76 additions and 5 deletions

View File

@@ -39,31 +39,33 @@ class Plotter:
variables = problem.temporal_domain.variables
if len(variables) not in [1, 2, 3]:
raise ValueError
raise ValueError('Samples can be plotted only in dimensions 1, 2 and 3.')
fig = plt.figure()
proj = '3d' if len(variables) == 3 else None
ax = fig.add_subplot(projection=proj)
for location in problem.input_pts:
coords = problem.input_pts[location].extract(variables).T.detach()
if coords.shape[0] == 1: # 1D samples
if len(variables)==1: # 1D samples
ax.plot(coords.flatten(),
torch.zeros(coords.flatten().shape),
'.',
label=location,
**kwargs)
else:
elif len(variables)==2:
ax.plot(*coords, '.', label=location, **kwargs)
elif len(variables)==3:
ax.scatter(*coords, '.', label=location, **kwargs)
ax.set_xlabel(variables[0])
try:
ax.set_ylabel(variables[1])
except:
except (IndexError, AttributeError):
pass
try:
ax.set_zlabel(variables[2])
except:
except (IndexError, AttributeError):
pass
plt.legend()