Tutorials and Doc (#191)

* Tutorial doc update
* update doc tutorial
* doc not compiling

---------

Co-authored-by: Dario Coscia <dcoscia@euclide.maths.sissa.it>
Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local>
This commit is contained in:
Nicola Demo
2023-10-23 12:48:09 +02:00
parent ac829aece9
commit 0c8072274e
93 changed files with 2306 additions and 1685 deletions

View File

@@ -11,7 +11,7 @@ class Plotter:
Implementation of a plotter class, for easy visualizations.
"""
def plot_samples(self, solver, variables=None):
def plot_samples(self, problem, variables=None):
"""
Plot the training grid samples.
@@ -30,11 +30,11 @@ class Plotter:
"""
if variables is None:
variables = solver.problem.domain.variables
variables = problem.domain.variables
elif variables == 'spatial':
variables = solver.problem.spatial_domain.variables
variables = problem.spatial_domain.variables
elif variables == 'temporal':
variables = solver.problem.temporal_domain.variables
variables = problem.temporal_domain.variables
if len(variables) not in [1, 2, 3]:
raise ValueError
@@ -42,11 +42,11 @@ class Plotter:
fig = plt.figure()
proj = '3d' if len(variables) == 3 else None
ax = fig.add_subplot(projection=proj)
for location in solver.problem.input_pts:
coords = solver.problem.input_pts[location].extract(
for location in problem.input_pts:
coords = problem.input_pts[location].extract(
variables).T.detach()
if coords.shape[0] == 1: # 1D samples
ax.plot(coords[0], torch.zeros(coords[0].shape), '.',
ax.plot(coords.flatten(), torch.zeros(coords.flatten().shape), '.',
label=location)
else:
ax.plot(*coords, '.', label=location)
@@ -80,14 +80,15 @@ class Plotter:
"""
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 8))
ax.plot(pts, pred.detach(), **kwargs)
ax.plot(pts, pred.detach(), label='neural net solution', **kwargs)
if truth_solution:
truth_output = truth_solution(pts).float()
ax.plot(pts, truth_output.detach(), **kwargs)
ax.plot(pts, truth_output.detach(), label='true solution', **kwargs)
plt.xlabel(pts.labels[0])
plt.ylabel(pred.labels[0])
plt.legend()
plt.show()
def _2d_plot(self, pts, pred, v, res, method, truth_solution=None,