Tutorials and Doc (#191)
* Tutorial doc update * update doc tutorial * doc not compiling --------- Co-authored-by: Dario Coscia <dcoscia@euclide.maths.sissa.it> Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local>
This commit is contained in:
@@ -11,7 +11,7 @@ class Plotter:
|
||||
Implementation of a plotter class, for easy visualizations.
|
||||
"""
|
||||
|
||||
def plot_samples(self, solver, variables=None):
|
||||
def plot_samples(self, problem, variables=None):
|
||||
"""
|
||||
Plot the training grid samples.
|
||||
|
||||
@@ -30,11 +30,11 @@ class Plotter:
|
||||
"""
|
||||
|
||||
if variables is None:
|
||||
variables = solver.problem.domain.variables
|
||||
variables = problem.domain.variables
|
||||
elif variables == 'spatial':
|
||||
variables = solver.problem.spatial_domain.variables
|
||||
variables = problem.spatial_domain.variables
|
||||
elif variables == 'temporal':
|
||||
variables = solver.problem.temporal_domain.variables
|
||||
variables = problem.temporal_domain.variables
|
||||
|
||||
if len(variables) not in [1, 2, 3]:
|
||||
raise ValueError
|
||||
@@ -42,11 +42,11 @@ class Plotter:
|
||||
fig = plt.figure()
|
||||
proj = '3d' if len(variables) == 3 else None
|
||||
ax = fig.add_subplot(projection=proj)
|
||||
for location in solver.problem.input_pts:
|
||||
coords = solver.problem.input_pts[location].extract(
|
||||
for location in problem.input_pts:
|
||||
coords = problem.input_pts[location].extract(
|
||||
variables).T.detach()
|
||||
if coords.shape[0] == 1: # 1D samples
|
||||
ax.plot(coords[0], torch.zeros(coords[0].shape), '.',
|
||||
ax.plot(coords.flatten(), torch.zeros(coords.flatten().shape), '.',
|
||||
label=location)
|
||||
else:
|
||||
ax.plot(*coords, '.', label=location)
|
||||
@@ -80,14 +80,15 @@ class Plotter:
|
||||
"""
|
||||
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 8))
|
||||
|
||||
ax.plot(pts, pred.detach(), **kwargs)
|
||||
ax.plot(pts, pred.detach(), label='neural net solution', **kwargs)
|
||||
|
||||
if truth_solution:
|
||||
truth_output = truth_solution(pts).float()
|
||||
ax.plot(pts, truth_output.detach(), **kwargs)
|
||||
ax.plot(pts, truth_output.detach(), label='true solution', **kwargs)
|
||||
|
||||
plt.xlabel(pts.labels[0])
|
||||
plt.ylabel(pred.labels[0])
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
def _2d_plot(self, pts, pred, v, res, method, truth_solution=None,
|
||||
|
||||
Reference in New Issue
Block a user