preliminary modifications for N-S
This commit is contained in:
@@ -14,22 +14,22 @@ class Burgers1D(TimeDependentProblem, SpatialProblem):
|
|||||||
domain = Span({'x': [-1, 1], 't': [0, 1]})
|
domain = Span({'x': [-1, 1], 't': [0, 1]})
|
||||||
|
|
||||||
def burger_equation(input_, output_):
|
def burger_equation(input_, output_):
|
||||||
grad_u = grad(output_['u'], input_)
|
grad_u = grad(output_.extract(['u']), input_)
|
||||||
grad_x = grad_u['x']
|
grad_x = grad_u.extract(['x'])
|
||||||
grad_t = grad_u['t']
|
grad_t = grad_u.extract(['t'])
|
||||||
gradgrad_u_x = grad(grad_u['x'], input_)
|
gradgrad_u_x = grad(grad_u.extract(['x']), input_)
|
||||||
return (
|
return (
|
||||||
grad_u['t'] + output_['u']*grad_u['x'] -
|
grad_u.extract(['t']) + output_.extract(['u'])*grad_u.extract(['x']) -
|
||||||
(0.01/torch.pi)*gradgrad_u_x['x']
|
(0.01/torch.pi)*gradgrad_u_x.extract(['x'])
|
||||||
)
|
)
|
||||||
|
|
||||||
def nil_dirichlet(input_, output_):
|
def nil_dirichlet(input_, output_):
|
||||||
u_expected = 0.0
|
u_expected = 0.0
|
||||||
return output_['u'] - u_expected
|
return output_.extract(['u']) - u_expected
|
||||||
|
|
||||||
def initial_condition(input_, output_):
|
def initial_condition(input_, output_):
|
||||||
u_expected = -torch.sin(torch.pi*input_['x'])
|
u_expected = -torch.sin(torch.pi*input_.extract(['x']))
|
||||||
return output_['u'] - u_expected
|
return output_.extract(['u']) - u_expected
|
||||||
|
|
||||||
conditions = {
|
conditions = {
|
||||||
'gamma1': Condition(Span({'x': -1, 't': [0, 1]}), nil_dirichlet),
|
'gamma1': Condition(Span({'x': -1, 't': [0, 1]}), nil_dirichlet),
|
||||||
|
|||||||
@@ -12,26 +12,26 @@ class EllipticOptimalControl(Problem2D):
|
|||||||
def __init__(self, alpha=1):
|
def __init__(self, alpha=1):
|
||||||
|
|
||||||
def term1(input_, output_):
|
def term1(input_, output_):
|
||||||
grad_p = self.grad(output_['p'], input_)
|
grad_p = self.grad(output_.extract(['p']), input_)
|
||||||
gradgrad_p_x1 = self.grad(grad_p['x1'], input_)
|
gradgrad_p_x1 = self.grad(grad_p.extract(['x1']), input_)
|
||||||
gradgrad_p_x2 = self.grad(grad_p['x2'], input_)
|
gradgrad_p_x2 = self.grad(grad_p.extract(['x2']), input_)
|
||||||
yd = 2.0
|
yd = 2.0
|
||||||
return output_['y'] - yd - (gradgrad_p_x1['x1'] + gradgrad_p_x2['x2'])
|
return output_.extract(['y']) - yd - (gradgrad_p_x1.extract(['x1']) + gradgrad_p_x2.extract(['x2']))
|
||||||
|
|
||||||
def term2(input_, output_):
|
def term2(input_, output_):
|
||||||
grad_y = self.grad(output_['y'], input_)
|
grad_y = self.grad(output_.extract(['y']), input_)
|
||||||
gradgrad_y_x1 = self.grad(grad_y['x1'], input_)
|
gradgrad_y_x1 = self.grad(grad_y.extract(['x1']), input_)
|
||||||
gradgrad_y_x2 = self.grad(grad_y['x2'], input_)
|
gradgrad_y_x2 = self.grad(grad_y.extract(['x2']), input_)
|
||||||
return - (gradgrad_y_x1['x1'] + gradgrad_y_x2['x2']) - output_['u']
|
return - (gradgrad_y_x1.extract(['x1']) + gradgrad_y_x2.extract(['x2'])) - output_.extract(['u'])
|
||||||
|
|
||||||
def term3(input_, output_):
|
def term3(input_, output_):
|
||||||
return output_['p'] - output_['u']*alpha
|
return output_.extract(['p']) - output_.extract(['u'])*alpha
|
||||||
|
|
||||||
|
|
||||||
def nil_dirichlet(input_, output_):
|
def nil_dirichlet(input_, output_):
|
||||||
y_value = 0.0
|
y_value = 0.0
|
||||||
p_value = 0.0
|
p_value = 0.0
|
||||||
return torch.abs(output_['y'] - y_value) + torch.abs(output_['p'] - p_value)
|
return torch.abs(output_.extract(['y']) - y_value) + torch.abs(output_.extract(['p']) - p_value)
|
||||||
|
|
||||||
self.conditions = {
|
self.conditions = {
|
||||||
'gamma1': {'location': Segment((xmin, ymin), (xmax, ymin)), 'func': nil_dirichlet},
|
'gamma1': {'location': Segment((xmin, ymin), (xmax, ymin)), 'func': nil_dirichlet},
|
||||||
|
|||||||
@@ -14,13 +14,13 @@ class ParametricPoisson(SpatialProblem, ParametricProblem):
|
|||||||
|
|
||||||
def laplace_equation(input_, output_):
|
def laplace_equation(input_, output_):
|
||||||
force_term = torch.exp(
|
force_term = torch.exp(
|
||||||
- 2*(input_['x'] - input_['mu1'])**2 - 2*(input_['y'] -
|
- 2*(input_.extract(['x']) - input_.extract(['mu1']))**2 - 2*(input_.extract(['y']) -
|
||||||
input_['mu2'])**2)
|
input_.extract(['mu2']))**2)
|
||||||
return nabla(output_['u'], input_) - force_term
|
return nabla(output_.extract(['u']), input_) - force_term
|
||||||
|
|
||||||
def nil_dirichlet(input_, output_):
|
def nil_dirichlet(input_, output_):
|
||||||
value = 0.0
|
value = 0.0
|
||||||
return output_['u'] - value
|
return output_.extract(['u']) - value
|
||||||
|
|
||||||
conditions = {
|
conditions = {
|
||||||
'gamma1': Condition(
|
'gamma1': Condition(
|
||||||
|
|||||||
@@ -13,28 +13,24 @@ class Stokes(SpatialProblem):
|
|||||||
domain = Span({'x': [-2, 2], 'y': [-1, 1]})
|
domain = Span({'x': [-2, 2], 'y': [-1, 1]})
|
||||||
|
|
||||||
def momentum(input_, output_):
|
def momentum(input_, output_):
|
||||||
#print(nabla(output_['ux', 'uy'], input_))
|
nabla_ = torch.hstack((LabelTensor(nabla(output_.extract(['ux']), input_), ['x']),
|
||||||
#print(grad(output_['p'], input_))
|
LabelTensor(nabla(output_.extract(['uy']), input_), ['y'])))
|
||||||
nabla_ = LabelTensor.hstack([
|
return - nabla_ + grad(output_.extract(['p']), input_)
|
||||||
LabelTensor(nabla(output_['ux'], input_), ['x']),
|
|
||||||
LabelTensor(nabla(output_['uy'], input_), ['y'])])
|
|
||||||
#return LabelTensor(nabla_.tensor + grad(output_['p'], input_).tensor, ['x', 'y'])
|
|
||||||
return nabla_.tensor + grad(output_['p'], input_).tensor
|
|
||||||
|
|
||||||
def continuity(input_, output_):
|
def continuity(input_, output_):
|
||||||
return div(output_['ux', 'uy'], input_)
|
return div(output_.extract(['ux', 'uy']), input_)
|
||||||
|
|
||||||
def inlet(input_, output_):
|
def inlet(input_, output_):
|
||||||
value = 2.0
|
value = 2 * (1 - input_.extract(['y'])**2)
|
||||||
return output_['ux'] - value
|
return output_.extract(['ux']) - value
|
||||||
|
|
||||||
def outlet(input_, output_):
|
def outlet(input_, output_):
|
||||||
value = 0.0
|
value = 0.0
|
||||||
return output_['p'] - value
|
return output_.extract(['p']) - value
|
||||||
|
|
||||||
def wall(input_, output_):
|
def wall(input_, output_):
|
||||||
value = 0.0
|
value = 0.0
|
||||||
return output_['ux', 'uy'].tensor - value
|
return output_.extract(['ux', 'uy']) - value
|
||||||
|
|
||||||
conditions = {
|
conditions = {
|
||||||
'gamma_top': Condition(Span({'x': [-2, 2], 'y': 1}), wall),
|
'gamma_top': Condition(Span({'x': [-2, 2], 'y': 1}), wall),
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import Softplus
|
from torch.nn import Softplus
|
||||||
|
from pina import Plotter
|
||||||
from pina import PINN as pPINN
|
from pina import PINN as pPINN
|
||||||
from problems.parametric_poisson import ParametricPoisson
|
from problems.parametric_poisson import ParametricPoisson
|
||||||
from pina.model import FeedForward
|
from pina.model import FeedForward
|
||||||
@@ -14,7 +14,7 @@ class myFeature(torch.nn.Module):
|
|||||||
super(myFeature, self).__init__()
|
super(myFeature, self).__init__()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.exp(- 2*(x['x'] - x['mu1'])**2 - 2*(x['y'] - x['mu2'])**2)
|
return torch.exp(- 2*(x.extract(['x']) - x.extract(['mu1']))**2 - 2*(x.extract(['y']) - x.extract(['mu2']))**2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -31,7 +31,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
poisson_problem = ParametricPoisson()
|
poisson_problem = ParametricPoisson()
|
||||||
model = FeedForward(
|
model = FeedForward(
|
||||||
layers=[200, 40, 10],
|
layers=[10, 10, 10],
|
||||||
output_variables=poisson_problem.output_variables,
|
output_variables=poisson_problem.output_variables,
|
||||||
input_variables=poisson_problem.input_variables,
|
input_variables=poisson_problem.input_variables,
|
||||||
func=Softplus,
|
func=Softplus,
|
||||||
@@ -42,15 +42,20 @@ if __name__ == "__main__":
|
|||||||
poisson_problem,
|
poisson_problem,
|
||||||
model,
|
model,
|
||||||
lr=0.0006,
|
lr=0.0006,
|
||||||
regularizer=1e-6,
|
regularizer=1e-6)
|
||||||
lr_accelerate=None)
|
|
||||||
|
|
||||||
if args.s:
|
if args.s:
|
||||||
|
|
||||||
pinn.span_pts(2000, 'random', ['D'])
|
pinn.span_pts(500, n_params=10, mode_spatial='random', locations=['D'])
|
||||||
pinn.span_pts(200, 'random', ['gamma1', 'gamma2', 'gamma3', 'gamma4'])
|
pinn.span_pts(200, n_params=10, mode_spatial='random', locations=['gamma1', 'gamma2', 'gamma3', 'gamma4'])
|
||||||
pinn.train(10000, 10)
|
pinn.plot_pts()
|
||||||
|
pinn.train(10000, 100)
|
||||||
|
with open('param_poisson_history_{}_{}.txt'.format(args.id_run, args.features), 'w') as file_:
|
||||||
|
for i, losses in enumerate(pinn.history):
|
||||||
|
file_.write('{} {}\n'.format(i, sum(losses)))
|
||||||
pinn.save_state('pina.poisson_param')
|
pinn.save_state('pina.poisson_param')
|
||||||
|
|
||||||
else:
|
else:
|
||||||
pinn.load_state('pina.poisson_param')
|
pinn.load_state('pina.poisson_param')
|
||||||
|
plotter = Plotter()
|
||||||
|
plotter.plot(pinn, component='u', parametric=True, params_value=0)
|
||||||
|
|||||||
@@ -52,9 +52,9 @@ if __name__ == "__main__":
|
|||||||
if args.s:
|
if args.s:
|
||||||
|
|
||||||
print(pinn)
|
print(pinn)
|
||||||
pinn.span_pts(20, 'grid', ['gamma1', 'gamma2', 'gamma3', 'gamma4'])
|
pinn.span_pts(20, mode_spatial='grid', locations=['gamma1', 'gamma2', 'gamma3', 'gamma4'])
|
||||||
pinn.span_pts(20, 'grid', ['D'])
|
pinn.span_pts(20, mode_spatial='grid', locations=['D'])
|
||||||
#pinn.plot_pts()
|
pinn.plot_pts()
|
||||||
pinn.train(5000, 100)
|
pinn.train(5000, 100)
|
||||||
with open('poisson_history_{}_{}.txt'.format(args.id_run, args.features), 'w') as file_:
|
with open('poisson_history_{}_{}.txt'.format(args.id_run, args.features), 'w') as file_:
|
||||||
for i, losses in enumerate(pinn.history):
|
for i, losses in enumerate(pinn.history):
|
||||||
@@ -64,6 +64,6 @@ if __name__ == "__main__":
|
|||||||
else:
|
else:
|
||||||
pinn.load_state('pina.poisson')
|
pinn.load_state('pina.poisson')
|
||||||
plotter = Plotter()
|
plotter = Plotter()
|
||||||
plotter.plot(pinn)
|
plotter.plot(pinn, component='u')
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
stokes_problem = Stokes()
|
stokes_problem = Stokes()
|
||||||
model = FeedForward(
|
model = FeedForward(
|
||||||
layers=[40, 20, 20, 10],
|
layers=[10, 10, 10, 10],
|
||||||
output_variables=stokes_problem.output_variables,
|
output_variables=stokes_problem.output_variables,
|
||||||
input_variables=stokes_problem.input_variables,
|
input_variables=stokes_problem.input_variables,
|
||||||
func=Softplus,
|
func=Softplus,
|
||||||
@@ -32,23 +32,24 @@ if __name__ == "__main__":
|
|||||||
model,
|
model,
|
||||||
lr=0.006,
|
lr=0.006,
|
||||||
error_norm='mse',
|
error_norm='mse',
|
||||||
regularizer=1e-8,
|
regularizer=1e-8)
|
||||||
lr_accelerate=None)
|
|
||||||
|
|
||||||
if args.s:
|
if args.s:
|
||||||
|
|
||||||
#pinn.span_pts(200, 'grid', ['gamma_out'])
|
pinn.span_pts(200, mode_spatial='grid', locations=['gamma_top', 'gamma_bot', 'gamma_in', 'gamma_out'])
|
||||||
pinn.span_pts(200, 'grid', ['gamma_top', 'gamma_bot', 'gamma_in', 'gamma_out'])
|
pinn.span_pts(2000, mode_spatial='random', locations=['D'])
|
||||||
pinn.span_pts(2000, 'random', ['D'])
|
pinn.plot_pts()
|
||||||
#plotter = Plotter()
|
|
||||||
#plotter.plot_samples(pinn)
|
|
||||||
pinn.train(10000, 100)
|
pinn.train(10000, 100)
|
||||||
|
with open('stokes_history_{}.txt'.format(args.id_run), 'w') as file_:
|
||||||
|
for i, losses in enumerate(pinn.history):
|
||||||
|
file_.write('{} {}\n'.format(i, sum(losses)))
|
||||||
pinn.save_state('pina.stokes')
|
pinn.save_state('pina.stokes')
|
||||||
|
|
||||||
else:
|
else:
|
||||||
pinn.load_state('pina.stokes')
|
pinn.load_state('pina.stokes')
|
||||||
plotter = Plotter()
|
plotter = Plotter()
|
||||||
plotter.plot_samples(pinn)
|
plotter.plot(pinn, component='ux')
|
||||||
plotter.plot(pinn)
|
plotter.plot(pinn, component='uy')
|
||||||
|
plotter.plot(pinn, component='p')
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,12 @@ from .location import Location
|
|||||||
class Condition:
|
class Condition:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
|
||||||
if len(args) == 2 and not kwargs:
|
if 'data_weight' in kwargs:
|
||||||
|
self.data_weight = kwargs['data_weight']
|
||||||
|
if not 'data_weight' in kwargs:
|
||||||
|
self.data_weight = 1.
|
||||||
|
|
||||||
|
if len(args) == 2:
|
||||||
|
|
||||||
if (isinstance(args[0], torch.Tensor) and
|
if (isinstance(args[0], torch.Tensor) and
|
||||||
isinstance(args[1], torch.Tensor)):
|
isinstance(args[1], torch.Tensor)):
|
||||||
@@ -21,7 +26,7 @@ class Condition:
|
|||||||
else:
|
else:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
elif not args and len(kwargs) == 2:
|
elif not args and len(kwargs) >= 2:
|
||||||
|
|
||||||
if 'input_points' in kwargs and 'output_points' in kwargs:
|
if 'input_points' in kwargs and 'output_points' in kwargs:
|
||||||
self.input_points = kwargs['input_points']
|
self.input_points = kwargs['input_points']
|
||||||
@@ -33,3 +38,4 @@ class Condition:
|
|||||||
raise ValueError
|
raise ValueError
|
||||||
else:
|
else:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,8 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from pina.label_tensor import LabelTensor
|
from pina.label_tensor import LabelTensor
|
||||||
|
import warnings
|
||||||
|
import copy
|
||||||
|
|
||||||
class DeepONet(torch.nn.Module):
|
class DeepONet(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
@@ -18,7 +19,7 @@ class DeepONet(torch.nn.Module):
|
|||||||
<https://doi.org/10.1038/s42256-021-00302-5>`_
|
<https://doi.org/10.1038/s42256-021-00302-5>`_
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, branch_net, trunk_net, output_variables):
|
def __init__(self, branch_net, trunk_net, output_variables, inner_size=10):
|
||||||
"""
|
"""
|
||||||
:param torch.nn.Module branch_net: the neural network to use as branch
|
:param torch.nn.Module branch_net: the neural network to use as branch
|
||||||
model. It has to take as input a :class:`LabelTensor`. The number
|
model. It has to take as input a :class:`LabelTensor`. The number
|
||||||
@@ -43,7 +44,7 @@ class DeepONet(torch.nn.Module):
|
|||||||
(1): Tanh()
|
(1): Tanh()
|
||||||
(2): Linear(in_features=20, out_features=20, bias=True)
|
(2): Linear(in_features=20, out_features=20, bias=True)
|
||||||
(3): Tanh()
|
(3): Tanh()
|
||||||
(4): Linear(in_features=20, out_features=10, bias=True)
|
(4): Linear(in_features=20, out_features=20, bias=True)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
(branch_net): FeedForward(
|
(branch_net): FeedForward(
|
||||||
@@ -53,20 +54,27 @@ class DeepONet(torch.nn.Module):
|
|||||||
(1): Tanh()
|
(1): Tanh()
|
||||||
(2): Linear(in_features=20, out_features=20, bias=True)
|
(2): Linear(in_features=20, out_features=20, bias=True)
|
||||||
(3): Tanh()
|
(3): Tanh()
|
||||||
(4): Linear(in_features=20, out_features=10, bias=True)
|
(4): Linear(in_features=20, out_features=20, bias=True)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.output_variables = output_variables
|
||||||
|
self.output_dimension = len(output_variables)
|
||||||
|
|
||||||
self.trunk_net = trunk_net
|
self.trunk_net = trunk_net
|
||||||
self.branch_net = branch_net
|
self.branch_net = branch_net
|
||||||
|
|
||||||
self.output_variables = output_variables
|
if isinstance(self.branch_net.output_variables, int) and isinstance(self.branch_net.output_variables, int):
|
||||||
self.output_dimension = len(output_variables)
|
if self.branch_net.output_dimension == self.trunk_net.output_dimension:
|
||||||
if self.output_dimension > 1:
|
self.inner_size = self.branch_net.output_dimension
|
||||||
raise NotImplementedError('Vectorial DeepONet to be implemented')
|
else:
|
||||||
|
raise ValueError('Branch and trunk networks have not the same output dimension.')
|
||||||
|
else:
|
||||||
|
warnings.warn("The output dimension of the branch and trunk networks has been imposed by default as 10 for each output variable. To set it change the output_variable of networks to an integer.")
|
||||||
|
self.inner_size = self.output_dimension*inner_size
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_variables(self):
|
def input_variables(self):
|
||||||
@@ -85,7 +93,13 @@ class DeepONet(torch.nn.Module):
|
|||||||
x.extract(self.branch_net.input_variables))
|
x.extract(self.branch_net.input_variables))
|
||||||
trunk_output = self.trunk_net(
|
trunk_output = self.trunk_net(
|
||||||
x.extract(self.trunk_net.input_variables))
|
x.extract(self.trunk_net.input_variables))
|
||||||
|
local_size = int(self.inner_size/self.output_dimension)
|
||||||
output_ = torch.sum(branch_output * trunk_output, dim=1).reshape(-1, 1)
|
for i, var in enumerate(self.output_variables):
|
||||||
|
start = i*local_size
|
||||||
return LabelTensor(output_, self.output_variables)
|
stop = (i+1)*local_size
|
||||||
|
local_output = LabelTensor(torch.sum(branch_output[:, start:stop] * trunk_output[:, start:stop], dim=1).reshape(-1, 1), var)
|
||||||
|
if i==0:
|
||||||
|
output_ = local_output
|
||||||
|
else:
|
||||||
|
output_ = output_.append(local_output)
|
||||||
|
return output_
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ def div(output_, input_):
|
|||||||
else: # really to improve
|
else: # really to improve
|
||||||
a = []
|
a = []
|
||||||
for o in output_.T:
|
for o in output_.T:
|
||||||
a.append(grad(o, input_))
|
a.append(grad(o, input_).extract(['x', 'y']))
|
||||||
div = torch.zeros(output_.shape[0], 1)
|
div = torch.zeros(output_.shape[0], 1)
|
||||||
for i in range(output_.shape[1]):
|
for i in range(output_.shape[1]):
|
||||||
div += a[i][:, i].reshape(-1, 1)
|
div += a[i][:, i].reshape(-1, 1)
|
||||||
@@ -42,4 +42,25 @@ def nabla(output_, input_):
|
|||||||
"""
|
"""
|
||||||
TODO
|
TODO
|
||||||
"""
|
"""
|
||||||
return div(grad(output_, input_), input_)
|
return div(grad(output_, input_).extract(['x', 'y']), input_)
|
||||||
|
|
||||||
|
|
||||||
|
def advection_term(output_, input_):
|
||||||
|
"""
|
||||||
|
TODO
|
||||||
|
"""
|
||||||
|
dimension = len(output_.labels)
|
||||||
|
for i, label in enumerate(output_.labels):
|
||||||
|
# compute u dot gradient in each direction
|
||||||
|
gradient_loc = grad(output_.extract([label]), input_).extract(input_.labels[:dimension])
|
||||||
|
dim_0 = gradient_loc.shape[0]
|
||||||
|
dim_1 = gradient_loc.shape[1]
|
||||||
|
u_dot_grad_loc = torch.bmm(output_.view(dim_0, 1, dim_1),
|
||||||
|
gradient_loc.view(dim_0, dim_1, 1))
|
||||||
|
u_dot_grad_loc = LabelTensor(torch.reshape(u_dot_grad_loc,
|
||||||
|
(u_dot_grad_loc.shape[0], u_dot_grad_loc.shape[1])), [input_.labels[i]])
|
||||||
|
if i==0:
|
||||||
|
adv_term = u_dot_grad_loc
|
||||||
|
else:
|
||||||
|
adv_term = adv_term.append(u_dot_grad_loc)
|
||||||
|
return adv_term
|
||||||
|
|||||||
47
pina/pinn.py
47
pina/pinn.py
@@ -5,6 +5,7 @@ import numpy as np
|
|||||||
from pina.label_tensor import LabelTensor
|
from pina.label_tensor import LabelTensor
|
||||||
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
|
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
|
||||||
|
|
||||||
|
|
||||||
class PINN(object):
|
class PINN(object):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@@ -13,7 +14,6 @@ class PINN(object):
|
|||||||
optimizer=torch.optim.Adam,
|
optimizer=torch.optim.Adam,
|
||||||
lr=0.001,
|
lr=0.001,
|
||||||
regularizer=0.00001,
|
regularizer=0.00001,
|
||||||
data_weight=1.,
|
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device='cpu',
|
device='cpu',
|
||||||
error_norm='mse'):
|
error_norm='mse'):
|
||||||
@@ -53,13 +53,10 @@ class PINN(object):
|
|||||||
self.truth_values = {}
|
self.truth_values = {}
|
||||||
self.input_pts = {}
|
self.input_pts = {}
|
||||||
|
|
||||||
|
|
||||||
self.trained_epoch = 0
|
self.trained_epoch = 0
|
||||||
self.optimizer = optimizer(
|
self.optimizer = optimizer(
|
||||||
self.model.parameters(), lr=lr, weight_decay=regularizer)
|
self.model.parameters(), lr=lr, weight_decay=regularizer)
|
||||||
|
|
||||||
self.data_weight = data_weight
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def problem(self):
|
def problem(self):
|
||||||
return self._problem
|
return self._problem
|
||||||
@@ -96,6 +93,7 @@ class PINN(object):
|
|||||||
'optimizer_state' : self.optimizer.state_dict(),
|
'optimizer_state' : self.optimizer.state_dict(),
|
||||||
'optimizer_class' : self.optimizer.__class__,
|
'optimizer_class' : self.optimizer.__class__,
|
||||||
'history' : self.history,
|
'history' : self.history,
|
||||||
|
'input_points_dict' : self.input_pts,
|
||||||
}
|
}
|
||||||
|
|
||||||
# TODO save also architecture param?
|
# TODO save also architecture param?
|
||||||
@@ -117,21 +115,26 @@ class PINN(object):
|
|||||||
self.trained_epoch = checkpoint['epoch']
|
self.trained_epoch = checkpoint['epoch']
|
||||||
self.history = checkpoint['history']
|
self.history = checkpoint['history']
|
||||||
|
|
||||||
|
self.input_pts = checkpoint['input_points_dict']
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
def span_pts(self, n, mode='grid', locations='all'):
|
def span_pts(self, n_spatial, n_params=0, mode_spatial='grid', mode_param='random', locations='all'):
|
||||||
if locations == 'all':
|
if locations == 'all':
|
||||||
locations = [condition for condition in self.problem.conditions]
|
locations = [condition for condition in self.problem.conditions]
|
||||||
|
|
||||||
for location in locations:
|
for location in locations:
|
||||||
condition = self.problem.conditions[location]
|
condition = self.problem.conditions[location]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
pts = condition.location.sample(n, mode)
|
pts = condition.location.sample(n_spatial, mode_spatial, variables=self.problem.spatial_variables)
|
||||||
|
if n_params != 0:
|
||||||
|
pts_params = condition.location.sample(n_params, mode_param, variables=self.problem.parameters)
|
||||||
|
pts = LabelTensor(pts.repeat(n_params, 1), pts.labels)
|
||||||
|
pts_params = LabelTensor(pts_params.repeat_interleave(n_spatial).reshape((n_spatial*n_params, len(self.problem.parameters))), pts_params.labels)
|
||||||
|
pts = pts.append(pts_params)
|
||||||
except:
|
except:
|
||||||
pts = condition.input_points
|
pts = condition.input_points
|
||||||
|
|
||||||
self.input_pts[location] = pts #.double() # TODO
|
self.input_pts[location] = pts #.double() # TODO
|
||||||
self.input_pts[location] = (
|
self.input_pts[location] = (
|
||||||
self.input_pts[location].to(dtype=self.dtype,
|
self.input_pts[location].to(dtype=self.dtype,
|
||||||
@@ -140,19 +143,16 @@ class PINN(object):
|
|||||||
self.input_pts[location].retain_grad()
|
self.input_pts[location].retain_grad()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def plot_pts(self, locations='all'):
|
def plot_pts(self, locations='all'):
|
||||||
import matplotlib
|
import matplotlib
|
||||||
matplotlib.use('GTK3Agg')
|
# matplotlib.use('GTK3Agg')
|
||||||
if locations == 'all':
|
if locations == 'all':
|
||||||
locations = [condition for condition in self.problem.conditions]
|
locations = [condition for condition in self.problem.conditions]
|
||||||
|
|
||||||
for location in locations:
|
for location in locations:
|
||||||
x, y = self.input_pts[location].tensor.T
|
x = self.input_pts[location].extract(['x'])
|
||||||
#plt.plot(x.detach(), y.detach(), 'o', label=location)
|
y = self.input_pts[location].extract(['y'])
|
||||||
np.savetxt('burgers_{}_pts.txt'.format(location), self.input_pts[location].tensor.detach(), header='x y', delimiter=' ')
|
plt.plot(x.detach(), y.detach(), '.', label=location)
|
||||||
|
# np.savetxt('burgers_{}_pts.txt'.format(location), self.input_pts[location].tensor.detach(), header='x y', delimiter=' ')
|
||||||
|
|
||||||
plt.legend()
|
plt.legend()
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
@@ -169,18 +169,23 @@ class PINN(object):
|
|||||||
for condition_name in self.problem.conditions:
|
for condition_name in self.problem.conditions:
|
||||||
condition = self.problem.conditions[condition_name]
|
condition = self.problem.conditions[condition_name]
|
||||||
pts = self.input_pts[condition_name]
|
pts = self.input_pts[condition_name]
|
||||||
|
|
||||||
predicted = self.model(pts)
|
predicted = self.model(pts)
|
||||||
|
if hasattr(condition, 'function'):
|
||||||
if isinstance(condition.function, list):
|
if isinstance(condition.function, list):
|
||||||
for function in condition.function:
|
for function in condition.function:
|
||||||
residuals = function(pts, predicted)
|
residuals = function(pts, predicted)
|
||||||
losses.append(self._compute_norm(residuals))
|
local_loss = condition.data_weight*self._compute_norm(residuals)
|
||||||
|
losses.append(local_loss)
|
||||||
else:
|
else:
|
||||||
residuals = condition.function(pts, predicted)
|
residuals = condition.function(pts, predicted)
|
||||||
losses.append(self._compute_norm(residuals))
|
local_loss = condition.data_weight*self._compute_norm(residuals)
|
||||||
|
losses.append(local_loss)
|
||||||
|
elif hasattr(condition, 'output_points'):
|
||||||
|
residuals = predicted - condition.output_points
|
||||||
|
local_loss = condition.data_weight*self._compute_norm(residuals)
|
||||||
|
losses.append(local_loss)
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
sum(losses).backward()
|
sum(losses).backward()
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
""" Module for plotting. """
|
""" Module for plotting. """
|
||||||
import matplotlib
|
import matplotlib
|
||||||
matplotlib.use('Qt5Agg')
|
#matplotlib.use('Qt5Agg')
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -32,15 +32,15 @@ class Plotter:
|
|||||||
truth_output = obj.problem.truth_solution(*pts.tensor.T).float()
|
truth_output = obj.problem.truth_solution(*pts.tensor.T).float()
|
||||||
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
|
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
|
||||||
|
|
||||||
cb = getattr(axes[0], method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
|
cb = getattr(axes[0], method)(*grids_container, predicted_output.reshape(res, res).detach())
|
||||||
fig.colorbar(cb, ax=axes[0])
|
fig.colorbar(cb, ax=axes[0])
|
||||||
cb = getattr(axes[1], method)(*grids_container, truth_output.reshape(res, res).detach())
|
cb = getattr(axes[1], method)(*grids_container, truth_output.reshape(res, res).detach())
|
||||||
fig.colorbar(cb, ax=axes[1])
|
fig.colorbar(cb, ax=axes[1])
|
||||||
cb = getattr(axes[2], method)(*grids_container, (truth_output-predicted_output.tensor.float().flatten()).detach().reshape(res, res))
|
cb = getattr(axes[2], method)(*grids_container, (truth_output-predicted_output.float().flatten()).detach().reshape(res, res))
|
||||||
fig.colorbar(cb, ax=axes[2])
|
fig.colorbar(cb, ax=axes[2])
|
||||||
else:
|
else:
|
||||||
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
|
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
|
||||||
cb = getattr(axes, method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
|
cb = getattr(axes, method)(*grids_container, predicted_output.reshape(res, res).detach())
|
||||||
fig.colorbar(cb, ax=axes)
|
fig.colorbar(cb, ax=axes)
|
||||||
|
|
||||||
|
|
||||||
@@ -66,66 +66,50 @@ class Plotter:
|
|||||||
truth_output = obj.problem.truth_solution(*pts.tensor.T).float()
|
truth_output = obj.problem.truth_solution(*pts.tensor.T).float()
|
||||||
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
|
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
|
||||||
|
|
||||||
cb = getattr(axes[0], method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
|
cb = getattr(axes[0], method)(*grids_container, predicted_output.reshape(res, res).detach())
|
||||||
fig.colorbar(cb, ax=axes[0])
|
fig.colorbar(cb, ax=axes[0])
|
||||||
cb = getattr(axes[1], method)(*grids_container, truth_output.reshape(res, res).detach())
|
cb = getattr(axes[1], method)(*grids_container, truth_output.reshape(res, res).detach())
|
||||||
fig.colorbar(cb, ax=axes[1])
|
fig.colorbar(cb, ax=axes[1])
|
||||||
cb = getattr(axes[2], method)(*grids_container, (truth_output-predicted_output.tensor.float().flatten()).detach().reshape(res, res))
|
cb = getattr(axes[2], method)(*grids_container, (truth_output-predicted_output.float().flatten()).detach().reshape(res, res))
|
||||||
fig.colorbar(cb, ax=axes[2])
|
fig.colorbar(cb, ax=axes[2])
|
||||||
else:
|
else:
|
||||||
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
|
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
|
||||||
cb = getattr(axes, method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
|
|
||||||
fig.colorbar(cb, ax=axes)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def plot(self, obj, method='contourf', filename=None):
|
|
||||||
"""
|
|
||||||
"""
|
|
||||||
res = 256
|
|
||||||
pts = obj.problem.domain.sample(res, 'grid')
|
|
||||||
print(pts)
|
|
||||||
grids_container = [
|
|
||||||
pts.tensor[:, 0].reshape(res, res),
|
|
||||||
pts.tensor[:, 1].reshape(res, res),
|
|
||||||
]
|
|
||||||
predicted_output = obj.model(pts)
|
|
||||||
predicted_output = predicted_output['p']
|
|
||||||
|
|
||||||
if hasattr(obj.problem, 'truth_solution'):
|
|
||||||
truth_output = obj.problem.truth_solution(*pts.tensor.T).float()
|
|
||||||
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
|
|
||||||
|
|
||||||
cb = getattr(axes[0], method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
|
|
||||||
fig.colorbar(cb, ax=axes[0])
|
|
||||||
cb = getattr(axes[1], method)(*grids_container, truth_output.reshape(res, res).detach())
|
|
||||||
fig.colorbar(cb, ax=axes[1])
|
|
||||||
cb = getattr(axes[2], method)(*grids_container, (truth_output-predicted_output.tensor.float().flatten()).detach().reshape(res, res))
|
|
||||||
fig.colorbar(cb, ax=axes[2])
|
|
||||||
else:
|
|
||||||
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
|
|
||||||
# cb = getattr(axes, method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
|
|
||||||
cb = getattr(axes, method)(*grids_container, predicted_output.reshape(res, res).detach())
|
cb = getattr(axes, method)(*grids_container, predicted_output.reshape(res, res).detach())
|
||||||
fig.colorbar(cb, ax=axes)
|
fig.colorbar(cb, ax=axes)
|
||||||
|
|
||||||
if filename:
|
|
||||||
plt.savefig(filename)
|
|
||||||
else:
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
def plot(self, obj, method='contourf', filename=None):
|
def plot(self, obj, method='contourf', component='u', parametric=False, params_value=1, filename=None):
|
||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
res = 256
|
res = 256
|
||||||
pts = obj.problem.domain.sample(res, 'grid')
|
pts = obj.problem.domain.sample(res, 'grid')
|
||||||
|
if parametric:
|
||||||
|
pts_params = torch.ones(pts.shape[0], len(obj.problem.parameters), dtype=pts.dtype)*params_value
|
||||||
|
pts_params = LabelTensor(pts_params, obj.problem.parameters)
|
||||||
|
pts = pts.append(pts_params)
|
||||||
grids_container = [
|
grids_container = [
|
||||||
pts[:, 0].reshape(res, res),
|
pts[:, 0].reshape(res, res),
|
||||||
pts[:, 1].reshape(res, res),
|
pts[:, 1].reshape(res, res),
|
||||||
]
|
]
|
||||||
|
ind_dict = {}
|
||||||
|
all_locations = [condition for condition in obj.problem.conditions]
|
||||||
|
for location in all_locations:
|
||||||
|
if hasattr(obj.problem.conditions[location], 'location'):
|
||||||
|
keys_range_ = obj.problem.conditions[location].location.range_.keys()
|
||||||
|
if ('x' in keys_range_) and ('y' in keys_range_):
|
||||||
|
range_x = obj.problem.conditions[location].location.range_['x']
|
||||||
|
range_y = obj.problem.conditions[location].location.range_['y']
|
||||||
|
ind_x = np.where(np.logical_or(pts[:, 0]<range_x[0], pts[:, 0]>range_x[1]))
|
||||||
|
ind_y = np.where(np.logical_or(pts[:, 1]<range_y[0], pts[:, 1]>range_y[1]))
|
||||||
|
ind_to_exclude = np.union1d(ind_x, ind_y)
|
||||||
|
ind_dict[location] = ind_to_exclude
|
||||||
|
import functools
|
||||||
|
from functools import reduce
|
||||||
|
final_inds = reduce(np.intersect1d, ind_dict.values())
|
||||||
predicted_output = obj.model(pts)
|
predicted_output = obj.model(pts)
|
||||||
predicted_output = predicted_output.extract(['u'])
|
predicted_output = predicted_output.extract([component])
|
||||||
|
predicted_output[final_inds] = np.nan
|
||||||
if hasattr(obj.problem, 'truth_solution'):
|
if hasattr(obj.problem, 'truth_solution'):
|
||||||
truth_output = obj.problem.truth_solution(*pts.T).float()
|
truth_output = obj.problem.truth_solution(*pts.T).float()
|
||||||
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
|
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
|
||||||
@@ -142,16 +126,16 @@ class Plotter:
|
|||||||
fig.colorbar(cb, ax=axes)
|
fig.colorbar(cb, ax=axes)
|
||||||
|
|
||||||
if filename:
|
if filename:
|
||||||
|
plt.title('Output {} with parameter {}'.format(component, params_value))
|
||||||
plt.savefig(filename)
|
plt.savefig(filename)
|
||||||
else:
|
else:
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def plot_samples(self, obj):
|
def plot_samples(self, obj):
|
||||||
|
|
||||||
for location in obj.input_pts:
|
for location in obj.input_pts:
|
||||||
plt.plot(*obj.input_pts[location].T.detach(), '.', label=location)
|
pts_x = obj.input_pts[location].extract(['x'])
|
||||||
|
pts_y = obj.input_pts[location].extract(['y'])
|
||||||
|
plt.plot(pts_x.detach(), pts_y.detach(), '.', label=location)
|
||||||
|
|
||||||
plt.legend()
|
plt.legend()
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|||||||
39
pina/span.py
39
pina/span.py
@@ -20,9 +20,27 @@ class Span(Location):
|
|||||||
else:
|
else:
|
||||||
raise TypeError
|
raise TypeError
|
||||||
|
|
||||||
def sample(self, n, mode='random'):
|
def sample(self, n, mode='random', variables='all'):
|
||||||
|
|
||||||
|
if variables=='all':
|
||||||
|
spatial_range_ = list(self.range_.keys())
|
||||||
|
spatial_fixed_ = list(self.fixed_.keys())
|
||||||
bounds = np.array(list(self.range_.values()))
|
bounds = np.array(list(self.range_.values()))
|
||||||
|
fixed = np.array(list(self.fixed_.values()))
|
||||||
|
else:
|
||||||
|
bounds = []
|
||||||
|
spatial_range_ = []
|
||||||
|
spatial_fixed_ = []
|
||||||
|
fixed = []
|
||||||
|
for variable in variables:
|
||||||
|
if variable in self.range_.keys():
|
||||||
|
spatial_range_.append(variable)
|
||||||
|
bounds.append(list(self.range_[variable]))
|
||||||
|
elif variable in self.fixed_.keys():
|
||||||
|
spatial_fixed_.append(variable)
|
||||||
|
fixed.append(int(self.fixed_[variable]))
|
||||||
|
fixed = torch.Tensor(fixed)
|
||||||
|
bounds = np.array(bounds)
|
||||||
if mode == 'random':
|
if mode == 'random':
|
||||||
pts = np.random.uniform(size=(n, bounds.shape[0]))
|
pts = np.random.uniform(size=(n, bounds.shape[0]))
|
||||||
elif mode == 'chebyshev':
|
elif mode == 'chebyshev':
|
||||||
@@ -41,24 +59,25 @@ class Span(Location):
|
|||||||
from scipy.stats import qmc
|
from scipy.stats import qmc
|
||||||
sampler = qmc.LatinHypercube(d=bounds.shape[0])
|
sampler = qmc.LatinHypercube(d=bounds.shape[0])
|
||||||
pts = sampler.random(n)
|
pts = sampler.random(n)
|
||||||
|
|
||||||
# Scale pts
|
# Scale pts
|
||||||
pts *= bounds[:, 1] - bounds[:, 0]
|
pts *= bounds[:, 1] - bounds[:, 0]
|
||||||
pts += bounds[:, 0]
|
pts += bounds[:, 0]
|
||||||
|
|
||||||
pts = pts.astype(np.float32)
|
pts = pts.astype(np.float32)
|
||||||
pts = torch.from_numpy(pts)
|
pts = torch.from_numpy(pts)
|
||||||
|
|
||||||
fixed = torch.Tensor(list(self.fixed_.values()))
|
pts_range_ = LabelTensor(pts, spatial_range_)
|
||||||
pts_fixed_ = torch.ones(pts.shape[0], len(self.fixed_),
|
|
||||||
dtype=pts.dtype) * fixed
|
if not len(spatial_fixed_)==0:
|
||||||
pts_range_ = LabelTensor(pts, list(self.range_.keys()))
|
pts_fixed_ = torch.ones(pts.shape[0], len(spatial_fixed_),
|
||||||
pts_fixed_ = LabelTensor(pts_fixed_, list(self.fixed_.keys()))
|
dtype=pts.dtype) * fixed
|
||||||
|
|
||||||
|
pts_fixed_ = LabelTensor(pts_fixed_, spatial_fixed_)
|
||||||
|
pts_range_ = pts_range_.append(pts_fixed_)
|
||||||
|
|
||||||
if self.fixed_:
|
|
||||||
return pts_range_.append(pts_fixed_)
|
|
||||||
else:
|
|
||||||
return pts_range_
|
return pts_range_
|
||||||
|
|
||||||
|
|
||||||
def meshgrid(self, n):
|
def meshgrid(self, n):
|
||||||
pts = np.array([
|
pts = np.array([
|
||||||
np.linspace(0, 1, n)
|
np.linspace(0, 1, n)
|
||||||
|
|||||||
Reference in New Issue
Block a user