fix old codes
This commit is contained in:
@@ -2,9 +2,9 @@ import argparse
|
||||
import torch
|
||||
from torch.nn import Softplus
|
||||
|
||||
from pina import PINN, Plotter
|
||||
from pina import PINN, Plotter, LabelTensor
|
||||
from pina.model import FeedForward
|
||||
from problems.burgers import Burgers1D
|
||||
from burger2 import Burgers1D
|
||||
|
||||
|
||||
class myFeature(torch.nn.Module):
|
||||
@@ -16,7 +16,7 @@ class myFeature(torch.nn.Module):
|
||||
self.idx = idx
|
||||
|
||||
def forward(self, x):
|
||||
return torch.sin(torch.pi * x[:, self.idx])
|
||||
return LabelTensor(torch.sin(torch.pi * x.extract(['x'])), ['sin(x)'])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -45,12 +45,14 @@ if __name__ == "__main__":
|
||||
model,
|
||||
lr=0.006,
|
||||
error_norm='mse',
|
||||
regularizer=0,
|
||||
lr_accelerate=None)
|
||||
regularizer=0)
|
||||
|
||||
if args.s:
|
||||
pinn.span_pts(2000, 'latin', ['D'])
|
||||
pinn.span_pts(150, 'random', ['gamma1', 'gamma2', 't0'])
|
||||
pinn.span_pts(
|
||||
{'n': 200, 'mode': 'random', 'variables': 't'},
|
||||
{'n': 20, 'mode': 'random', 'variables': 'x'},
|
||||
locations=['D'])
|
||||
pinn.span_pts(150, 'random', location=['gamma1', 'gamma2', 't0'])
|
||||
pinn.train(5000, 100)
|
||||
pinn.save_state('pina.burger.{}.{}'.format(args.id_run, args.features))
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user