45
examples/problems/stokes.py
Normal file
45
examples/problems/stokes.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pina.problem import SpatialProblem
|
||||
from pina.operators import nabla, grad, div
|
||||
from pina import Condition, Span, LabelTensor
|
||||
|
||||
|
||||
class Stokes(SpatialProblem):
|
||||
|
||||
spatial_variables = ['x', 'y']
|
||||
output_variables = ['ux', 'uy', 'p']
|
||||
domain = Span({'x': [-2, 2], 'y': [-1, 1]})
|
||||
|
||||
def momentum(input_, output_):
|
||||
#print(nabla(output_['ux', 'uy'], input_))
|
||||
#print(grad(output_['p'], input_))
|
||||
nabla_ = LabelTensor.hstack([
|
||||
LabelTensor(nabla(output_['ux'], input_), ['x']),
|
||||
LabelTensor(nabla(output_['uy'], input_), ['y'])])
|
||||
#return LabelTensor(nabla_.tensor + grad(output_['p'], input_).tensor, ['x', 'y'])
|
||||
return nabla_.tensor + grad(output_['p'], input_).tensor
|
||||
|
||||
def continuity(input_, output_):
|
||||
return div(output_['ux', 'uy'], input_)
|
||||
|
||||
def inlet(input_, output_):
|
||||
value = 2.0
|
||||
return output_['ux'] - value
|
||||
|
||||
def outlet(input_, output_):
|
||||
value = 0.0
|
||||
return output_['p'] - value
|
||||
|
||||
def wall(input_, output_):
|
||||
value = 0.0
|
||||
return output_['ux', 'uy'].tensor - value
|
||||
|
||||
conditions = {
|
||||
'gamma_top': Condition(Span({'x': [-2, 2], 'y': 1}), wall),
|
||||
'gamma_bot': Condition(Span({'x': [-2, 2], 'y': -1}), wall),
|
||||
'gamma_out': Condition(Span({'x': 2, 'y': [-1, 1]}), outlet),
|
||||
'gamma_in': Condition(Span({'x': -2, 'y': [-1, 1]}), inlet),
|
||||
'D': Condition(Span({'x': [-2, 2], 'y': [-1, 1]}), [momentum, continuity]),
|
||||
}
|
||||
54
examples/run_stokes.py
Normal file
54
examples/run_stokes.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import argparse
|
||||
import sys
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import ReLU, Tanh, Softplus
|
||||
|
||||
from pina import PINN, LabelTensor, Plotter
|
||||
from pina.model import FeedForward
|
||||
from pina.adaptive_functions import AdaptiveSin, AdaptiveCos, AdaptiveTanh
|
||||
from problems.stokes import Stokes
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run PINA")
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument("-s", "-save", action="store_true")
|
||||
group.add_argument("-l", "-load", action="store_true")
|
||||
parser.add_argument("id_run", help="number of run", type=int)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
stokes_problem = Stokes()
|
||||
model = FeedForward(
|
||||
layers=[40, 20, 20, 10],
|
||||
output_variables=stokes_problem.output_variables,
|
||||
input_variables=stokes_problem.input_variables,
|
||||
func=Softplus,
|
||||
)
|
||||
|
||||
pinn = PINN(
|
||||
stokes_problem,
|
||||
model,
|
||||
lr=0.006,
|
||||
error_norm='mse',
|
||||
regularizer=1e-8,
|
||||
lr_accelerate=None)
|
||||
|
||||
if args.s:
|
||||
|
||||
#pinn.span_pts(200, 'grid', ['gamma_out'])
|
||||
pinn.span_pts(200, 'grid', ['gamma_top', 'gamma_bot', 'gamma_in', 'gamma_out'])
|
||||
pinn.span_pts(2000, 'random', ['D'])
|
||||
#plotter = Plotter()
|
||||
#plotter.plot_samples(pinn)
|
||||
pinn.train(10000, 100)
|
||||
pinn.save_state('pina.stokes')
|
||||
|
||||
else:
|
||||
pinn.load_state('pina.stokes')
|
||||
plotter = Plotter()
|
||||
plotter.plot_samples(pinn)
|
||||
plotter.plot(pinn)
|
||||
|
||||
|
||||
@@ -15,6 +15,9 @@ class Condition:
|
||||
elif isinstance(args[0], Location) and callable(args[1]):
|
||||
self.location = args[0]
|
||||
self.function = args[1]
|
||||
elif isinstance(args[0], Location) and isinstance(args[1], list):
|
||||
self.location = args[0]
|
||||
self.function = args[1]
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
|
||||
@@ -12,6 +12,9 @@ class LabelTensor():
|
||||
self.tensor = x
|
||||
|
||||
def __getitem__(self, key):
|
||||
if isinstance(key, (tuple, list)):
|
||||
indeces = [self.labels.index(k) for k in key]
|
||||
return LabelTensor(self.tensor[:, indeces], [self.labels[idx] for idx in indeces])
|
||||
if key in self.labels:
|
||||
return self.tensor[:, self.labels.index(key)]
|
||||
else:
|
||||
|
||||
@@ -52,6 +52,8 @@ class FeedForward(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
"""
|
||||
"""
|
||||
|
||||
x = x[self.input_variables]
|
||||
nf = len(self.extra_features)
|
||||
if nf == 0:
|
||||
return LabelTensor(self.model(x.tensor), self.output_variables)
|
||||
|
||||
15
pina/pinn.py
15
pina/pinn.py
@@ -65,8 +65,8 @@ class PINN(object):
|
||||
self.model = model
|
||||
self.model.to(dtype=self.dtype, device=self.device)
|
||||
|
||||
self.input_pts = {}
|
||||
self.truth_values = {}
|
||||
self.input_pts = {}
|
||||
|
||||
|
||||
self.trained_epoch = 0
|
||||
@@ -171,13 +171,15 @@ class PINN(object):
|
||||
except:
|
||||
pts = condition.input_points
|
||||
|
||||
print(location, pts)
|
||||
|
||||
self.input_pts[location] = pts
|
||||
print(pts.tensor.shape)
|
||||
self.input_pts[location].tensor.to(dtype=self.dtype, device=self.device)
|
||||
self.input_pts[location].tensor.requires_grad_(True)
|
||||
self.input_pts[location].tensor.retain_grad()
|
||||
|
||||
|
||||
|
||||
def plot_pts(self, locations='all'):
|
||||
import matplotlib
|
||||
matplotlib.use('GTK3Agg')
|
||||
@@ -209,8 +211,13 @@ class PINN(object):
|
||||
|
||||
predicted = self.model(pts)
|
||||
|
||||
residuals = condition.function(pts, predicted)
|
||||
losses.append(self._compute_norm(residuals))
|
||||
if isinstance(condition.function, list):
|
||||
for function in condition.function:
|
||||
residuals = function(pts, predicted)
|
||||
losses.append(self._compute_norm(residuals))
|
||||
else:
|
||||
residuals = condition.function(pts, predicted)
|
||||
losses.append(self._compute_norm(residuals))
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
sum(losses).backward()
|
||||
|
||||
@@ -84,11 +84,13 @@ class Plotter:
|
||||
"""
|
||||
res = 256
|
||||
pts = obj.problem.domain.sample(res, 'grid')
|
||||
print(pts)
|
||||
grids_container = [
|
||||
pts[:, 0].reshape(res, res),
|
||||
pts[:, 1].reshape(res, res),
|
||||
pts.tensor[:, 0].reshape(res, res),
|
||||
pts.tensor[:, 1].reshape(res, res),
|
||||
]
|
||||
predicted_output = obj.model(pts)
|
||||
predicted_output = predicted_output['p']
|
||||
|
||||
if hasattr(obj.problem, 'truth_solution'):
|
||||
truth_output = obj.problem.truth_solution(*pts.tensor.T).float()
|
||||
@@ -102,10 +104,56 @@ class Plotter:
|
||||
fig.colorbar(cb, ax=axes[2])
|
||||
else:
|
||||
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
|
||||
cb = getattr(axes, method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
|
||||
# cb = getattr(axes, method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
|
||||
cb = getattr(axes, method)(*grids_container, predicted_output.reshape(res, res).detach())
|
||||
fig.colorbar(cb, ax=axes)
|
||||
|
||||
if filename:
|
||||
plt.savefig(filename)
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot(self, obj, method='contourf', filename=None):
|
||||
"""
|
||||
"""
|
||||
res = 256
|
||||
pts = obj.problem.domain.sample(res, 'grid')
|
||||
print(pts)
|
||||
grids_container = [
|
||||
pts.tensor[:, 0].reshape(res, res),
|
||||
pts.tensor[:, 1].reshape(res, res),
|
||||
]
|
||||
predicted_output = obj.model(pts)
|
||||
predicted_output = predicted_output['ux']
|
||||
|
||||
if hasattr(obj.problem, 'truth_solution'):
|
||||
truth_output = obj.problem.truth_solution(*pts.tensor.T).float()
|
||||
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
|
||||
|
||||
cb = getattr(axes[0], method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
|
||||
fig.colorbar(cb, ax=axes[0])
|
||||
cb = getattr(axes[1], method)(*grids_container, truth_output.reshape(res, res).detach())
|
||||
fig.colorbar(cb, ax=axes[1])
|
||||
cb = getattr(axes[2], method)(*grids_container, (truth_output-predicted_output.tensor.float().flatten()).detach().reshape(res, res))
|
||||
fig.colorbar(cb, ax=axes[2])
|
||||
else:
|
||||
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
|
||||
# cb = getattr(axes, method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
|
||||
cb = getattr(axes, method)(*grids_container, predicted_output.reshape(res, res).detach())
|
||||
fig.colorbar(cb, ax=axes)
|
||||
|
||||
if filename:
|
||||
plt.savefig(filename)
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
|
||||
|
||||
def plot_samples(self, obj):
|
||||
|
||||
for location in obj.input_pts:
|
||||
plt.plot(*obj.input_pts[location].tensor.T.detach(), '.', label=location)
|
||||
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
@@ -37,6 +37,7 @@ class Span(Location):
|
||||
for _ in range(bounds.shape[0])])
|
||||
grids = np.meshgrid(*pts)
|
||||
pts = np.hstack([grid.reshape(-1, 1) for grid in grids])
|
||||
print(pts)
|
||||
elif mode == 'lh' or mode == 'latin':
|
||||
from scipy.stats import qmc
|
||||
sampler = qmc.LatinHypercube(d=bounds.shape[0])
|
||||
|
||||
Reference in New Issue
Block a user