fix examples (#21)

This commit is contained in:
Nicola Demo
2022-07-21 13:41:59 +02:00
committed by GitHub
parent 62f203fcc3
commit e8c2f87460
14 changed files with 63 additions and 67 deletions

View File

@@ -1,11 +1,13 @@
import argparse
import numpy as np
import torch
from torch.nn import ReLU, Tanh, Softplus
from torch.nn import Softplus
from pina import PINN, LabelTensor, Plotter
from pina.model import MultiFeedForward
from problems.parametric_elliptic_optimal_control_alpha_variable import (
ParametricEllipticOptimalControl)
from pina import PINN, LabelTensor
from parametric_elliptic_optimal_control_alpha_variable2 import ParametricEllipticOptimalControl
from pina.model import MultiFeedForward, FeedForward
class myFeature(torch.nn.Module):
"""
@@ -31,7 +33,6 @@ class CustomMultiDFF(MultiFeedForward):
return out.append(p)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run PINA")
@@ -71,8 +72,12 @@ if __name__ == "__main__":
{'variables': ['mu', 'alpha'], 'mode': 'grid', 'n': 5},
locations=['gamma1', 'gamma2', 'gamma3', 'gamma4'])
pinn.train(10000, 20)
pinn.train(1000, 20)
pinn.save_state('pina.ocp')
else:
pinn.load_state('pina.ocp')
plotter = Plotter()
plotter.plot(pinn, components='y', fixed_variables={'alpha': 0.01, 'mu': 1.0})
plotter.plot(pinn, components='u_param', fixed_variables={'alpha': 0.01, 'mu': 1.0})
plotter.plot(pinn, components='p', fixed_variables={'alpha': 0.01, 'mu': 1.0})