Refactoring code
This commit is contained in:
@@ -2,15 +2,16 @@ import sys
|
||||
import numpy as np
|
||||
import torch
|
||||
import argparse
|
||||
from pina.pinn import PINN
|
||||
from pina import PINN
|
||||
from pina.ppinn import ParametricPINN as pPINN
|
||||
from pina.label_tensor import LabelTensor
|
||||
from torch.nn import ReLU, Tanh, Softplus
|
||||
from problems.poisson2D import Poisson2DProblem as Poisson2D
|
||||
from problems.poisson2D import Poisson2D
|
||||
from pina.deep_feed_forward import DeepFeedForward
|
||||
|
||||
from pina.adaptive_functions import AdaptiveSin, AdaptiveCos, AdaptiveTanh
|
||||
|
||||
from pina import Plotter
|
||||
|
||||
class myFeature(torch.nn.Module):
|
||||
"""
|
||||
@@ -54,17 +55,18 @@ if __name__ == "__main__":
|
||||
|
||||
if args.s:
|
||||
|
||||
pinn.span_pts(10, 'grid', ['D'])
|
||||
pinn.span_pts(10, 'grid', ['gamma1', 'gamma2', 'gamma3', 'gamma4'])
|
||||
pinn.span_pts(20, 'grid', ['D'])
|
||||
pinn.span_pts(20, 'grid', ['gamma1', 'gamma2', 'gamma3', 'gamma4'])
|
||||
#pinn.plot_pts()
|
||||
pinn.train(10000, 100)
|
||||
pinn.train(1000, 100)
|
||||
with open('poisson_history_{}_{}.txt'.format(args.id_run, args.features), 'w') as file_:
|
||||
for i, losses in enumerate(pinn.history):
|
||||
file_.write('{} {}\n'.format(i, sum(losses).item()))
|
||||
file_.write('{} {}\n'.format(i, sum(losses)))
|
||||
pinn.save_state('pina.poisson')
|
||||
|
||||
else:
|
||||
pinn.load_state('pina.poisson')
|
||||
pinn.plot(40)
|
||||
plotter = Plotter()
|
||||
plotter.plot(pinn)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user