fix old codes

This commit is contained in:
Your Name
2022-07-11 10:58:15 +02:00
parent 088649e042
commit f526a26050
19 changed files with 385 additions and 457 deletions

View File

@@ -14,13 +14,12 @@ class Burgers1D(TimeDependentProblem, SpatialProblem):
domain = Span({'x': [-1, 1], 't': [0, 1]})
def burger_equation(input_, output_):
grad_u = grad(output_.extract(['u']), input_)
grad_x = grad_u.extract(['x'])
grad_t = grad_u.extract(['t'])
gradgrad_u_x = grad(grad_u.extract(['x']), input_)
du = grad(output_, input_)
ddu = grad(du, input_, components=['dudx'])
return (
grad_u.extract(['t']) + output_.extract(['u'])*grad_u.extract(['x']) -
(0.01/torch.pi)*gradgrad_u_x.extract(['x'])
du.extract(['dudt']) +
output_.extract(['u'])*du.extract(['dudx']) -
(0.01/torch.pi)*ddu.extract(['ddudxdx'])
)
def nil_dirichlet(input_, output_):

View File

@@ -1,52 +1,59 @@
import numpy as np
import torch
from pina.problem import Problem
from pina.segment import Segment
from pina.cube import Cube
from pina.problem2d import Problem2D
xmin, xmax, ymin, ymax = -1, 1, -1, 1
class ParametricEllipticOptimalControl(Problem2D):
def __init__(self, alpha=1):
def term1(input_, param_, output_):
grad_p = self.grad(output_['p'], input_)
gradgrad_p_x1 = self.grad(grad_p['x1'], input_)
gradgrad_p_x2 = self.grad(grad_p['x2'], input_)
#print('mu', input_['mu'])
return output_['y'] - input_['mu'] - (gradgrad_p_x1['x1'] + gradgrad_p_x2['x2'])
def term2(input_, param_, output_):
grad_y = self.grad(output_['y'], input_)
gradgrad_y_x1 = self.grad(grad_y['x1'], input_)
gradgrad_y_x2 = self.grad(grad_y['x2'], input_)
return - (gradgrad_y_x1['x1'] + gradgrad_y_x2['x2']) - output_['u_param']
def term3(input_, param_, output_):
#print('a', input_['alpha'], output_['p'], output_['u_param'])
return output_['p'] - output_['u_param']*input_['alpha']
from pina import Span, Condition
from pina.problem import SpatialProblem, ParametricProblem
from pina.operators import grad, nabla
def nil_dirichlet(input_, param_, output_):
y_value = 0.0
p_value = 0.0
return torch.abs(output_['y'] - y_value) + torch.abs(output_['p'] - p_value)
class ParametricEllipticOptimalControl(SpatialProblem, ParametricProblem):
self.conditions = {
'gamma1': {'location': Segment((xmin, ymin), (xmax, ymin)), 'func': nil_dirichlet},
'gamma2': {'location': Segment((xmax, ymin), (xmax, ymax)), 'func': nil_dirichlet},
'gamma3': {'location': Segment((xmax, ymax), (xmin, ymax)), 'func': nil_dirichlet},
'gamma4': {'location': Segment((xmin, ymax), (xmin, ymin)), 'func': nil_dirichlet},
'D1': {'location': Cube([[xmin, xmax], [ymin, ymax]]), 'func': [term1, term2]},
#'D2': {'location': Cube([[0, 1], [0, 1]]), 'func': term2},
#'D3': {'location': Cube([[0, 1], [0, 1]]), 'func': term3}
}
xmin, xmax, ymin, ymax = -1, 1, -1, 1
amin, amax = 0.0001, 1
mumin, mumax = 0.5, 3
mu_range = [mumin, mumax]
a_range = [amin, amax]
x_range = [xmin, xmax]
y_range = [ymin, ymax]
self.input_variables = ['x1', 'x2']
self.output_variables = ['u', 'p', 'y']
self.parameters = ['mu', 'alpha']
self.spatial_domain = Cube([[xmin, xmax], [xmin, xmax]])
self.parameter_domain = np.array([[0.5, 3], [0.0001, 1]])
spatial_variables = ['x1', 'x2']
parameters = ['mu', 'alpha']
output_variables = ['u', 'p', 'y']
domain = Span({
'x1': x_range, 'x2': y_range, 'mu': mu_range, 'alpha': a_range})
def term1(input_, output_):
laplace_p = nabla(output_, input_, components=['p'], d=['x1', 'x2'])
return output_.extract(['y']) - input_.extract(['mu']) - laplace_p
def term2(input_, output_):
laplace_y = nabla(output_, input_, components=['y'], d=['x1', 'x2'])
return - laplace_y - output_.extract(['u_param'])
def state_dirichlet(input_, output_):
y_exp = 0.0
return output_.extract(['y']) - y_exp
def adj_dirichlet(input_, output_):
p_exp = 0.0
return output_.extract(['p']) - p_exp
conditions = {
'gamma1': Condition(
Span({'x1': x_range, 'x2': 1, 'mu': mu_range, 'alpha': a_range}),
[state_dirichlet, adj_dirichlet]),
'gamma2': Condition(
Span({'x1': x_range, 'x2': -1, 'mu': mu_range, 'alpha': a_range}),
[state_dirichlet, adj_dirichlet]),
'gamma3': Condition(
Span({'x1': 1, 'x2': y_range, 'mu': mu_range, 'alpha': a_range}),
[state_dirichlet, adj_dirichlet]),
'gamma4': Condition(
Span({'x1': -1, 'x2': y_range, 'mu': mu_range, 'alpha': a_range}),
[state_dirichlet, adj_dirichlet]),
'D': Condition(
Span({'x1': x_range, 'x2': y_range,
'mu': mu_range, 'alpha': a_range}),
[term1, term2]),
}

View File

@@ -14,8 +14,8 @@ class ParametricPoisson(SpatialProblem, ParametricProblem):
def laplace_equation(input_, output_):
force_term = torch.exp(
- 2*(input_.extract(['x']) - input_.extract(['mu1']))**2 - 2*(input_.extract(['y']) -
input_.extract(['mu2']))**2)
- 2*(input_.extract(['x']) - input_.extract(['mu1']))**2
- 2*(input_.extract(['y']) - input_.extract(['mu2']))**2)
return nabla(output_.extract(['u']), input_) - force_term
def nil_dirichlet(input_, output_):

View File

@@ -23,11 +23,11 @@ class Poisson(SpatialProblem):
return output_.extract(['u']) - value
conditions = {
'gamma1': Condition(Span({'x': [-1, 1], 'y': 1}), nil_dirichlet),
'gamma2': Condition(Span({'x': [-1, 1], 'y': -1}), nil_dirichlet),
'gamma3': Condition(Span({'x': 1, 'y': [-1, 1]}), nil_dirichlet),
'gamma4': Condition(Span({'x': -1, 'y': [-1, 1]}), nil_dirichlet),
'D': Condition(Span({'x': [-1, 1], 'y': [-1, 1]}), laplace_equation),
'gamma1': Condition(Span({'x': [0, 1], 'y': 1}), nil_dirichlet),
'gamma2': Condition(Span({'x': [0, 1], 'y': 0}), nil_dirichlet),
'gamma3': Condition(Span({'x': 1, 'y': [0, 1]}), nil_dirichlet),
'gamma4': Condition(Span({'x': 0, 'y': [0, 1]}), nil_dirichlet),
'D': Condition(Span({'x': [0, 1], 'y': [0, 1]}), laplace_equation),
}
def poisson_sol(self, x, y):