fix examples (#21)
This commit is contained in:
@@ -1,15 +1,12 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from torch.nn import ReLU, Tanh, Softplus, PReLU
|
from torch.nn import Softplus
|
||||||
|
|
||||||
from pina.problem import SpatialProblem, ParametricProblem
|
from pina.problem import SpatialProblem
|
||||||
from pina.operators import nabla, grad, div
|
from pina.operators import grad
|
||||||
from pina.model import FeedForward, DeepONet
|
from pina.model import FeedForward
|
||||||
from pina import Condition, Span, LabelTensor, Plotter, PINN
|
from pina import Condition, Span, Plotter, PINN
|
||||||
|
|
||||||
import matplotlib
|
|
||||||
matplotlib.use('Qt5Agg')
|
|
||||||
|
|
||||||
|
|
||||||
class FirstOrderODE(SpatialProblem):
|
class FirstOrderODE(SpatialProblem):
|
||||||
|
|||||||
@@ -8,10 +8,9 @@ from pina.span import Span
|
|||||||
|
|
||||||
class Burgers1D(TimeDependentProblem, SpatialProblem):
|
class Burgers1D(TimeDependentProblem, SpatialProblem):
|
||||||
|
|
||||||
spatial_variables = ['x']
|
|
||||||
temporal_variable = ['t']
|
|
||||||
output_variables = ['u']
|
output_variables = ['u']
|
||||||
domain = Span({'x': [-1, 1], 't': [0, 1]})
|
spatial_domain = Span({'x': [-1, 1]})
|
||||||
|
temporal_domain = Span({'t': [0, 1]})
|
||||||
|
|
||||||
def burger_equation(input_, output_):
|
def burger_equation(input_, output_):
|
||||||
du = grad(output_, input_)
|
du = grad(output_, input_)
|
||||||
|
|||||||
@@ -16,11 +16,9 @@ class ParametricEllipticOptimalControl(SpatialProblem, ParametricProblem):
|
|||||||
x_range = [xmin, xmax]
|
x_range = [xmin, xmax]
|
||||||
y_range = [ymin, ymax]
|
y_range = [ymin, ymax]
|
||||||
|
|
||||||
spatial_variables = ['x1', 'x2']
|
|
||||||
parameters = ['mu', 'alpha']
|
|
||||||
output_variables = ['u', 'p', 'y']
|
output_variables = ['u', 'p', 'y']
|
||||||
domain = Span({
|
spatial_domain = Span({'x1': x_range, 'x2': y_range})
|
||||||
'x1': x_range, 'x2': y_range, 'mu': mu_range, 'alpha': a_range})
|
parameter_domain = Span({'mu': mu_range, 'alpha': a_range})
|
||||||
|
|
||||||
|
|
||||||
def term1(input_, output_):
|
def term1(input_, output_):
|
||||||
|
|||||||
@@ -7,10 +7,9 @@ from pina import Span, Condition
|
|||||||
|
|
||||||
class ParametricPoisson(SpatialProblem, ParametricProblem):
|
class ParametricPoisson(SpatialProblem, ParametricProblem):
|
||||||
|
|
||||||
spatial_variables = ['x', 'y']
|
|
||||||
parameters = ['mu1', 'mu2']
|
|
||||||
output_variables = ['u']
|
output_variables = ['u']
|
||||||
domain = Span({'x': [-1, 1], 'y': [-1, 1]})
|
spatial_domain = Span({'x': [-1, 1], 'y': [-1, 1]})
|
||||||
|
parameter_domain = Span({'mu1': [-1, 1], 'mu2': [-1, 1]})
|
||||||
|
|
||||||
def laplace_equation(input_, output_):
|
def laplace_equation(input_, output_):
|
||||||
force_term = torch.exp(
|
force_term = torch.exp(
|
||||||
|
|||||||
@@ -8,9 +8,8 @@ from pina import Condition, Span
|
|||||||
|
|
||||||
class Poisson(SpatialProblem):
|
class Poisson(SpatialProblem):
|
||||||
|
|
||||||
spatial_variables = ['x', 'y']
|
|
||||||
output_variables = ['u']
|
output_variables = ['u']
|
||||||
domain = Span({'x': [0, 1], 'y': [0, 1]})
|
spatial_domain = Span({'x': [0, 1], 'y': [0, 1]})
|
||||||
|
|
||||||
def laplace_equation(input_, output_):
|
def laplace_equation(input_, output_):
|
||||||
force_term = (torch.sin(input_.extract(['x'])*torch.pi) *
|
force_term = (torch.sin(input_.extract(['x'])*torch.pi) *
|
||||||
@@ -30,7 +29,11 @@ class Poisson(SpatialProblem):
|
|||||||
'D': Condition(Span({'x': [0, 1], 'y': [0, 1]}), laplace_equation),
|
'D': Condition(Span({'x': [0, 1], 'y': [0, 1]}), laplace_equation),
|
||||||
}
|
}
|
||||||
|
|
||||||
def poisson_sol(self, x, y):
|
def poisson_sol(self, pts):
|
||||||
return -(np.sin(x*np.pi)*np.sin(y*np.pi))/(2*np.pi**2)
|
return -(
|
||||||
|
torch.sin(pts.extract(['x'])*torch.pi)*
|
||||||
|
torch.sin(pts.extract(['y'])*torch.pi)
|
||||||
|
)/(2*torch.pi**2)
|
||||||
|
#return -(np.sin(x*np.pi)*np.sin(y*np.pi))/(2*np.pi**2)
|
||||||
|
|
||||||
truth_solution = poisson_sol
|
truth_solution = poisson_sol
|
||||||
|
|||||||
@@ -8,9 +8,8 @@ from pina import Condition, Span, LabelTensor
|
|||||||
|
|
||||||
class Stokes(SpatialProblem):
|
class Stokes(SpatialProblem):
|
||||||
|
|
||||||
spatial_variables = ['x', 'y']
|
|
||||||
output_variables = ['ux', 'uy', 'p']
|
output_variables = ['ux', 'uy', 'p']
|
||||||
domain = Span({'x': [-2, 2], 'y': [-1, 1]})
|
spatial_domain = Span({'x': [-2, 2], 'y': [-1, 1]})
|
||||||
|
|
||||||
def momentum(input_, output_):
|
def momentum(input_, output_):
|
||||||
nabla_ = torch.hstack((LabelTensor(nabla(output_.extract(['ux']), input_), ['x']),
|
nabla_ = torch.hstack((LabelTensor(nabla(output_.extract(['ux']), input_), ['x']),
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from torch.nn import Softplus
|
|||||||
|
|
||||||
from pina import PINN, Plotter, LabelTensor
|
from pina import PINN, Plotter, LabelTensor
|
||||||
from pina.model import FeedForward
|
from pina.model import FeedForward
|
||||||
from burger2 import Burgers1D
|
from problems.burgers import Burgers1D
|
||||||
|
|
||||||
|
|
||||||
class myFeature(torch.nn.Module):
|
class myFeature(torch.nn.Module):
|
||||||
@@ -49,8 +49,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
if args.s:
|
if args.s:
|
||||||
pinn.span_pts(
|
pinn.span_pts(
|
||||||
{'n': 200, 'mode': 'random', 'variables': 't'},
|
{'n': 200, 'mode': 'grid', 'variables': 't'},
|
||||||
{'n': 20, 'mode': 'random', 'variables': 'x'},
|
{'n': 20, 'mode': 'grid', 'variables': 'x'},
|
||||||
locations=['D'])
|
locations=['D'])
|
||||||
pinn.span_pts(150, 'random', location=['gamma1', 'gamma2', 't0'])
|
pinn.span_pts(150, 'random', location=['gamma1', 'gamma2', 't0'])
|
||||||
pinn.train(5000, 100)
|
pinn.train(5000, 100)
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import ReLU, Tanh, Softplus
|
from torch.nn import Softplus
|
||||||
|
|
||||||
|
from pina import PINN, LabelTensor, Plotter
|
||||||
|
from pina.model import MultiFeedForward
|
||||||
|
from problems.parametric_elliptic_optimal_control_alpha_variable import (
|
||||||
|
ParametricEllipticOptimalControl)
|
||||||
|
|
||||||
from pina import PINN, LabelTensor
|
|
||||||
from parametric_elliptic_optimal_control_alpha_variable2 import ParametricEllipticOptimalControl
|
|
||||||
from pina.model import MultiFeedForward, FeedForward
|
|
||||||
|
|
||||||
class myFeature(torch.nn.Module):
|
class myFeature(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
@@ -31,7 +33,6 @@ class CustomMultiDFF(MultiFeedForward):
|
|||||||
return out.append(p)
|
return out.append(p)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Run PINA")
|
parser = argparse.ArgumentParser(description="Run PINA")
|
||||||
@@ -71,8 +72,12 @@ if __name__ == "__main__":
|
|||||||
{'variables': ['mu', 'alpha'], 'mode': 'grid', 'n': 5},
|
{'variables': ['mu', 'alpha'], 'mode': 'grid', 'n': 5},
|
||||||
locations=['gamma1', 'gamma2', 'gamma3', 'gamma4'])
|
locations=['gamma1', 'gamma2', 'gamma3', 'gamma4'])
|
||||||
|
|
||||||
pinn.train(10000, 20)
|
pinn.train(1000, 20)
|
||||||
pinn.save_state('pina.ocp')
|
pinn.save_state('pina.ocp')
|
||||||
|
|
||||||
else:
|
else:
|
||||||
pinn.load_state('pina.ocp')
|
pinn.load_state('pina.ocp')
|
||||||
|
plotter = Plotter()
|
||||||
|
plotter.plot(pinn, components='y', fixed_variables={'alpha': 0.01, 'mu': 1.0})
|
||||||
|
plotter.plot(pinn, components='u_param', fixed_variables={'alpha': 0.01, 'mu': 1.0})
|
||||||
|
plotter.plot(pinn, components='p', fixed_variables={'alpha': 0.01, 'mu': 1.0})
|
||||||
|
|||||||
@@ -2,8 +2,8 @@ import argparse
|
|||||||
import torch
|
import torch
|
||||||
from torch.nn import Softplus
|
from torch.nn import Softplus
|
||||||
from pina import Plotter, LabelTensor, PINN
|
from pina import Plotter, LabelTensor, PINN
|
||||||
from parametric_poisson2 import ParametricPoisson
|
|
||||||
from pina.model import FeedForward
|
from pina.model import FeedForward
|
||||||
|
from problems.parametric_poisson import ParametricPoisson
|
||||||
|
|
||||||
|
|
||||||
class myFeature(torch.nn.Module):
|
class myFeature(torch.nn.Module):
|
||||||
@@ -43,11 +43,7 @@ if __name__ == "__main__":
|
|||||||
extra_features=feat
|
extra_features=feat
|
||||||
)
|
)
|
||||||
|
|
||||||
pinn = PINN(
|
pinn = PINN(poisson_problem, model, lr=0.006, regularizer=1e-6)
|
||||||
poisson_problem,
|
|
||||||
model,
|
|
||||||
lr=0.006,
|
|
||||||
regularizer=1e-6)
|
|
||||||
|
|
||||||
if args.s:
|
if args.s:
|
||||||
|
|
||||||
@@ -65,4 +61,5 @@ if __name__ == "__main__":
|
|||||||
else:
|
else:
|
||||||
pinn.load_state('pina.poisson_param')
|
pinn.load_state('pina.poisson_param')
|
||||||
plotter = Plotter()
|
plotter = Plotter()
|
||||||
plotter.plot(pinn, component='u', parametric=True, params_value=0)
|
plotter.plot(pinn, fixed_variables={'mu1': 0, 'mu2': 1}, levels=21)
|
||||||
|
plotter.plot(pinn, fixed_variables={'mu1': 1, 'mu2': -1}, levels=21)
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from torch.nn import ReLU, Tanh, Softplus
|
|||||||
from pina import PINN, LabelTensor, Plotter
|
from pina import PINN, LabelTensor, Plotter
|
||||||
from pina.model import FeedForward
|
from pina.model import FeedForward
|
||||||
from pina.adaptive_functions import AdaptiveSin, AdaptiveCos, AdaptiveTanh
|
from pina.adaptive_functions import AdaptiveSin, AdaptiveCos, AdaptiveTanh
|
||||||
from poisson2 import Poisson
|
from problems.poisson import Poisson
|
||||||
|
|
||||||
|
|
||||||
class myFeature(torch.nn.Module):
|
class myFeature(torch.nn.Module):
|
||||||
@@ -61,6 +61,4 @@ if __name__ == "__main__":
|
|||||||
else:
|
else:
|
||||||
pinn.load_state('pina.poisson')
|
pinn.load_state('pina.poisson')
|
||||||
plotter = Plotter()
|
plotter = Plotter()
|
||||||
plotter.plot(pinn, component='u')
|
plotter.plot(pinn)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -36,9 +36,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
if args.s:
|
if args.s:
|
||||||
|
|
||||||
pinn.span_pts(200, mode_spatial='grid', locations=['gamma_top', 'gamma_bot', 'gamma_in', 'gamma_out'])
|
pinn.span_pts(200, 'grid', locations=['gamma_top', 'gamma_bot', 'gamma_in', 'gamma_out'])
|
||||||
pinn.span_pts(2000, mode_spatial='random', locations=['D'])
|
pinn.span_pts(2000, 'random', locations=['D'])
|
||||||
pinn.plot_pts()
|
|
||||||
pinn.train(10000, 100)
|
pinn.train(10000, 100)
|
||||||
with open('stokes_history_{}.txt'.format(args.id_run), 'w') as file_:
|
with open('stokes_history_{}.txt'.format(args.id_run), 'w') as file_:
|
||||||
for i, losses in enumerate(pinn.history):
|
for i, losses in enumerate(pinn.history):
|
||||||
@@ -48,8 +47,8 @@ if __name__ == "__main__":
|
|||||||
else:
|
else:
|
||||||
pinn.load_state('pina.stokes')
|
pinn.load_state('pina.stokes')
|
||||||
plotter = Plotter()
|
plotter = Plotter()
|
||||||
plotter.plot(pinn, component='ux')
|
plotter.plot(pinn, components='ux')
|
||||||
plotter.plot(pinn, component='uy')
|
plotter.plot(pinn, components='uy')
|
||||||
plotter.plot(pinn, component='p')
|
plotter.plot(pinn, components='p')
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -50,22 +50,23 @@ class Plotter:
|
|||||||
plt.legend()
|
plt.legend()
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
def _1d_plot(self, pts, pred, method, truth_solution=None):
|
def _1d_plot(self, pts, pred, method, truth_solution=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 8))
|
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 8))
|
||||||
|
|
||||||
ax.plot(pts, pred.detach())
|
ax.plot(pts, pred.detach(), **kwargs)
|
||||||
|
|
||||||
if truth_solution:
|
if truth_solution:
|
||||||
truth_output = truth_solution(pts).float()
|
truth_output = truth_solution(pts).float()
|
||||||
ax.plot(pts, truth_output.detach())
|
ax.plot(pts, truth_output.detach(), **kwargs)
|
||||||
|
|
||||||
plt.xlabel(pts.labels[0])
|
plt.xlabel(pts.labels[0])
|
||||||
plt.ylabel(pred.labels[0])
|
plt.ylabel(pred.labels[0])
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
def _2d_plot(self, pts, pred, v, res, method, truth_solution=None):
|
def _2d_plot(self, pts, pred, v, res, method, truth_solution=None,
|
||||||
|
**kwargs):
|
||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -73,27 +74,29 @@ class Plotter:
|
|||||||
|
|
||||||
pred_output = pred.reshape(res, res)
|
pred_output = pred.reshape(res, res)
|
||||||
if truth_solution:
|
if truth_solution:
|
||||||
truth_output = truth_solution(*pts.T).float().reshape(res, res)
|
truth_output = truth_solution(pts).float().reshape(res, res)
|
||||||
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
|
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
|
||||||
|
|
||||||
cb = getattr(ax[0], method)(*grids, pred_output.detach())
|
cb = getattr(ax[0], method)(*grids, pred_output.detach(), **kwargs)
|
||||||
fig.colorbar(cb, ax=ax[0])
|
fig.colorbar(cb, ax=ax[0])
|
||||||
cb = getattr(ax[1], method)(*grids, truth_output.detach())
|
cb = getattr(ax[1], method)(*grids, truth_output.detach(), **kwargs)
|
||||||
fig.colorbar(cb, ax=ax[1])
|
fig.colorbar(cb, ax=ax[1])
|
||||||
cb = getattr(ax[2], method)(*grids,
|
cb = getattr(ax[2], method)(*grids,
|
||||||
(truth_output-pred_output).detach())
|
(truth_output-pred_output).detach(),
|
||||||
|
**kwargs)
|
||||||
fig.colorbar(cb, ax=ax[2])
|
fig.colorbar(cb, ax=ax[2])
|
||||||
else:
|
else:
|
||||||
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
|
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
|
||||||
cb = getattr(ax, method)(*grids, pred_output.detach())
|
cb = getattr(ax, method)(*grids, pred_output.detach(), **kwargs)
|
||||||
fig.colorbar(cb, ax=ax)
|
fig.colorbar(cb, ax=ax)
|
||||||
|
|
||||||
|
|
||||||
def plot(self, pinn, components, fixed_variables={}, method='contourf',
|
def plot(self, pinn, components=None, fixed_variables={}, method='contourf',
|
||||||
res=256, filename=None):
|
res=256, filename=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
|
if components is None:
|
||||||
|
components = [pinn.problem.output_variables]
|
||||||
v = [
|
v = [
|
||||||
var for var in pinn.problem.input_variables
|
var for var in pinn.problem.input_variables
|
||||||
if var not in fixed_variables.keys()
|
if var not in fixed_variables.keys()
|
||||||
@@ -112,10 +115,11 @@ class Plotter:
|
|||||||
|
|
||||||
truth_solution = getattr(pinn.problem, 'truth_solution', None)
|
truth_solution = getattr(pinn.problem, 'truth_solution', None)
|
||||||
if len(v) == 1:
|
if len(v) == 1:
|
||||||
self._1d_plot(pts, predicted_output, method, truth_solution)
|
self._1d_plot(pts, predicted_output, method, truth_solution,
|
||||||
|
**kwargs)
|
||||||
elif len(v) == 2:
|
elif len(v) == 2:
|
||||||
self._2d_plot(pts, predicted_output, v, res, method,
|
self._2d_plot(pts, predicted_output, v, res, method,
|
||||||
truth_solution)
|
truth_solution, **kwargs)
|
||||||
|
|
||||||
if filename:
|
if filename:
|
||||||
plt.title('Output {} with parameter {}'.format(components,
|
plt.title('Output {} with parameter {}'.format(components,
|
||||||
|
|||||||
@@ -39,8 +39,6 @@ class AbstractProblem(metaclass=ABCMeta):
|
|||||||
else:
|
else:
|
||||||
raise RuntimeError('different domains')
|
raise RuntimeError('different domains')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@input_variables.setter
|
@input_variables.setter
|
||||||
def input_variables(self, variables):
|
def input_variables(self, variables):
|
||||||
raise RuntimeError
|
raise RuntimeError
|
||||||
|
|||||||
@@ -10,5 +10,5 @@ class TimeDependentProblem(AbstractProblem):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def temporal_variables(self):
|
def temporal_variable(self):
|
||||||
return self.temporal_domain.variables
|
return self.temporal_domain.variables
|
||||||
|
|||||||
Reference in New Issue
Block a user