fix examples (#21)
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
import argparse
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import ReLU, Tanh, Softplus
|
||||
from torch.nn import Softplus
|
||||
|
||||
from pina import PINN, LabelTensor, Plotter
|
||||
from pina.model import MultiFeedForward
|
||||
from problems.parametric_elliptic_optimal_control_alpha_variable import (
|
||||
ParametricEllipticOptimalControl)
|
||||
|
||||
from pina import PINN, LabelTensor
|
||||
from parametric_elliptic_optimal_control_alpha_variable2 import ParametricEllipticOptimalControl
|
||||
from pina.model import MultiFeedForward, FeedForward
|
||||
|
||||
class myFeature(torch.nn.Module):
|
||||
"""
|
||||
@@ -31,7 +33,6 @@ class CustomMultiDFF(MultiFeedForward):
|
||||
return out.append(p)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run PINA")
|
||||
@@ -71,8 +72,12 @@ if __name__ == "__main__":
|
||||
{'variables': ['mu', 'alpha'], 'mode': 'grid', 'n': 5},
|
||||
locations=['gamma1', 'gamma2', 'gamma3', 'gamma4'])
|
||||
|
||||
pinn.train(10000, 20)
|
||||
pinn.train(1000, 20)
|
||||
pinn.save_state('pina.ocp')
|
||||
|
||||
else:
|
||||
pinn.load_state('pina.ocp')
|
||||
plotter = Plotter()
|
||||
plotter.plot(pinn, components='y', fixed_variables={'alpha': 0.01, 'mu': 1.0})
|
||||
plotter.plot(pinn, components='u_param', fixed_variables={'alpha': 0.01, 'mu': 1.0})
|
||||
plotter.plot(pinn, components='p', fixed_variables={'alpha': 0.01, 'mu': 1.0})
|
||||
|
||||
Reference in New Issue
Block a user