fix examples (#21)
This commit is contained in:
@@ -1,15 +1,12 @@
|
||||
import argparse
|
||||
import torch
|
||||
|
||||
from torch.nn import ReLU, Tanh, Softplus, PReLU
|
||||
from torch.nn import Softplus
|
||||
|
||||
from pina.problem import SpatialProblem, ParametricProblem
|
||||
from pina.operators import nabla, grad, div
|
||||
from pina.model import FeedForward, DeepONet
|
||||
from pina import Condition, Span, LabelTensor, Plotter, PINN
|
||||
|
||||
import matplotlib
|
||||
matplotlib.use('Qt5Agg')
|
||||
from pina.problem import SpatialProblem
|
||||
from pina.operators import grad
|
||||
from pina.model import FeedForward
|
||||
from pina import Condition, Span, Plotter, PINN
|
||||
|
||||
|
||||
class FirstOrderODE(SpatialProblem):
|
||||
|
||||
@@ -8,10 +8,9 @@ from pina.span import Span
|
||||
|
||||
class Burgers1D(TimeDependentProblem, SpatialProblem):
|
||||
|
||||
spatial_variables = ['x']
|
||||
temporal_variable = ['t']
|
||||
output_variables = ['u']
|
||||
domain = Span({'x': [-1, 1], 't': [0, 1]})
|
||||
spatial_domain = Span({'x': [-1, 1]})
|
||||
temporal_domain = Span({'t': [0, 1]})
|
||||
|
||||
def burger_equation(input_, output_):
|
||||
du = grad(output_, input_)
|
||||
|
||||
@@ -16,11 +16,9 @@ class ParametricEllipticOptimalControl(SpatialProblem, ParametricProblem):
|
||||
x_range = [xmin, xmax]
|
||||
y_range = [ymin, ymax]
|
||||
|
||||
spatial_variables = ['x1', 'x2']
|
||||
parameters = ['mu', 'alpha']
|
||||
output_variables = ['u', 'p', 'y']
|
||||
domain = Span({
|
||||
'x1': x_range, 'x2': y_range, 'mu': mu_range, 'alpha': a_range})
|
||||
spatial_domain = Span({'x1': x_range, 'x2': y_range})
|
||||
parameter_domain = Span({'mu': mu_range, 'alpha': a_range})
|
||||
|
||||
|
||||
def term1(input_, output_):
|
||||
|
||||
@@ -7,10 +7,9 @@ from pina import Span, Condition
|
||||
|
||||
class ParametricPoisson(SpatialProblem, ParametricProblem):
|
||||
|
||||
spatial_variables = ['x', 'y']
|
||||
parameters = ['mu1', 'mu2']
|
||||
output_variables = ['u']
|
||||
domain = Span({'x': [-1, 1], 'y': [-1, 1]})
|
||||
spatial_domain = Span({'x': [-1, 1], 'y': [-1, 1]})
|
||||
parameter_domain = Span({'mu1': [-1, 1], 'mu2': [-1, 1]})
|
||||
|
||||
def laplace_equation(input_, output_):
|
||||
force_term = torch.exp(
|
||||
|
||||
@@ -8,9 +8,8 @@ from pina import Condition, Span
|
||||
|
||||
class Poisson(SpatialProblem):
|
||||
|
||||
spatial_variables = ['x', 'y']
|
||||
output_variables = ['u']
|
||||
domain = Span({'x': [0, 1], 'y': [0, 1]})
|
||||
spatial_domain = Span({'x': [0, 1], 'y': [0, 1]})
|
||||
|
||||
def laplace_equation(input_, output_):
|
||||
force_term = (torch.sin(input_.extract(['x'])*torch.pi) *
|
||||
@@ -30,7 +29,11 @@ class Poisson(SpatialProblem):
|
||||
'D': Condition(Span({'x': [0, 1], 'y': [0, 1]}), laplace_equation),
|
||||
}
|
||||
|
||||
def poisson_sol(self, x, y):
|
||||
return -(np.sin(x*np.pi)*np.sin(y*np.pi))/(2*np.pi**2)
|
||||
def poisson_sol(self, pts):
|
||||
return -(
|
||||
torch.sin(pts.extract(['x'])*torch.pi)*
|
||||
torch.sin(pts.extract(['y'])*torch.pi)
|
||||
)/(2*torch.pi**2)
|
||||
#return -(np.sin(x*np.pi)*np.sin(y*np.pi))/(2*np.pi**2)
|
||||
|
||||
truth_solution = poisson_sol
|
||||
|
||||
@@ -8,9 +8,8 @@ from pina import Condition, Span, LabelTensor
|
||||
|
||||
class Stokes(SpatialProblem):
|
||||
|
||||
spatial_variables = ['x', 'y']
|
||||
output_variables = ['ux', 'uy', 'p']
|
||||
domain = Span({'x': [-2, 2], 'y': [-1, 1]})
|
||||
spatial_domain = Span({'x': [-2, 2], 'y': [-1, 1]})
|
||||
|
||||
def momentum(input_, output_):
|
||||
nabla_ = torch.hstack((LabelTensor(nabla(output_.extract(['ux']), input_), ['x']),
|
||||
|
||||
@@ -4,7 +4,7 @@ from torch.nn import Softplus
|
||||
|
||||
from pina import PINN, Plotter, LabelTensor
|
||||
from pina.model import FeedForward
|
||||
from burger2 import Burgers1D
|
||||
from problems.burgers import Burgers1D
|
||||
|
||||
|
||||
class myFeature(torch.nn.Module):
|
||||
@@ -49,8 +49,8 @@ if __name__ == "__main__":
|
||||
|
||||
if args.s:
|
||||
pinn.span_pts(
|
||||
{'n': 200, 'mode': 'random', 'variables': 't'},
|
||||
{'n': 20, 'mode': 'random', 'variables': 'x'},
|
||||
{'n': 200, 'mode': 'grid', 'variables': 't'},
|
||||
{'n': 20, 'mode': 'grid', 'variables': 'x'},
|
||||
locations=['D'])
|
||||
pinn.span_pts(150, 'random', location=['gamma1', 'gamma2', 't0'])
|
||||
pinn.train(5000, 100)
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import argparse
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import ReLU, Tanh, Softplus
|
||||
from torch.nn import Softplus
|
||||
|
||||
from pina import PINN, LabelTensor, Plotter
|
||||
from pina.model import MultiFeedForward
|
||||
from problems.parametric_elliptic_optimal_control_alpha_variable import (
|
||||
ParametricEllipticOptimalControl)
|
||||
|
||||
from pina import PINN, LabelTensor
|
||||
from parametric_elliptic_optimal_control_alpha_variable2 import ParametricEllipticOptimalControl
|
||||
from pina.model import MultiFeedForward, FeedForward
|
||||
|
||||
class myFeature(torch.nn.Module):
|
||||
"""
|
||||
@@ -31,7 +33,6 @@ class CustomMultiDFF(MultiFeedForward):
|
||||
return out.append(p)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run PINA")
|
||||
@@ -71,8 +72,12 @@ if __name__ == "__main__":
|
||||
{'variables': ['mu', 'alpha'], 'mode': 'grid', 'n': 5},
|
||||
locations=['gamma1', 'gamma2', 'gamma3', 'gamma4'])
|
||||
|
||||
pinn.train(10000, 20)
|
||||
pinn.train(1000, 20)
|
||||
pinn.save_state('pina.ocp')
|
||||
|
||||
else:
|
||||
pinn.load_state('pina.ocp')
|
||||
plotter = Plotter()
|
||||
plotter.plot(pinn, components='y', fixed_variables={'alpha': 0.01, 'mu': 1.0})
|
||||
plotter.plot(pinn, components='u_param', fixed_variables={'alpha': 0.01, 'mu': 1.0})
|
||||
plotter.plot(pinn, components='p', fixed_variables={'alpha': 0.01, 'mu': 1.0})
|
||||
|
||||
@@ -2,8 +2,8 @@ import argparse
|
||||
import torch
|
||||
from torch.nn import Softplus
|
||||
from pina import Plotter, LabelTensor, PINN
|
||||
from parametric_poisson2 import ParametricPoisson
|
||||
from pina.model import FeedForward
|
||||
from problems.parametric_poisson import ParametricPoisson
|
||||
|
||||
|
||||
class myFeature(torch.nn.Module):
|
||||
@@ -43,11 +43,7 @@ if __name__ == "__main__":
|
||||
extra_features=feat
|
||||
)
|
||||
|
||||
pinn = PINN(
|
||||
poisson_problem,
|
||||
model,
|
||||
lr=0.006,
|
||||
regularizer=1e-6)
|
||||
pinn = PINN(poisson_problem, model, lr=0.006, regularizer=1e-6)
|
||||
|
||||
if args.s:
|
||||
|
||||
@@ -65,4 +61,5 @@ if __name__ == "__main__":
|
||||
else:
|
||||
pinn.load_state('pina.poisson_param')
|
||||
plotter = Plotter()
|
||||
plotter.plot(pinn, component='u', parametric=True, params_value=0)
|
||||
plotter.plot(pinn, fixed_variables={'mu1': 0, 'mu2': 1}, levels=21)
|
||||
plotter.plot(pinn, fixed_variables={'mu1': 1, 'mu2': -1}, levels=21)
|
||||
|
||||
@@ -7,7 +7,7 @@ from torch.nn import ReLU, Tanh, Softplus
|
||||
from pina import PINN, LabelTensor, Plotter
|
||||
from pina.model import FeedForward
|
||||
from pina.adaptive_functions import AdaptiveSin, AdaptiveCos, AdaptiveTanh
|
||||
from poisson2 import Poisson
|
||||
from problems.poisson import Poisson
|
||||
|
||||
|
||||
class myFeature(torch.nn.Module):
|
||||
@@ -61,6 +61,4 @@ if __name__ == "__main__":
|
||||
else:
|
||||
pinn.load_state('pina.poisson')
|
||||
plotter = Plotter()
|
||||
plotter.plot(pinn, component='u')
|
||||
|
||||
|
||||
plotter.plot(pinn)
|
||||
|
||||
@@ -36,9 +36,8 @@ if __name__ == "__main__":
|
||||
|
||||
if args.s:
|
||||
|
||||
pinn.span_pts(200, mode_spatial='grid', locations=['gamma_top', 'gamma_bot', 'gamma_in', 'gamma_out'])
|
||||
pinn.span_pts(2000, mode_spatial='random', locations=['D'])
|
||||
pinn.plot_pts()
|
||||
pinn.span_pts(200, 'grid', locations=['gamma_top', 'gamma_bot', 'gamma_in', 'gamma_out'])
|
||||
pinn.span_pts(2000, 'random', locations=['D'])
|
||||
pinn.train(10000, 100)
|
||||
with open('stokes_history_{}.txt'.format(args.id_run), 'w') as file_:
|
||||
for i, losses in enumerate(pinn.history):
|
||||
@@ -48,8 +47,8 @@ if __name__ == "__main__":
|
||||
else:
|
||||
pinn.load_state('pina.stokes')
|
||||
plotter = Plotter()
|
||||
plotter.plot(pinn, component='ux')
|
||||
plotter.plot(pinn, component='uy')
|
||||
plotter.plot(pinn, component='p')
|
||||
plotter.plot(pinn, components='ux')
|
||||
plotter.plot(pinn, components='uy')
|
||||
plotter.plot(pinn, components='p')
|
||||
|
||||
|
||||
|
||||
@@ -50,22 +50,23 @@ class Plotter:
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
def _1d_plot(self, pts, pred, method, truth_solution=None):
|
||||
def _1d_plot(self, pts, pred, method, truth_solution=None, **kwargs):
|
||||
"""
|
||||
"""
|
||||
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 8))
|
||||
|
||||
ax.plot(pts, pred.detach())
|
||||
ax.plot(pts, pred.detach(), **kwargs)
|
||||
|
||||
if truth_solution:
|
||||
truth_output = truth_solution(pts).float()
|
||||
ax.plot(pts, truth_output.detach())
|
||||
ax.plot(pts, truth_output.detach(), **kwargs)
|
||||
|
||||
plt.xlabel(pts.labels[0])
|
||||
plt.ylabel(pred.labels[0])
|
||||
plt.show()
|
||||
|
||||
def _2d_plot(self, pts, pred, v, res, method, truth_solution=None):
|
||||
def _2d_plot(self, pts, pred, v, res, method, truth_solution=None,
|
||||
**kwargs):
|
||||
"""
|
||||
"""
|
||||
|
||||
@@ -73,27 +74,29 @@ class Plotter:
|
||||
|
||||
pred_output = pred.reshape(res, res)
|
||||
if truth_solution:
|
||||
truth_output = truth_solution(*pts.T).float().reshape(res, res)
|
||||
truth_output = truth_solution(pts).float().reshape(res, res)
|
||||
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
|
||||
|
||||
cb = getattr(ax[0], method)(*grids, pred_output.detach())
|
||||
cb = getattr(ax[0], method)(*grids, pred_output.detach(), **kwargs)
|
||||
fig.colorbar(cb, ax=ax[0])
|
||||
cb = getattr(ax[1], method)(*grids, truth_output.detach())
|
||||
cb = getattr(ax[1], method)(*grids, truth_output.detach(), **kwargs)
|
||||
fig.colorbar(cb, ax=ax[1])
|
||||
cb = getattr(ax[2], method)(*grids,
|
||||
(truth_output-pred_output).detach())
|
||||
(truth_output-pred_output).detach(),
|
||||
**kwargs)
|
||||
fig.colorbar(cb, ax=ax[2])
|
||||
else:
|
||||
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
|
||||
cb = getattr(ax, method)(*grids, pred_output.detach())
|
||||
cb = getattr(ax, method)(*grids, pred_output.detach(), **kwargs)
|
||||
fig.colorbar(cb, ax=ax)
|
||||
|
||||
|
||||
def plot(self, pinn, components, fixed_variables={}, method='contourf',
|
||||
res=256, filename=None):
|
||||
def plot(self, pinn, components=None, fixed_variables={}, method='contourf',
|
||||
res=256, filename=None, **kwargs):
|
||||
"""
|
||||
"""
|
||||
|
||||
if components is None:
|
||||
components = [pinn.problem.output_variables]
|
||||
v = [
|
||||
var for var in pinn.problem.input_variables
|
||||
if var not in fixed_variables.keys()
|
||||
@@ -112,10 +115,11 @@ class Plotter:
|
||||
|
||||
truth_solution = getattr(pinn.problem, 'truth_solution', None)
|
||||
if len(v) == 1:
|
||||
self._1d_plot(pts, predicted_output, method, truth_solution)
|
||||
self._1d_plot(pts, predicted_output, method, truth_solution,
|
||||
**kwargs)
|
||||
elif len(v) == 2:
|
||||
self._2d_plot(pts, predicted_output, v, res, method,
|
||||
truth_solution)
|
||||
truth_solution, **kwargs)
|
||||
|
||||
if filename:
|
||||
plt.title('Output {} with parameter {}'.format(components,
|
||||
|
||||
@@ -39,8 +39,6 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
else:
|
||||
raise RuntimeError('different domains')
|
||||
|
||||
|
||||
|
||||
@input_variables.setter
|
||||
def input_variables(self, variables):
|
||||
raise RuntimeError
|
||||
|
||||
@@ -10,5 +10,5 @@ class TimeDependentProblem(AbstractProblem):
|
||||
pass
|
||||
|
||||
@property
|
||||
def temporal_variables(self):
|
||||
def temporal_variable(self):
|
||||
return self.temporal_domain.variables
|
||||
|
||||
Reference in New Issue
Block a user