fix old codes

This commit is contained in:
Your Name
2022-07-11 10:58:15 +02:00
parent 088649e042
commit f526a26050
19 changed files with 385 additions and 457 deletions

View File

@@ -2,9 +2,9 @@ import argparse
import torch
from torch.nn import Softplus
from pina import PINN, Plotter
from pina import PINN, Plotter, LabelTensor
from pina.model import FeedForward
from problems.burgers import Burgers1D
from burger2 import Burgers1D
class myFeature(torch.nn.Module):
@@ -16,7 +16,7 @@ class myFeature(torch.nn.Module):
self.idx = idx
def forward(self, x):
return torch.sin(torch.pi * x[:, self.idx])
return LabelTensor(torch.sin(torch.pi * x.extract(['x'])), ['sin(x)'])
if __name__ == "__main__":
@@ -45,12 +45,14 @@ if __name__ == "__main__":
model,
lr=0.006,
error_norm='mse',
regularizer=0,
lr_accelerate=None)
regularizer=0)
if args.s:
pinn.span_pts(2000, 'latin', ['D'])
pinn.span_pts(150, 'random', ['gamma1', 'gamma2', 't0'])
pinn.span_pts(
{'n': 200, 'mode': 'random', 'variables': 't'},
{'n': 20, 'mode': 'random', 'variables': 'x'},
locations=['D'])
pinn.span_pts(150, 'random', location=['gamma1', 'gamma2', 't0'])
pinn.train(5000, 100)
pinn.save_state('pina.burger.{}.{}'.format(args.id_run, args.features))
else: