minor fix

This commit is contained in:
Your Name
2022-07-20 17:23:53 +02:00
committed by Nicola Demo
parent 75a81af99c
commit a05adea4e3
10 changed files with 231 additions and 203 deletions

View File

@@ -13,130 +13,113 @@ from .problem import SpatialProblem, TimeDependentProblem
class Plotter:
def _plot_2D(self, obj, method='contourf'):
"""
"""
if not isinstance(obj, PINN):
raise RuntimeError
def plot_samples(self, pinn, variables=None):
res = 256
pts = obj.problem.spatial_domain.discretize(res, 'grid')
grids_container = [
pts[:, 0].reshape(res, res),
pts[:, 1].reshape(res, res),
]
pts = LabelTensor(torch.tensor(pts), obj.problem.input_variables)
predicted_output = obj.model(pts.tensor)
if variables is None:
variables = pinn.problem.domain.variables
elif variables == 'spatial':
variables = pinn.problem.spatial_domain.variables
elif variables == 'temporal':
variables = pinn.problem.temporal_domain.variables
if hasattr(obj.problem, 'truth_solution'):
truth_output = obj.problem.truth_solution(*pts.tensor.T).float()
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
if len(variables) not in [1, 2, 3]:
raise ValueError
cb = getattr(axes[0], method)(*grids_container, predicted_output.reshape(res, res).detach())
fig.colorbar(cb, ax=axes[0])
cb = getattr(axes[1], method)(*grids_container, truth_output.reshape(res, res).detach())
fig.colorbar(cb, ax=axes[1])
cb = getattr(axes[2], method)(*grids_container, (truth_output-predicted_output.float().flatten()).detach().reshape(res, res))
fig.colorbar(cb, ax=axes[2])
else:
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
cb = getattr(axes, method)(*grids_container, predicted_output.reshape(res, res).detach())
fig.colorbar(cb, ax=axes)
fig = plt.figure()
proj = '3d' if len(variables) == 3 else None
ax = fig.add_subplot(projection=proj)
for location in pinn.input_pts:
coords = pinn.input_pts[location].extract(variables).T.detach()
if coords.shape[0] == 1: # 1D samples
ax.plot(coords[0], torch.zeros(coords[0].shape), '.',
label=location)
else:
ax.plot(*coords, '.', label=location)
ax.set_xlabel(variables[0])
try:
ax.set_ylabel(variables[1])
except:
pass
def _plot_1D_TimeDep(self, obj, method='contourf'):
"""
"""
if not isinstance(obj, PINN):
raise RuntimeError
res = 256
grids_container = np.meshgrid(
obj.problem.spatial_domain.discretize(res, 'grid'),
obj.problem.temporal_domain.discretize(res, 'grid'),
)
pts = np.hstack([
grids_container[0].reshape(-1, 1),
grids_container[1].reshape(-1, 1),
])
pts = LabelTensor(torch.tensor(pts), obj.problem.input_variables)
predicted_output = obj.model(pts.tensor)
if hasattr(obj.problem, 'truth_solution'):
truth_output = obj.problem.truth_solution(*pts.tensor.T).float()
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
cb = getattr(axes[0], method)(*grids_container, predicted_output.reshape(res, res).detach())
fig.colorbar(cb, ax=axes[0])
cb = getattr(axes[1], method)(*grids_container, truth_output.reshape(res, res).detach())
fig.colorbar(cb, ax=axes[1])
cb = getattr(axes[2], method)(*grids_container, (truth_output-predicted_output.float().flatten()).detach().reshape(res, res))
fig.colorbar(cb, ax=axes[2])
else:
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
cb = getattr(axes, method)(*grids_container, predicted_output.reshape(res, res).detach())
fig.colorbar(cb, ax=axes)
def plot(self, obj, method='contourf', component='u', parametric=False, params_value=1.5, filename=None):
"""
"""
res = 256
pts = obj.problem.domain.sample(res, 'grid')
if parametric:
pts_params = torch.ones(pts.shape[0], len(obj.problem.parameters), dtype=pts.dtype)*params_value
pts_params = LabelTensor(pts_params, obj.problem.parameters)
pts = pts.append(pts_params)
grids_container = [
pts[:, 0].reshape(res, res),
pts[:, 1].reshape(res, res),
]
ind_dict = {}
all_locations = [condition for condition in obj.problem.conditions]
for location in all_locations:
if hasattr(obj.problem.conditions[location], 'location'):
keys_range_ = obj.problem.conditions[location].location.range_.keys()
if ('x' in keys_range_) and ('y' in keys_range_):
range_x = obj.problem.conditions[location].location.range_['x']
range_y = obj.problem.conditions[location].location.range_['y']
ind_x = np.where(np.logical_or(pts[:, 0]<range_x[0], pts[:, 0]>range_x[1]))
ind_y = np.where(np.logical_or(pts[:, 1]<range_y[0], pts[:, 1]>range_y[1]))
ind_to_exclude = np.union1d(ind_x, ind_y)
ind_dict[location] = ind_to_exclude
import functools
from functools import reduce
final_inds = reduce(np.intersect1d, ind_dict.values())
predicted_output = obj.model(pts)
predicted_output = predicted_output.extract([component])
predicted_output[final_inds] = np.nan
if hasattr(obj.problem, 'truth_solution'):
truth_output = obj.problem.truth_solution(*pts.T).float()
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
cb = getattr(axes[0], method)(*grids_container, predicted_output.reshape(res, res).detach())
fig.colorbar(cb, ax=axes[0])
cb = getattr(axes[1], method)(*grids_container, truth_output.reshape(res, res).detach())
fig.colorbar(cb, ax=axes[1])
cb = getattr(axes[2], method)(*grids_container, (truth_output-predicted_output.float().flatten()).detach().reshape(res, res))
fig.colorbar(cb, ax=axes[2])
else:
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
cb = getattr(axes, method)(*grids_container, predicted_output.reshape(res, res).detach(), levels=32)
fig.colorbar(cb, ax=axes)
if filename:
plt.title('Output {} with parameter {}'.format(component, params_value))
plt.savefig(filename)
else:
plt.show()
def plot_samples(self, obj):
for location in obj.input_pts:
pts_x = obj.input_pts[location].extract(['x'])
pts_y = obj.input_pts[location].extract(['y'])
plt.plot(pts_x.detach(), pts_y.detach(), '.', label=location)
try:
ax.set_zlabel(variables[2])
except:
pass
plt.legend()
plt.show()
def _1d_plot(self, pts, pred, method, truth_solution=None):
"""
"""
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 8))
ax.plot(pts, pred.detach())
if truth_solution:
truth_output = truth_solution(pts).float()
ax.plot(pts, truth_output.detach())
plt.xlabel(pts.labels[0])
plt.ylabel(pred.labels[0])
plt.show()
def _2d_plot(self, pts, pred, v, res, method, truth_solution=None):
"""
"""
grids = [p_.reshape(res, res) for p_ in pts.extract(v).T]
pred_output = pred.reshape(res, res)
if truth_solution:
truth_output = truth_solution(*pts.T).float().reshape(res, res)
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
cb = getattr(ax[0], method)(*grids, pred_output.detach())
fig.colorbar(cb, ax=ax[0])
cb = getattr(ax[1], method)(*grids, truth_output.detach())
fig.colorbar(cb, ax=ax[1])
cb = getattr(ax[2], method)(*grids,
(truth_output-pred_output).detach())
fig.colorbar(cb, ax=ax[2])
else:
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
cb = getattr(ax, method)(*grids, pred_output.detach())
fig.colorbar(cb, ax=ax)
def plot(self, pinn, components, fixed_variables={}, method='contourf',
res=256, filename=None):
"""
"""
v = [
var for var in pinn.problem.input_variables
if var not in fixed_variables.keys()
]
pts = pinn.problem.domain.sample(res, 'grid', variables=v)
for variable, value in fixed_variables.items():
new = LabelTensor(torch.ones(pts.shape[0], 1)*value, [variable])
pts = pts.append(new)
predicted_output = pinn.model(pts)
if isinstance(components, str):
predicted_output = predicted_output.extract(components)
elif callable(components):
predicted_output = components(predicted_output)
truth_solution = getattr(pinn.problem, 'truth_solution', None)
if len(v) == 1:
self._1d_plot(pts, predicted_output, method, truth_solution)
elif len(v) == 2:
self._2d_plot(pts, predicted_output, v, res, method,
truth_solution)
if filename:
plt.title('Output {} with parameter {}'.format(components,
fixed_variables))
plt.savefig(filename)
else:
plt.show()