preliminary modifications for N-S

This commit is contained in:
Anna Ivagnes
2022-05-05 17:12:31 +02:00
parent d152fe67e3
commit 8130912926
13 changed files with 213 additions and 162 deletions

View File

@@ -1,6 +1,6 @@
""" Module for plotting. """
import matplotlib
matplotlib.use('Qt5Agg')
#matplotlib.use('Qt5Agg')
import matplotlib.pyplot as plt
import numpy as np
import torch
@@ -32,15 +32,15 @@ class Plotter:
truth_output = obj.problem.truth_solution(*pts.tensor.T).float()
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
cb = getattr(axes[0], method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
cb = getattr(axes[0], method)(*grids_container, predicted_output.reshape(res, res).detach())
fig.colorbar(cb, ax=axes[0])
cb = getattr(axes[1], method)(*grids_container, truth_output.reshape(res, res).detach())
fig.colorbar(cb, ax=axes[1])
cb = getattr(axes[2], method)(*grids_container, (truth_output-predicted_output.tensor.float().flatten()).detach().reshape(res, res))
cb = getattr(axes[2], method)(*grids_container, (truth_output-predicted_output.float().flatten()).detach().reshape(res, res))
fig.colorbar(cb, ax=axes[2])
else:
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
cb = getattr(axes, method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
cb = getattr(axes, method)(*grids_container, predicted_output.reshape(res, res).detach())
fig.colorbar(cb, ax=axes)
@@ -66,66 +66,50 @@ class Plotter:
truth_output = obj.problem.truth_solution(*pts.tensor.T).float()
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
cb = getattr(axes[0], method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
cb = getattr(axes[0], method)(*grids_container, predicted_output.reshape(res, res).detach())
fig.colorbar(cb, ax=axes[0])
cb = getattr(axes[1], method)(*grids_container, truth_output.reshape(res, res).detach())
fig.colorbar(cb, ax=axes[1])
cb = getattr(axes[2], method)(*grids_container, (truth_output-predicted_output.tensor.float().flatten()).detach().reshape(res, res))
cb = getattr(axes[2], method)(*grids_container, (truth_output-predicted_output.float().flatten()).detach().reshape(res, res))
fig.colorbar(cb, ax=axes[2])
else:
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
cb = getattr(axes, method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
fig.colorbar(cb, ax=axes)
def plot(self, obj, method='contourf', filename=None):
"""
"""
res = 256
pts = obj.problem.domain.sample(res, 'grid')
print(pts)
grids_container = [
pts.tensor[:, 0].reshape(res, res),
pts.tensor[:, 1].reshape(res, res),
]
predicted_output = obj.model(pts)
predicted_output = predicted_output['p']
if hasattr(obj.problem, 'truth_solution'):
truth_output = obj.problem.truth_solution(*pts.tensor.T).float()
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
cb = getattr(axes[0], method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
fig.colorbar(cb, ax=axes[0])
cb = getattr(axes[1], method)(*grids_container, truth_output.reshape(res, res).detach())
fig.colorbar(cb, ax=axes[1])
cb = getattr(axes[2], method)(*grids_container, (truth_output-predicted_output.tensor.float().flatten()).detach().reshape(res, res))
fig.colorbar(cb, ax=axes[2])
else:
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
# cb = getattr(axes, method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
cb = getattr(axes, method)(*grids_container, predicted_output.reshape(res, res).detach())
fig.colorbar(cb, ax=axes)
if filename:
plt.savefig(filename)
else:
plt.show()
def plot(self, obj, method='contourf', filename=None):
def plot(self, obj, method='contourf', component='u', parametric=False, params_value=1, filename=None):
"""
"""
res = 256
pts = obj.problem.domain.sample(res, 'grid')
if parametric:
pts_params = torch.ones(pts.shape[0], len(obj.problem.parameters), dtype=pts.dtype)*params_value
pts_params = LabelTensor(pts_params, obj.problem.parameters)
pts = pts.append(pts_params)
grids_container = [
pts[:, 0].reshape(res, res),
pts[:, 1].reshape(res, res),
]
ind_dict = {}
all_locations = [condition for condition in obj.problem.conditions]
for location in all_locations:
if hasattr(obj.problem.conditions[location], 'location'):
keys_range_ = obj.problem.conditions[location].location.range_.keys()
if ('x' in keys_range_) and ('y' in keys_range_):
range_x = obj.problem.conditions[location].location.range_['x']
range_y = obj.problem.conditions[location].location.range_['y']
ind_x = np.where(np.logical_or(pts[:, 0]<range_x[0], pts[:, 0]>range_x[1]))
ind_y = np.where(np.logical_or(pts[:, 1]<range_y[0], pts[:, 1]>range_y[1]))
ind_to_exclude = np.union1d(ind_x, ind_y)
ind_dict[location] = ind_to_exclude
import functools
from functools import reduce
final_inds = reduce(np.intersect1d, ind_dict.values())
predicted_output = obj.model(pts)
predicted_output = predicted_output.extract(['u'])
predicted_output = predicted_output.extract([component])
predicted_output[final_inds] = np.nan
if hasattr(obj.problem, 'truth_solution'):
truth_output = obj.problem.truth_solution(*pts.T).float()
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
@@ -142,16 +126,16 @@ class Plotter:
fig.colorbar(cb, ax=axes)
if filename:
plt.title('Output {} with parameter {}'.format(component, params_value))
plt.savefig(filename)
else:
plt.show()
def plot_samples(self, obj):
for location in obj.input_pts:
plt.plot(*obj.input_pts[location].T.detach(), '.', label=location)
pts_x = obj.input_pts[location].extract(['x'])
pts_y = obj.input_pts[location].extract(['y'])
plt.plot(pts_x.detach(), pts_y.detach(), '.', label=location)
plt.legend()
plt.show()