fix old codes

This commit is contained in:
Your Name
2022-07-11 10:58:15 +02:00
parent 088649e042
commit f526a26050
19 changed files with 385 additions and 457 deletions

View File

@@ -1,9 +1,8 @@
import argparse
import torch
from torch.nn import Softplus
from pina import Plotter
from pina import PINN as pPINN
from problems.parametric_poisson import ParametricPoisson
from pina import Plotter, LabelTensor, PINN
from parametric_poisson2 import ParametricPoisson
from pina.model import FeedForward
@@ -14,7 +13,13 @@ class myFeature(torch.nn.Module):
super(myFeature, self).__init__()
def forward(self, x):
return torch.exp(- 2*(x.extract(['x']) - x.extract(['mu1']))**2 - 2*(x.extract(['y']) - x.extract(['mu2']))**2)
t = (
torch.exp(
- 2*(x.extract(['x']) - x.extract(['mu1']))**2
- 2*(x.extract(['y']) - x.extract(['mu2']))**2
)
)
return LabelTensor(t, ['k0'])
if __name__ == "__main__":
@@ -38,21 +43,23 @@ if __name__ == "__main__":
extra_features=feat
)
pinn = pPINN(
pinn = PINN(
poisson_problem,
model,
lr=0.0006,
lr=0.006,
regularizer=1e-6)
if args.s:
pinn.span_pts(500, n_params=10, mode_spatial='random', locations=['D'])
pinn.span_pts(200, n_params=10, mode_spatial='random', locations=['gamma1', 'gamma2', 'gamma3', 'gamma4'])
pinn.plot_pts()
pinn.span_pts(
{'variables': ['x', 'y'], 'mode': 'random', 'n': 100},
{'variables': ['mu1', 'mu2'], 'mode': 'grid', 'n': 5},
locations=['D'])
pinn.span_pts(
{'variables': ['x', 'y'], 'mode': 'grid', 'n': 20},
{'variables': ['mu1', 'mu2'], 'mode': 'grid', 'n': 5},
locations=['gamma1', 'gamma2', 'gamma3', 'gamma4'])
pinn.train(10000, 100)
with open('param_poisson_history_{}_{}.txt'.format(args.id_run, args.features), 'w') as file_:
for i, losses in enumerate(pinn.history):
file_.write('{} {}\n'.format(i, sum(losses)))
pinn.save_state('pina.poisson_param')
else: