Refactoring code

This commit is contained in:
Your Name
2022-01-27 14:55:42 +01:00
parent fb16fc7f3a
commit fa8ffd5042
32 changed files with 417 additions and 442 deletions

View File

@@ -2,15 +2,16 @@ import sys
import numpy as np
import torch
import argparse
from pina.pinn import PINN
from pina import PINN
from pina.ppinn import ParametricPINN as pPINN
from pina.label_tensor import LabelTensor
from torch.nn import ReLU, Tanh, Softplus
from problems.poisson2D import Poisson2DProblem as Poisson2D
from problems.poisson2D import Poisson2D
from pina.deep_feed_forward import DeepFeedForward
from pina.adaptive_functions import AdaptiveSin, AdaptiveCos, AdaptiveTanh
from pina import Plotter
class myFeature(torch.nn.Module):
"""
@@ -54,17 +55,18 @@ if __name__ == "__main__":
if args.s:
pinn.span_pts(10, 'grid', ['D'])
pinn.span_pts(10, 'grid', ['gamma1', 'gamma2', 'gamma3', 'gamma4'])
pinn.span_pts(20, 'grid', ['D'])
pinn.span_pts(20, 'grid', ['gamma1', 'gamma2', 'gamma3', 'gamma4'])
#pinn.plot_pts()
pinn.train(10000, 100)
pinn.train(1000, 100)
with open('poisson_history_{}_{}.txt'.format(args.id_run, args.features), 'w') as file_:
for i, losses in enumerate(pinn.history):
file_.write('{} {}\n'.format(i, sum(losses).item()))
file_.write('{} {}\n'.format(i, sum(losses)))
pinn.save_state('pina.poisson')
else:
pinn.load_state('pina.poisson')
pinn.plot(40)
plotter = Plotter()
plotter.plot(pinn)