version 0.0.1

This commit is contained in:
Your Name
2022-02-11 16:44:37 +01:00
parent fa8ffd5042
commit 1483746b45
29 changed files with 416 additions and 559 deletions

View File

@@ -1,15 +1,10 @@
import sys
import numpy as np
import torch
import argparse
from pina.pinn import PINN
from pina.ppinn import ParametricPINN as pPINN
from pina.label_tensor import LabelTensor
from torch.nn import ReLU, Tanh, Softplus
from problems.burgers import Burgers1D
from pina.deep_feed_forward import DeepFeedForward
import torch
from torch.nn import Softplus
from pina import Plotter
from pina import PINN, Plotter
from pina.model import FeedForward
from problems.burgers import Burgers1D
class myFeature(torch.nn.Module):
@@ -23,6 +18,7 @@ class myFeature(torch.nn.Module):
def forward(self, x):
return torch.sin(torch.pi * x[:, self.idx])
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run PINA")
@@ -36,11 +32,8 @@ if __name__ == "__main__":
feat = [myFeature(0)] if args.features else []
burgers_problem = Burgers1D()
model = DeepFeedForward(
model = FeedForward(
layers=[30, 20, 10, 5],
#layers=[8, 8, 8],
#layers=[16, 8, 4, 4],
#layers=[20, 4, 4, 4],
output_variables=burgers_problem.output_variables,
input_variables=burgers_problem.input_variables,
func=Softplus,
@@ -57,7 +50,7 @@ if __name__ == "__main__":
if args.s:
pinn.span_pts(2000, 'latin', ['D'])
pinn.span_pts(150, 'random', ['gamma1', 'gamma2', 'initia'])
pinn.span_pts(150, 'random', ['gamma1', 'gamma2', 't0'])
pinn.train(5000, 100)
pinn.save_state('pina.burger.{}.{}'.format(args.id_run, args.features))
else: