use LabelTensor, fix minor, docs

This commit is contained in:
Your Name
2022-03-29 18:05:26 +02:00
parent 12f4084d7f
commit 6b001c6c53
19 changed files with 370 additions and 322 deletions

View File

@@ -1,6 +1,6 @@
""" Module for plotting. """
import matplotlib
#matplotlib.use('Qt5Agg')
matplotlib.use('Qt5Agg')
import matplotlib.pyplot as plt
import numpy as np
import torch
@@ -119,16 +119,15 @@ class Plotter:
"""
res = 256
pts = obj.problem.domain.sample(res, 'grid')
print(pts)
grids_container = [
pts.tensor[:, 0].reshape(res, res),
pts.tensor[:, 1].reshape(res, res),
pts[:, 0].reshape(res, res),
pts[:, 1].reshape(res, res),
]
predicted_output = obj.model(pts)
predicted_output = predicted_output['u']
predicted_output = predicted_output.extract(['u'])
if hasattr(obj.problem, 'truth_solution'):
truth_output = obj.problem.truth_solution(*pts.tensor.T).float()
truth_output = obj.problem.truth_solution(*pts.T).float()
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
cb = getattr(axes[0], method)(*grids_container, predicted_output.reshape(res, res).detach())
@@ -139,7 +138,6 @@ class Plotter:
fig.colorbar(cb, ax=axes[2])
else:
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
# cb = getattr(axes, method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
cb = getattr(axes, method)(*grids_container, predicted_output.reshape(res, res).detach())
fig.colorbar(cb, ax=axes)
@@ -153,7 +151,7 @@ class Plotter:
def plot_samples(self, obj):
for location in obj.input_pts:
plt.plot(*obj.input_pts[location].tensor.T.detach(), '.', label=location)
plt.plot(*obj.input_pts[location].T.detach(), '.', label=location)
plt.legend()
plt.show()