fix examples (#21)

This commit is contained in:
Nicola Demo
2022-07-21 13:41:59 +02:00
committed by GitHub
parent 62f203fcc3
commit e8c2f87460
14 changed files with 63 additions and 67 deletions

View File

@@ -1,15 +1,12 @@
import argparse import argparse
import torch import torch
from torch.nn import ReLU, Tanh, Softplus, PReLU from torch.nn import Softplus
from pina.problem import SpatialProblem, ParametricProblem from pina.problem import SpatialProblem
from pina.operators import nabla, grad, div from pina.operators import grad
from pina.model import FeedForward, DeepONet from pina.model import FeedForward
from pina import Condition, Span, LabelTensor, Plotter, PINN from pina import Condition, Span, Plotter, PINN
import matplotlib
matplotlib.use('Qt5Agg')
class FirstOrderODE(SpatialProblem): class FirstOrderODE(SpatialProblem):

View File

@@ -8,10 +8,9 @@ from pina.span import Span
class Burgers1D(TimeDependentProblem, SpatialProblem): class Burgers1D(TimeDependentProblem, SpatialProblem):
spatial_variables = ['x']
temporal_variable = ['t']
output_variables = ['u'] output_variables = ['u']
domain = Span({'x': [-1, 1], 't': [0, 1]}) spatial_domain = Span({'x': [-1, 1]})
temporal_domain = Span({'t': [0, 1]})
def burger_equation(input_, output_): def burger_equation(input_, output_):
du = grad(output_, input_) du = grad(output_, input_)

View File

@@ -16,11 +16,9 @@ class ParametricEllipticOptimalControl(SpatialProblem, ParametricProblem):
x_range = [xmin, xmax] x_range = [xmin, xmax]
y_range = [ymin, ymax] y_range = [ymin, ymax]
spatial_variables = ['x1', 'x2']
parameters = ['mu', 'alpha']
output_variables = ['u', 'p', 'y'] output_variables = ['u', 'p', 'y']
domain = Span({ spatial_domain = Span({'x1': x_range, 'x2': y_range})
'x1': x_range, 'x2': y_range, 'mu': mu_range, 'alpha': a_range}) parameter_domain = Span({'mu': mu_range, 'alpha': a_range})
def term1(input_, output_): def term1(input_, output_):

View File

@@ -7,10 +7,9 @@ from pina import Span, Condition
class ParametricPoisson(SpatialProblem, ParametricProblem): class ParametricPoisson(SpatialProblem, ParametricProblem):
spatial_variables = ['x', 'y']
parameters = ['mu1', 'mu2']
output_variables = ['u'] output_variables = ['u']
domain = Span({'x': [-1, 1], 'y': [-1, 1]}) spatial_domain = Span({'x': [-1, 1], 'y': [-1, 1]})
parameter_domain = Span({'mu1': [-1, 1], 'mu2': [-1, 1]})
def laplace_equation(input_, output_): def laplace_equation(input_, output_):
force_term = torch.exp( force_term = torch.exp(

View File

@@ -8,9 +8,8 @@ from pina import Condition, Span
class Poisson(SpatialProblem): class Poisson(SpatialProblem):
spatial_variables = ['x', 'y']
output_variables = ['u'] output_variables = ['u']
domain = Span({'x': [0, 1], 'y': [0, 1]}) spatial_domain = Span({'x': [0, 1], 'y': [0, 1]})
def laplace_equation(input_, output_): def laplace_equation(input_, output_):
force_term = (torch.sin(input_.extract(['x'])*torch.pi) * force_term = (torch.sin(input_.extract(['x'])*torch.pi) *
@@ -30,7 +29,11 @@ class Poisson(SpatialProblem):
'D': Condition(Span({'x': [0, 1], 'y': [0, 1]}), laplace_equation), 'D': Condition(Span({'x': [0, 1], 'y': [0, 1]}), laplace_equation),
} }
def poisson_sol(self, x, y): def poisson_sol(self, pts):
return -(np.sin(x*np.pi)*np.sin(y*np.pi))/(2*np.pi**2) return -(
torch.sin(pts.extract(['x'])*torch.pi)*
torch.sin(pts.extract(['y'])*torch.pi)
)/(2*torch.pi**2)
#return -(np.sin(x*np.pi)*np.sin(y*np.pi))/(2*np.pi**2)
truth_solution = poisson_sol truth_solution = poisson_sol

View File

@@ -8,9 +8,8 @@ from pina import Condition, Span, LabelTensor
class Stokes(SpatialProblem): class Stokes(SpatialProblem):
spatial_variables = ['x', 'y']
output_variables = ['ux', 'uy', 'p'] output_variables = ['ux', 'uy', 'p']
domain = Span({'x': [-2, 2], 'y': [-1, 1]}) spatial_domain = Span({'x': [-2, 2], 'y': [-1, 1]})
def momentum(input_, output_): def momentum(input_, output_):
nabla_ = torch.hstack((LabelTensor(nabla(output_.extract(['ux']), input_), ['x']), nabla_ = torch.hstack((LabelTensor(nabla(output_.extract(['ux']), input_), ['x']),

View File

@@ -4,7 +4,7 @@ from torch.nn import Softplus
from pina import PINN, Plotter, LabelTensor from pina import PINN, Plotter, LabelTensor
from pina.model import FeedForward from pina.model import FeedForward
from burger2 import Burgers1D from problems.burgers import Burgers1D
class myFeature(torch.nn.Module): class myFeature(torch.nn.Module):
@@ -49,8 +49,8 @@ if __name__ == "__main__":
if args.s: if args.s:
pinn.span_pts( pinn.span_pts(
{'n': 200, 'mode': 'random', 'variables': 't'}, {'n': 200, 'mode': 'grid', 'variables': 't'},
{'n': 20, 'mode': 'random', 'variables': 'x'}, {'n': 20, 'mode': 'grid', 'variables': 'x'},
locations=['D']) locations=['D'])
pinn.span_pts(150, 'random', location=['gamma1', 'gamma2', 't0']) pinn.span_pts(150, 'random', location=['gamma1', 'gamma2', 't0'])
pinn.train(5000, 100) pinn.train(5000, 100)

View File

@@ -1,11 +1,13 @@
import argparse import argparse
import numpy as np import numpy as np
import torch import torch
from torch.nn import ReLU, Tanh, Softplus from torch.nn import Softplus
from pina import PINN, LabelTensor, Plotter
from pina.model import MultiFeedForward
from problems.parametric_elliptic_optimal_control_alpha_variable import (
ParametricEllipticOptimalControl)
from pina import PINN, LabelTensor
from parametric_elliptic_optimal_control_alpha_variable2 import ParametricEllipticOptimalControl
from pina.model import MultiFeedForward, FeedForward
class myFeature(torch.nn.Module): class myFeature(torch.nn.Module):
""" """
@@ -31,7 +33,6 @@ class CustomMultiDFF(MultiFeedForward):
return out.append(p) return out.append(p)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run PINA") parser = argparse.ArgumentParser(description="Run PINA")
@@ -71,8 +72,12 @@ if __name__ == "__main__":
{'variables': ['mu', 'alpha'], 'mode': 'grid', 'n': 5}, {'variables': ['mu', 'alpha'], 'mode': 'grid', 'n': 5},
locations=['gamma1', 'gamma2', 'gamma3', 'gamma4']) locations=['gamma1', 'gamma2', 'gamma3', 'gamma4'])
pinn.train(10000, 20) pinn.train(1000, 20)
pinn.save_state('pina.ocp') pinn.save_state('pina.ocp')
else: else:
pinn.load_state('pina.ocp') pinn.load_state('pina.ocp')
plotter = Plotter()
plotter.plot(pinn, components='y', fixed_variables={'alpha': 0.01, 'mu': 1.0})
plotter.plot(pinn, components='u_param', fixed_variables={'alpha': 0.01, 'mu': 1.0})
plotter.plot(pinn, components='p', fixed_variables={'alpha': 0.01, 'mu': 1.0})

View File

@@ -2,8 +2,8 @@ import argparse
import torch import torch
from torch.nn import Softplus from torch.nn import Softplus
from pina import Plotter, LabelTensor, PINN from pina import Plotter, LabelTensor, PINN
from parametric_poisson2 import ParametricPoisson
from pina.model import FeedForward from pina.model import FeedForward
from problems.parametric_poisson import ParametricPoisson
class myFeature(torch.nn.Module): class myFeature(torch.nn.Module):
@@ -43,11 +43,7 @@ if __name__ == "__main__":
extra_features=feat extra_features=feat
) )
pinn = PINN( pinn = PINN(poisson_problem, model, lr=0.006, regularizer=1e-6)
poisson_problem,
model,
lr=0.006,
regularizer=1e-6)
if args.s: if args.s:
@@ -65,4 +61,5 @@ if __name__ == "__main__":
else: else:
pinn.load_state('pina.poisson_param') pinn.load_state('pina.poisson_param')
plotter = Plotter() plotter = Plotter()
plotter.plot(pinn, component='u', parametric=True, params_value=0) plotter.plot(pinn, fixed_variables={'mu1': 0, 'mu2': 1}, levels=21)
plotter.plot(pinn, fixed_variables={'mu1': 1, 'mu2': -1}, levels=21)

View File

@@ -7,7 +7,7 @@ from torch.nn import ReLU, Tanh, Softplus
from pina import PINN, LabelTensor, Plotter from pina import PINN, LabelTensor, Plotter
from pina.model import FeedForward from pina.model import FeedForward
from pina.adaptive_functions import AdaptiveSin, AdaptiveCos, AdaptiveTanh from pina.adaptive_functions import AdaptiveSin, AdaptiveCos, AdaptiveTanh
from poisson2 import Poisson from problems.poisson import Poisson
class myFeature(torch.nn.Module): class myFeature(torch.nn.Module):
@@ -61,6 +61,4 @@ if __name__ == "__main__":
else: else:
pinn.load_state('pina.poisson') pinn.load_state('pina.poisson')
plotter = Plotter() plotter = Plotter()
plotter.plot(pinn, component='u') plotter.plot(pinn)

View File

@@ -36,9 +36,8 @@ if __name__ == "__main__":
if args.s: if args.s:
pinn.span_pts(200, mode_spatial='grid', locations=['gamma_top', 'gamma_bot', 'gamma_in', 'gamma_out']) pinn.span_pts(200, 'grid', locations=['gamma_top', 'gamma_bot', 'gamma_in', 'gamma_out'])
pinn.span_pts(2000, mode_spatial='random', locations=['D']) pinn.span_pts(2000, 'random', locations=['D'])
pinn.plot_pts()
pinn.train(10000, 100) pinn.train(10000, 100)
with open('stokes_history_{}.txt'.format(args.id_run), 'w') as file_: with open('stokes_history_{}.txt'.format(args.id_run), 'w') as file_:
for i, losses in enumerate(pinn.history): for i, losses in enumerate(pinn.history):
@@ -48,8 +47,8 @@ if __name__ == "__main__":
else: else:
pinn.load_state('pina.stokes') pinn.load_state('pina.stokes')
plotter = Plotter() plotter = Plotter()
plotter.plot(pinn, component='ux') plotter.plot(pinn, components='ux')
plotter.plot(pinn, component='uy') plotter.plot(pinn, components='uy')
plotter.plot(pinn, component='p') plotter.plot(pinn, components='p')

View File

@@ -50,22 +50,23 @@ class Plotter:
plt.legend() plt.legend()
plt.show() plt.show()
def _1d_plot(self, pts, pred, method, truth_solution=None): def _1d_plot(self, pts, pred, method, truth_solution=None, **kwargs):
""" """
""" """
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 8)) fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 8))
ax.plot(pts, pred.detach()) ax.plot(pts, pred.detach(), **kwargs)
if truth_solution: if truth_solution:
truth_output = truth_solution(pts).float() truth_output = truth_solution(pts).float()
ax.plot(pts, truth_output.detach()) ax.plot(pts, truth_output.detach(), **kwargs)
plt.xlabel(pts.labels[0]) plt.xlabel(pts.labels[0])
plt.ylabel(pred.labels[0]) plt.ylabel(pred.labels[0])
plt.show() plt.show()
def _2d_plot(self, pts, pred, v, res, method, truth_solution=None): def _2d_plot(self, pts, pred, v, res, method, truth_solution=None,
**kwargs):
""" """
""" """
@@ -73,27 +74,29 @@ class Plotter:
pred_output = pred.reshape(res, res) pred_output = pred.reshape(res, res)
if truth_solution: if truth_solution:
truth_output = truth_solution(*pts.T).float().reshape(res, res) truth_output = truth_solution(pts).float().reshape(res, res)
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(16, 6)) fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
cb = getattr(ax[0], method)(*grids, pred_output.detach()) cb = getattr(ax[0], method)(*grids, pred_output.detach(), **kwargs)
fig.colorbar(cb, ax=ax[0]) fig.colorbar(cb, ax=ax[0])
cb = getattr(ax[1], method)(*grids, truth_output.detach()) cb = getattr(ax[1], method)(*grids, truth_output.detach(), **kwargs)
fig.colorbar(cb, ax=ax[1]) fig.colorbar(cb, ax=ax[1])
cb = getattr(ax[2], method)(*grids, cb = getattr(ax[2], method)(*grids,
(truth_output-pred_output).detach()) (truth_output-pred_output).detach(),
**kwargs)
fig.colorbar(cb, ax=ax[2]) fig.colorbar(cb, ax=ax[2])
else: else:
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 6)) fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
cb = getattr(ax, method)(*grids, pred_output.detach()) cb = getattr(ax, method)(*grids, pred_output.detach(), **kwargs)
fig.colorbar(cb, ax=ax) fig.colorbar(cb, ax=ax)
def plot(self, pinn, components, fixed_variables={}, method='contourf', def plot(self, pinn, components=None, fixed_variables={}, method='contourf',
res=256, filename=None): res=256, filename=None, **kwargs):
""" """
""" """
if components is None:
components = [pinn.problem.output_variables]
v = [ v = [
var for var in pinn.problem.input_variables var for var in pinn.problem.input_variables
if var not in fixed_variables.keys() if var not in fixed_variables.keys()
@@ -112,10 +115,11 @@ class Plotter:
truth_solution = getattr(pinn.problem, 'truth_solution', None) truth_solution = getattr(pinn.problem, 'truth_solution', None)
if len(v) == 1: if len(v) == 1:
self._1d_plot(pts, predicted_output, method, truth_solution) self._1d_plot(pts, predicted_output, method, truth_solution,
**kwargs)
elif len(v) == 2: elif len(v) == 2:
self._2d_plot(pts, predicted_output, v, res, method, self._2d_plot(pts, predicted_output, v, res, method,
truth_solution) truth_solution, **kwargs)
if filename: if filename:
plt.title('Output {} with parameter {}'.format(components, plt.title('Output {} with parameter {}'.format(components,

View File

@@ -39,8 +39,6 @@ class AbstractProblem(metaclass=ABCMeta):
else: else:
raise RuntimeError('different domains') raise RuntimeError('different domains')
@input_variables.setter @input_variables.setter
def input_variables(self, variables): def input_variables(self, variables):
raise RuntimeError raise RuntimeError

View File

@@ -10,5 +10,5 @@ class TimeDependentProblem(AbstractProblem):
pass pass
@property @property
def temporal_variables(self): def temporal_variable(self):
return self.temporal_domain.variables return self.temporal_domain.variables