version 0.0.1
This commit is contained in:
@@ -1,15 +1,10 @@
|
||||
import sys
|
||||
import numpy as np
|
||||
import torch
|
||||
import argparse
|
||||
from pina.pinn import PINN
|
||||
from pina.ppinn import ParametricPINN as pPINN
|
||||
from pina.label_tensor import LabelTensor
|
||||
from torch.nn import ReLU, Tanh, Softplus
|
||||
from problems.burgers import Burgers1D
|
||||
from pina.deep_feed_forward import DeepFeedForward
|
||||
import torch
|
||||
from torch.nn import Softplus
|
||||
|
||||
from pina import Plotter
|
||||
from pina import PINN, Plotter
|
||||
from pina.model import FeedForward
|
||||
from problems.burgers import Burgers1D
|
||||
|
||||
|
||||
class myFeature(torch.nn.Module):
|
||||
@@ -23,6 +18,7 @@ class myFeature(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.sin(torch.pi * x[:, self.idx])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run PINA")
|
||||
@@ -36,11 +32,8 @@ if __name__ == "__main__":
|
||||
feat = [myFeature(0)] if args.features else []
|
||||
|
||||
burgers_problem = Burgers1D()
|
||||
model = DeepFeedForward(
|
||||
model = FeedForward(
|
||||
layers=[30, 20, 10, 5],
|
||||
#layers=[8, 8, 8],
|
||||
#layers=[16, 8, 4, 4],
|
||||
#layers=[20, 4, 4, 4],
|
||||
output_variables=burgers_problem.output_variables,
|
||||
input_variables=burgers_problem.input_variables,
|
||||
func=Softplus,
|
||||
@@ -57,7 +50,7 @@ if __name__ == "__main__":
|
||||
|
||||
if args.s:
|
||||
pinn.span_pts(2000, 'latin', ['D'])
|
||||
pinn.span_pts(150, 'random', ['gamma1', 'gamma2', 'initia'])
|
||||
pinn.span_pts(150, 'random', ['gamma1', 'gamma2', 't0'])
|
||||
pinn.train(5000, 100)
|
||||
pinn.save_state('pina.burger.{}.{}'.format(args.id_run, args.features))
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user