preliminary modifications for N-S

This commit is contained in:
Anna Ivagnes
2022-05-05 17:12:31 +02:00
parent d152fe67e3
commit 8130912926
13 changed files with 213 additions and 162 deletions

View File

@@ -14,22 +14,22 @@ class Burgers1D(TimeDependentProblem, SpatialProblem):
domain = Span({'x': [-1, 1], 't': [0, 1]}) domain = Span({'x': [-1, 1], 't': [0, 1]})
def burger_equation(input_, output_): def burger_equation(input_, output_):
grad_u = grad(output_['u'], input_) grad_u = grad(output_.extract(['u']), input_)
grad_x = grad_u['x'] grad_x = grad_u.extract(['x'])
grad_t = grad_u['t'] grad_t = grad_u.extract(['t'])
gradgrad_u_x = grad(grad_u['x'], input_) gradgrad_u_x = grad(grad_u.extract(['x']), input_)
return ( return (
grad_u['t'] + output_['u']*grad_u['x'] - grad_u.extract(['t']) + output_.extract(['u'])*grad_u.extract(['x']) -
(0.01/torch.pi)*gradgrad_u_x['x'] (0.01/torch.pi)*gradgrad_u_x.extract(['x'])
) )
def nil_dirichlet(input_, output_): def nil_dirichlet(input_, output_):
u_expected = 0.0 u_expected = 0.0
return output_['u'] - u_expected return output_.extract(['u']) - u_expected
def initial_condition(input_, output_): def initial_condition(input_, output_):
u_expected = -torch.sin(torch.pi*input_['x']) u_expected = -torch.sin(torch.pi*input_.extract(['x']))
return output_['u'] - u_expected return output_.extract(['u']) - u_expected
conditions = { conditions = {
'gamma1': Condition(Span({'x': -1, 't': [0, 1]}), nil_dirichlet), 'gamma1': Condition(Span({'x': -1, 't': [0, 1]}), nil_dirichlet),

View File

@@ -12,26 +12,26 @@ class EllipticOptimalControl(Problem2D):
def __init__(self, alpha=1): def __init__(self, alpha=1):
def term1(input_, output_): def term1(input_, output_):
grad_p = self.grad(output_['p'], input_) grad_p = self.grad(output_.extract(['p']), input_)
gradgrad_p_x1 = self.grad(grad_p['x1'], input_) gradgrad_p_x1 = self.grad(grad_p.extract(['x1']), input_)
gradgrad_p_x2 = self.grad(grad_p['x2'], input_) gradgrad_p_x2 = self.grad(grad_p.extract(['x2']), input_)
yd = 2.0 yd = 2.0
return output_['y'] - yd - (gradgrad_p_x1['x1'] + gradgrad_p_x2['x2']) return output_.extract(['y']) - yd - (gradgrad_p_x1.extract(['x1']) + gradgrad_p_x2.extract(['x2']))
def term2(input_, output_): def term2(input_, output_):
grad_y = self.grad(output_['y'], input_) grad_y = self.grad(output_.extract(['y']), input_)
gradgrad_y_x1 = self.grad(grad_y['x1'], input_) gradgrad_y_x1 = self.grad(grad_y.extract(['x1']), input_)
gradgrad_y_x2 = self.grad(grad_y['x2'], input_) gradgrad_y_x2 = self.grad(grad_y.extract(['x2']), input_)
return - (gradgrad_y_x1['x1'] + gradgrad_y_x2['x2']) - output_['u'] return - (gradgrad_y_x1.extract(['x1']) + gradgrad_y_x2.extract(['x2'])) - output_.extract(['u'])
def term3(input_, output_): def term3(input_, output_):
return output_['p'] - output_['u']*alpha return output_.extract(['p']) - output_.extract(['u'])*alpha
def nil_dirichlet(input_, output_): def nil_dirichlet(input_, output_):
y_value = 0.0 y_value = 0.0
p_value = 0.0 p_value = 0.0
return torch.abs(output_['y'] - y_value) + torch.abs(output_['p'] - p_value) return torch.abs(output_.extract(['y']) - y_value) + torch.abs(output_.extract(['p']) - p_value)
self.conditions = { self.conditions = {
'gamma1': {'location': Segment((xmin, ymin), (xmax, ymin)), 'func': nil_dirichlet}, 'gamma1': {'location': Segment((xmin, ymin), (xmax, ymin)), 'func': nil_dirichlet},

View File

@@ -14,13 +14,13 @@ class ParametricPoisson(SpatialProblem, ParametricProblem):
def laplace_equation(input_, output_): def laplace_equation(input_, output_):
force_term = torch.exp( force_term = torch.exp(
- 2*(input_['x'] - input_['mu1'])**2 - 2*(input_['y'] - - 2*(input_.extract(['x']) - input_.extract(['mu1']))**2 - 2*(input_.extract(['y']) -
input_['mu2'])**2) input_.extract(['mu2']))**2)
return nabla(output_['u'], input_) - force_term return nabla(output_.extract(['u']), input_) - force_term
def nil_dirichlet(input_, output_): def nil_dirichlet(input_, output_):
value = 0.0 value = 0.0
return output_['u'] - value return output_.extract(['u']) - value
conditions = { conditions = {
'gamma1': Condition( 'gamma1': Condition(

View File

@@ -13,28 +13,24 @@ class Stokes(SpatialProblem):
domain = Span({'x': [-2, 2], 'y': [-1, 1]}) domain = Span({'x': [-2, 2], 'y': [-1, 1]})
def momentum(input_, output_): def momentum(input_, output_):
#print(nabla(output_['ux', 'uy'], input_)) nabla_ = torch.hstack((LabelTensor(nabla(output_.extract(['ux']), input_), ['x']),
#print(grad(output_['p'], input_)) LabelTensor(nabla(output_.extract(['uy']), input_), ['y'])))
nabla_ = LabelTensor.hstack([ return - nabla_ + grad(output_.extract(['p']), input_)
LabelTensor(nabla(output_['ux'], input_), ['x']),
LabelTensor(nabla(output_['uy'], input_), ['y'])])
#return LabelTensor(nabla_.tensor + grad(output_['p'], input_).tensor, ['x', 'y'])
return nabla_.tensor + grad(output_['p'], input_).tensor
def continuity(input_, output_): def continuity(input_, output_):
return div(output_['ux', 'uy'], input_) return div(output_.extract(['ux', 'uy']), input_)
def inlet(input_, output_): def inlet(input_, output_):
value = 2.0 value = 2 * (1 - input_.extract(['y'])**2)
return output_['ux'] - value return output_.extract(['ux']) - value
def outlet(input_, output_): def outlet(input_, output_):
value = 0.0 value = 0.0
return output_['p'] - value return output_.extract(['p']) - value
def wall(input_, output_): def wall(input_, output_):
value = 0.0 value = 0.0
return output_['ux', 'uy'].tensor - value return output_.extract(['ux', 'uy']) - value
conditions = { conditions = {
'gamma_top': Condition(Span({'x': [-2, 2], 'y': 1}), wall), 'gamma_top': Condition(Span({'x': [-2, 2], 'y': 1}), wall),

View File

@@ -1,7 +1,7 @@
import argparse import argparse
import torch import torch
from torch.nn import Softplus from torch.nn import Softplus
from pina import Plotter
from pina import PINN as pPINN from pina import PINN as pPINN
from problems.parametric_poisson import ParametricPoisson from problems.parametric_poisson import ParametricPoisson
from pina.model import FeedForward from pina.model import FeedForward
@@ -14,7 +14,7 @@ class myFeature(torch.nn.Module):
super(myFeature, self).__init__() super(myFeature, self).__init__()
def forward(self, x): def forward(self, x):
return torch.exp(- 2*(x['x'] - x['mu1'])**2 - 2*(x['y'] - x['mu2'])**2) return torch.exp(- 2*(x.extract(['x']) - x.extract(['mu1']))**2 - 2*(x.extract(['y']) - x.extract(['mu2']))**2)
if __name__ == "__main__": if __name__ == "__main__":
@@ -31,7 +31,7 @@ if __name__ == "__main__":
poisson_problem = ParametricPoisson() poisson_problem = ParametricPoisson()
model = FeedForward( model = FeedForward(
layers=[200, 40, 10], layers=[10, 10, 10],
output_variables=poisson_problem.output_variables, output_variables=poisson_problem.output_variables,
input_variables=poisson_problem.input_variables, input_variables=poisson_problem.input_variables,
func=Softplus, func=Softplus,
@@ -42,15 +42,20 @@ if __name__ == "__main__":
poisson_problem, poisson_problem,
model, model,
lr=0.0006, lr=0.0006,
regularizer=1e-6, regularizer=1e-6)
lr_accelerate=None)
if args.s: if args.s:
pinn.span_pts(2000, 'random', ['D']) pinn.span_pts(500, n_params=10, mode_spatial='random', locations=['D'])
pinn.span_pts(200, 'random', ['gamma1', 'gamma2', 'gamma3', 'gamma4']) pinn.span_pts(200, n_params=10, mode_spatial='random', locations=['gamma1', 'gamma2', 'gamma3', 'gamma4'])
pinn.train(10000, 10) pinn.plot_pts()
pinn.train(10000, 100)
with open('param_poisson_history_{}_{}.txt'.format(args.id_run, args.features), 'w') as file_:
for i, losses in enumerate(pinn.history):
file_.write('{} {}\n'.format(i, sum(losses)))
pinn.save_state('pina.poisson_param') pinn.save_state('pina.poisson_param')
else: else:
pinn.load_state('pina.poisson_param') pinn.load_state('pina.poisson_param')
plotter = Plotter()
plotter.plot(pinn, component='u', parametric=True, params_value=0)

View File

@@ -52,9 +52,9 @@ if __name__ == "__main__":
if args.s: if args.s:
print(pinn) print(pinn)
pinn.span_pts(20, 'grid', ['gamma1', 'gamma2', 'gamma3', 'gamma4']) pinn.span_pts(20, mode_spatial='grid', locations=['gamma1', 'gamma2', 'gamma3', 'gamma4'])
pinn.span_pts(20, 'grid', ['D']) pinn.span_pts(20, mode_spatial='grid', locations=['D'])
#pinn.plot_pts() pinn.plot_pts()
pinn.train(5000, 100) pinn.train(5000, 100)
with open('poisson_history_{}_{}.txt'.format(args.id_run, args.features), 'w') as file_: with open('poisson_history_{}_{}.txt'.format(args.id_run, args.features), 'w') as file_:
for i, losses in enumerate(pinn.history): for i, losses in enumerate(pinn.history):
@@ -64,6 +64,6 @@ if __name__ == "__main__":
else: else:
pinn.load_state('pina.poisson') pinn.load_state('pina.poisson')
plotter = Plotter() plotter = Plotter()
plotter.plot(pinn) plotter.plot(pinn, component='u')

View File

@@ -21,7 +21,7 @@ if __name__ == "__main__":
stokes_problem = Stokes() stokes_problem = Stokes()
model = FeedForward( model = FeedForward(
layers=[40, 20, 20, 10], layers=[10, 10, 10, 10],
output_variables=stokes_problem.output_variables, output_variables=stokes_problem.output_variables,
input_variables=stokes_problem.input_variables, input_variables=stokes_problem.input_variables,
func=Softplus, func=Softplus,
@@ -32,23 +32,24 @@ if __name__ == "__main__":
model, model,
lr=0.006, lr=0.006,
error_norm='mse', error_norm='mse',
regularizer=1e-8, regularizer=1e-8)
lr_accelerate=None)
if args.s: if args.s:
#pinn.span_pts(200, 'grid', ['gamma_out']) pinn.span_pts(200, mode_spatial='grid', locations=['gamma_top', 'gamma_bot', 'gamma_in', 'gamma_out'])
pinn.span_pts(200, 'grid', ['gamma_top', 'gamma_bot', 'gamma_in', 'gamma_out']) pinn.span_pts(2000, mode_spatial='random', locations=['D'])
pinn.span_pts(2000, 'random', ['D']) pinn.plot_pts()
#plotter = Plotter()
#plotter.plot_samples(pinn)
pinn.train(10000, 100) pinn.train(10000, 100)
with open('stokes_history_{}.txt'.format(args.id_run), 'w') as file_:
for i, losses in enumerate(pinn.history):
file_.write('{} {}\n'.format(i, sum(losses)))
pinn.save_state('pina.stokes') pinn.save_state('pina.stokes')
else: else:
pinn.load_state('pina.stokes') pinn.load_state('pina.stokes')
plotter = Plotter() plotter = Plotter()
plotter.plot_samples(pinn) plotter.plot(pinn, component='ux')
plotter.plot(pinn) plotter.plot(pinn, component='uy')
plotter.plot(pinn, component='p')

View File

@@ -6,7 +6,12 @@ from .location import Location
class Condition: class Condition:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
if len(args) == 2 and not kwargs: if 'data_weight' in kwargs:
self.data_weight = kwargs['data_weight']
if not 'data_weight' in kwargs:
self.data_weight = 1.
if len(args) == 2:
if (isinstance(args[0], torch.Tensor) and if (isinstance(args[0], torch.Tensor) and
isinstance(args[1], torch.Tensor)): isinstance(args[1], torch.Tensor)):
@@ -21,7 +26,7 @@ class Condition:
else: else:
raise ValueError raise ValueError
elif not args and len(kwargs) == 2: elif not args and len(kwargs) >= 2:
if 'input_points' in kwargs and 'output_points' in kwargs: if 'input_points' in kwargs and 'output_points' in kwargs:
self.input_points = kwargs['input_points'] self.input_points = kwargs['input_points']
@@ -33,3 +38,4 @@ class Condition:
raise ValueError raise ValueError
else: else:
raise ValueError raise ValueError

View File

@@ -3,7 +3,8 @@ import torch
import torch.nn as nn import torch.nn as nn
from pina.label_tensor import LabelTensor from pina.label_tensor import LabelTensor
import warnings
import copy
class DeepONet(torch.nn.Module): class DeepONet(torch.nn.Module):
""" """
@@ -18,7 +19,7 @@ class DeepONet(torch.nn.Module):
<https://doi.org/10.1038/s42256-021-00302-5>`_ <https://doi.org/10.1038/s42256-021-00302-5>`_
""" """
def __init__(self, branch_net, trunk_net, output_variables): def __init__(self, branch_net, trunk_net, output_variables, inner_size=10):
""" """
:param torch.nn.Module branch_net: the neural network to use as branch :param torch.nn.Module branch_net: the neural network to use as branch
model. It has to take as input a :class:`LabelTensor`. The number model. It has to take as input a :class:`LabelTensor`. The number
@@ -43,7 +44,7 @@ class DeepONet(torch.nn.Module):
(1): Tanh() (1): Tanh()
(2): Linear(in_features=20, out_features=20, bias=True) (2): Linear(in_features=20, out_features=20, bias=True)
(3): Tanh() (3): Tanh()
(4): Linear(in_features=20, out_features=10, bias=True) (4): Linear(in_features=20, out_features=20, bias=True)
) )
) )
(branch_net): FeedForward( (branch_net): FeedForward(
@@ -53,20 +54,27 @@ class DeepONet(torch.nn.Module):
(1): Tanh() (1): Tanh()
(2): Linear(in_features=20, out_features=20, bias=True) (2): Linear(in_features=20, out_features=20, bias=True)
(3): Tanh() (3): Tanh()
(4): Linear(in_features=20, out_features=10, bias=True) (4): Linear(in_features=20, out_features=20, bias=True)
) )
) )
) )
""" """
super().__init__() super().__init__()
self.output_variables = output_variables
self.output_dimension = len(output_variables)
self.trunk_net = trunk_net self.trunk_net = trunk_net
self.branch_net = branch_net self.branch_net = branch_net
self.output_variables = output_variables if isinstance(self.branch_net.output_variables, int) and isinstance(self.branch_net.output_variables, int):
self.output_dimension = len(output_variables) if self.branch_net.output_dimension == self.trunk_net.output_dimension:
if self.output_dimension > 1: self.inner_size = self.branch_net.output_dimension
raise NotImplementedError('Vectorial DeepONet to be implemented') else:
raise ValueError('Branch and trunk networks have not the same output dimension.')
else:
warnings.warn("The output dimension of the branch and trunk networks has been imposed by default as 10 for each output variable. To set it change the output_variable of networks to an integer.")
self.inner_size = self.output_dimension*inner_size
@property @property
def input_variables(self): def input_variables(self):
@@ -82,10 +90,16 @@ class DeepONet(torch.nn.Module):
:rtype: LabelTensor :rtype: LabelTensor
""" """
branch_output = self.branch_net( branch_output = self.branch_net(
x.extract(self.branch_net.input_variables)) x.extract(self.branch_net.input_variables))
trunk_output = self.trunk_net( trunk_output = self.trunk_net(
x.extract(self.trunk_net.input_variables)) x.extract(self.trunk_net.input_variables))
local_size = int(self.inner_size/self.output_dimension)
output_ = torch.sum(branch_output * trunk_output, dim=1).reshape(-1, 1) for i, var in enumerate(self.output_variables):
start = i*local_size
return LabelTensor(output_, self.output_variables) stop = (i+1)*local_size
local_output = LabelTensor(torch.sum(branch_output[:, start:stop] * trunk_output[:, start:stop], dim=1).reshape(-1, 1), var)
if i==0:
output_ = local_output
else:
output_ = output_.append(local_output)
return output_

View File

@@ -30,7 +30,7 @@ def div(output_, input_):
else: # really to improve else: # really to improve
a = [] a = []
for o in output_.T: for o in output_.T:
a.append(grad(o, input_)) a.append(grad(o, input_).extract(['x', 'y']))
div = torch.zeros(output_.shape[0], 1) div = torch.zeros(output_.shape[0], 1)
for i in range(output_.shape[1]): for i in range(output_.shape[1]):
div += a[i][:, i].reshape(-1, 1) div += a[i][:, i].reshape(-1, 1)
@@ -42,4 +42,25 @@ def nabla(output_, input_):
""" """
TODO TODO
""" """
return div(grad(output_, input_), input_) return div(grad(output_, input_).extract(['x', 'y']), input_)
def advection_term(output_, input_):
"""
TODO
"""
dimension = len(output_.labels)
for i, label in enumerate(output_.labels):
# compute u dot gradient in each direction
gradient_loc = grad(output_.extract([label]), input_).extract(input_.labels[:dimension])
dim_0 = gradient_loc.shape[0]
dim_1 = gradient_loc.shape[1]
u_dot_grad_loc = torch.bmm(output_.view(dim_0, 1, dim_1),
gradient_loc.view(dim_0, dim_1, 1))
u_dot_grad_loc = LabelTensor(torch.reshape(u_dot_grad_loc,
(u_dot_grad_loc.shape[0], u_dot_grad_loc.shape[1])), [input_.labels[i]])
if i==0:
adv_term = u_dot_grad_loc
else:
adv_term = adv_term.append(u_dot_grad_loc)
return adv_term

View File

@@ -5,6 +5,7 @@ import numpy as np
from pina.label_tensor import LabelTensor from pina.label_tensor import LabelTensor
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732 torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
class PINN(object): class PINN(object):
def __init__(self, def __init__(self,
@@ -13,7 +14,6 @@ class PINN(object):
optimizer=torch.optim.Adam, optimizer=torch.optim.Adam,
lr=0.001, lr=0.001,
regularizer=0.00001, regularizer=0.00001,
data_weight=1.,
dtype=torch.float32, dtype=torch.float32,
device='cpu', device='cpu',
error_norm='mse'): error_norm='mse'):
@@ -53,13 +53,10 @@ class PINN(object):
self.truth_values = {} self.truth_values = {}
self.input_pts = {} self.input_pts = {}
self.trained_epoch = 0 self.trained_epoch = 0
self.optimizer = optimizer( self.optimizer = optimizer(
self.model.parameters(), lr=lr, weight_decay=regularizer) self.model.parameters(), lr=lr, weight_decay=regularizer)
self.data_weight = data_weight
@property @property
def problem(self): def problem(self):
return self._problem return self._problem
@@ -96,6 +93,7 @@ class PINN(object):
'optimizer_state' : self.optimizer.state_dict(), 'optimizer_state' : self.optimizer.state_dict(),
'optimizer_class' : self.optimizer.__class__, 'optimizer_class' : self.optimizer.__class__,
'history' : self.history, 'history' : self.history,
'input_points_dict' : self.input_pts,
} }
# TODO save also architecture param? # TODO save also architecture param?
@@ -117,22 +115,27 @@ class PINN(object):
self.trained_epoch = checkpoint['epoch'] self.trained_epoch = checkpoint['epoch']
self.history = checkpoint['history'] self.history = checkpoint['history']
self.input_pts = checkpoint['input_points_dict']
return self return self
def span_pts(self, n, mode='grid', locations='all'): def span_pts(self, n_spatial, n_params=0, mode_spatial='grid', mode_param='random', locations='all'):
if locations == 'all': if locations == 'all':
locations = [condition for condition in self.problem.conditions] locations = [condition for condition in self.problem.conditions]
for location in locations: for location in locations:
condition = self.problem.conditions[location] condition = self.problem.conditions[location]
try: try:
pts = condition.location.sample(n, mode) pts = condition.location.sample(n_spatial, mode_spatial, variables=self.problem.spatial_variables)
if n_params != 0:
pts_params = condition.location.sample(n_params, mode_param, variables=self.problem.parameters)
pts = LabelTensor(pts.repeat(n_params, 1), pts.labels)
pts_params = LabelTensor(pts_params.repeat_interleave(n_spatial).reshape((n_spatial*n_params, len(self.problem.parameters))), pts_params.labels)
pts = pts.append(pts_params)
except: except:
pts = condition.input_points pts = condition.input_points
self.input_pts[location] = pts #.double() # TODO
self.input_pts[location] = pts#.double() # TODO
self.input_pts[location] = ( self.input_pts[location] = (
self.input_pts[location].to(dtype=self.dtype, self.input_pts[location].to(dtype=self.dtype,
device=self.device)) device=self.device))
@@ -140,19 +143,16 @@ class PINN(object):
self.input_pts[location].retain_grad() self.input_pts[location].retain_grad()
def plot_pts(self, locations='all'): def plot_pts(self, locations='all'):
import matplotlib import matplotlib
matplotlib.use('GTK3Agg') # matplotlib.use('GTK3Agg')
if locations == 'all': if locations == 'all':
locations = [condition for condition in self.problem.conditions] locations = [condition for condition in self.problem.conditions]
for location in locations: for location in locations:
x, y = self.input_pts[location].tensor.T x = self.input_pts[location].extract(['x'])
#plt.plot(x.detach(), y.detach(), 'o', label=location) y = self.input_pts[location].extract(['y'])
np.savetxt('burgers_{}_pts.txt'.format(location), self.input_pts[location].tensor.detach(), header='x y', delimiter=' ') plt.plot(x.detach(), y.detach(), '.', label=location)
# np.savetxt('burgers_{}_pts.txt'.format(location), self.input_pts[location].tensor.detach(), header='x y', delimiter=' ')
plt.legend() plt.legend()
plt.show() plt.show()
@@ -169,18 +169,23 @@ class PINN(object):
for condition_name in self.problem.conditions: for condition_name in self.problem.conditions:
condition = self.problem.conditions[condition_name] condition = self.problem.conditions[condition_name]
pts = self.input_pts[condition_name] pts = self.input_pts[condition_name]
predicted = self.model(pts) predicted = self.model(pts)
if hasattr(condition, 'function'):
if isinstance(condition.function, list): if isinstance(condition.function, list):
for function in condition.function: for function in condition.function:
residuals = function(pts, predicted) residuals = function(pts, predicted)
losses.append(self._compute_norm(residuals)) local_loss = condition.data_weight*self._compute_norm(residuals)
else: losses.append(local_loss)
residuals = condition.function(pts, predicted) else:
losses.append(self._compute_norm(residuals)) residuals = condition.function(pts, predicted)
local_loss = condition.data_weight*self._compute_norm(residuals)
losses.append(local_loss)
elif hasattr(condition, 'output_points'):
residuals = predicted - condition.output_points
local_loss = condition.data_weight*self._compute_norm(residuals)
losses.append(local_loss)
self.optimizer.zero_grad() self.optimizer.zero_grad()
sum(losses).backward() sum(losses).backward()
self.optimizer.step() self.optimizer.step()

View File

@@ -1,6 +1,6 @@
""" Module for plotting. """ """ Module for plotting. """
import matplotlib import matplotlib
matplotlib.use('Qt5Agg') #matplotlib.use('Qt5Agg')
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
@@ -32,15 +32,15 @@ class Plotter:
truth_output = obj.problem.truth_solution(*pts.tensor.T).float() truth_output = obj.problem.truth_solution(*pts.tensor.T).float()
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6)) fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
cb = getattr(axes[0], method)(*grids_container, predicted_output.tensor.reshape(res, res).detach()) cb = getattr(axes[0], method)(*grids_container, predicted_output.reshape(res, res).detach())
fig.colorbar(cb, ax=axes[0]) fig.colorbar(cb, ax=axes[0])
cb = getattr(axes[1], method)(*grids_container, truth_output.reshape(res, res).detach()) cb = getattr(axes[1], method)(*grids_container, truth_output.reshape(res, res).detach())
fig.colorbar(cb, ax=axes[1]) fig.colorbar(cb, ax=axes[1])
cb = getattr(axes[2], method)(*grids_container, (truth_output-predicted_output.tensor.float().flatten()).detach().reshape(res, res)) cb = getattr(axes[2], method)(*grids_container, (truth_output-predicted_output.float().flatten()).detach().reshape(res, res))
fig.colorbar(cb, ax=axes[2]) fig.colorbar(cb, ax=axes[2])
else: else:
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6)) fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
cb = getattr(axes, method)(*grids_container, predicted_output.tensor.reshape(res, res).detach()) cb = getattr(axes, method)(*grids_container, predicted_output.reshape(res, res).detach())
fig.colorbar(cb, ax=axes) fig.colorbar(cb, ax=axes)
@@ -66,66 +66,50 @@ class Plotter:
truth_output = obj.problem.truth_solution(*pts.tensor.T).float() truth_output = obj.problem.truth_solution(*pts.tensor.T).float()
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6)) fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
cb = getattr(axes[0], method)(*grids_container, predicted_output.tensor.reshape(res, res).detach()) cb = getattr(axes[0], method)(*grids_container, predicted_output.reshape(res, res).detach())
fig.colorbar(cb, ax=axes[0]) fig.colorbar(cb, ax=axes[0])
cb = getattr(axes[1], method)(*grids_container, truth_output.reshape(res, res).detach()) cb = getattr(axes[1], method)(*grids_container, truth_output.reshape(res, res).detach())
fig.colorbar(cb, ax=axes[1]) fig.colorbar(cb, ax=axes[1])
cb = getattr(axes[2], method)(*grids_container, (truth_output-predicted_output.tensor.float().flatten()).detach().reshape(res, res)) cb = getattr(axes[2], method)(*grids_container, (truth_output-predicted_output.float().flatten()).detach().reshape(res, res))
fig.colorbar(cb, ax=axes[2]) fig.colorbar(cb, ax=axes[2])
else: else:
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6)) fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
cb = getattr(axes, method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
fig.colorbar(cb, ax=axes)
def plot(self, obj, method='contourf', filename=None):
"""
"""
res = 256
pts = obj.problem.domain.sample(res, 'grid')
print(pts)
grids_container = [
pts.tensor[:, 0].reshape(res, res),
pts.tensor[:, 1].reshape(res, res),
]
predicted_output = obj.model(pts)
predicted_output = predicted_output['p']
if hasattr(obj.problem, 'truth_solution'):
truth_output = obj.problem.truth_solution(*pts.tensor.T).float()
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
cb = getattr(axes[0], method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
fig.colorbar(cb, ax=axes[0])
cb = getattr(axes[1], method)(*grids_container, truth_output.reshape(res, res).detach())
fig.colorbar(cb, ax=axes[1])
cb = getattr(axes[2], method)(*grids_container, (truth_output-predicted_output.tensor.float().flatten()).detach().reshape(res, res))
fig.colorbar(cb, ax=axes[2])
else:
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
# cb = getattr(axes, method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
cb = getattr(axes, method)(*grids_container, predicted_output.reshape(res, res).detach()) cb = getattr(axes, method)(*grids_container, predicted_output.reshape(res, res).detach())
fig.colorbar(cb, ax=axes) fig.colorbar(cb, ax=axes)
if filename:
plt.savefig(filename)
else:
plt.show()
def plot(self, obj, method='contourf', filename=None): def plot(self, obj, method='contourf', component='u', parametric=False, params_value=1, filename=None):
""" """
""" """
res = 256 res = 256
pts = obj.problem.domain.sample(res, 'grid') pts = obj.problem.domain.sample(res, 'grid')
if parametric:
pts_params = torch.ones(pts.shape[0], len(obj.problem.parameters), dtype=pts.dtype)*params_value
pts_params = LabelTensor(pts_params, obj.problem.parameters)
pts = pts.append(pts_params)
grids_container = [ grids_container = [
pts[:, 0].reshape(res, res), pts[:, 0].reshape(res, res),
pts[:, 1].reshape(res, res), pts[:, 1].reshape(res, res),
] ]
ind_dict = {}
all_locations = [condition for condition in obj.problem.conditions]
for location in all_locations:
if hasattr(obj.problem.conditions[location], 'location'):
keys_range_ = obj.problem.conditions[location].location.range_.keys()
if ('x' in keys_range_) and ('y' in keys_range_):
range_x = obj.problem.conditions[location].location.range_['x']
range_y = obj.problem.conditions[location].location.range_['y']
ind_x = np.where(np.logical_or(pts[:, 0]<range_x[0], pts[:, 0]>range_x[1]))
ind_y = np.where(np.logical_or(pts[:, 1]<range_y[0], pts[:, 1]>range_y[1]))
ind_to_exclude = np.union1d(ind_x, ind_y)
ind_dict[location] = ind_to_exclude
import functools
from functools import reduce
final_inds = reduce(np.intersect1d, ind_dict.values())
predicted_output = obj.model(pts) predicted_output = obj.model(pts)
predicted_output = predicted_output.extract(['u']) predicted_output = predicted_output.extract([component])
predicted_output[final_inds] = np.nan
if hasattr(obj.problem, 'truth_solution'): if hasattr(obj.problem, 'truth_solution'):
truth_output = obj.problem.truth_solution(*pts.T).float() truth_output = obj.problem.truth_solution(*pts.T).float()
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6)) fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
@@ -142,16 +126,16 @@ class Plotter:
fig.colorbar(cb, ax=axes) fig.colorbar(cb, ax=axes)
if filename: if filename:
plt.title('Output {} with parameter {}'.format(component, params_value))
plt.savefig(filename) plt.savefig(filename)
else: else:
plt.show() plt.show()
def plot_samples(self, obj): def plot_samples(self, obj):
for location in obj.input_pts: for location in obj.input_pts:
plt.plot(*obj.input_pts[location].T.detach(), '.', label=location) pts_x = obj.input_pts[location].extract(['x'])
pts_y = obj.input_pts[location].extract(['y'])
plt.plot(pts_x.detach(), pts_y.detach(), '.', label=location)
plt.legend() plt.legend()
plt.show() plt.show()

View File

@@ -20,9 +20,27 @@ class Span(Location):
else: else:
raise TypeError raise TypeError
def sample(self, n, mode='random'): def sample(self, n, mode='random', variables='all'):
bounds = np.array(list(self.range_.values())) if variables=='all':
spatial_range_ = list(self.range_.keys())
spatial_fixed_ = list(self.fixed_.keys())
bounds = np.array(list(self.range_.values()))
fixed = np.array(list(self.fixed_.values()))
else:
bounds = []
spatial_range_ = []
spatial_fixed_ = []
fixed = []
for variable in variables:
if variable in self.range_.keys():
spatial_range_.append(variable)
bounds.append(list(self.range_[variable]))
elif variable in self.fixed_.keys():
spatial_fixed_.append(variable)
fixed.append(int(self.fixed_[variable]))
fixed = torch.Tensor(fixed)
bounds = np.array(bounds)
if mode == 'random': if mode == 'random':
pts = np.random.uniform(size=(n, bounds.shape[0])) pts = np.random.uniform(size=(n, bounds.shape[0]))
elif mode == 'chebyshev': elif mode == 'chebyshev':
@@ -41,23 +59,24 @@ class Span(Location):
from scipy.stats import qmc from scipy.stats import qmc
sampler = qmc.LatinHypercube(d=bounds.shape[0]) sampler = qmc.LatinHypercube(d=bounds.shape[0])
pts = sampler.random(n) pts = sampler.random(n)
# Scale pts # Scale pts
pts *= bounds[:, 1] - bounds[:, 0] pts *= bounds[:, 1] - bounds[:, 0]
pts += bounds[:, 0] pts += bounds[:, 0]
pts = pts.astype(np.float32) pts = pts.astype(np.float32)
pts = torch.from_numpy(pts) pts = torch.from_numpy(pts)
fixed = torch.Tensor(list(self.fixed_.values())) pts_range_ = LabelTensor(pts, spatial_range_)
pts_fixed_ = torch.ones(pts.shape[0], len(self.fixed_),
dtype=pts.dtype) * fixed if not len(spatial_fixed_)==0:
pts_range_ = LabelTensor(pts, list(self.range_.keys())) pts_fixed_ = torch.ones(pts.shape[0], len(spatial_fixed_),
pts_fixed_ = LabelTensor(pts_fixed_, list(self.fixed_.keys())) dtype=pts.dtype) * fixed
pts_fixed_ = LabelTensor(pts_fixed_, spatial_fixed_)
pts_range_ = pts_range_.append(pts_fixed_)
return pts_range_
if self.fixed_:
return pts_range_.append(pts_fixed_)
else:
return pts_range_
def meshgrid(self, n): def meshgrid(self, n):
pts = np.array([ pts = np.array([