Examples update for v0.1 (#206)

* modify examples/problems
* modify tutorials

---------

Co-authored-by: Dario Coscia <dariocoscia@dhcp-235.eduroam.sissa.it>
Co-authored-by: Dario Coscia <dariocoscia@dhcp-015.eduroam.sissa.it>
This commit is contained in:
Dario Coscia
2023-11-14 18:24:07 +01:00
committed by Nicola Demo
parent 0d38de5afe
commit ee39b39805
19 changed files with 605 additions and 613 deletions

View File

@@ -1,67 +1,53 @@
""" Run PINA on ODE equation. """
import argparse
import torch
from torch.nn import Softplus
from pina.model import FeedForward
from pina import Condition, CartesianDomain, Plotter, PINN
from pina.solvers import PINN
from pina.plotter import Plotter
from pina.trainer import Trainer
from problems.first_order_ode import FirstOrderODE
class FirstOrderODE(SpatialProblem):
x_rng = [0, 5]
output_variables = ['y']
spatial_domain = CartesianDomain({'x': x_rng})
def ode(input_, output_):
y = output_
x = input_
return grad(y, x) + y - x
def fixed(input_, output_):
exp_value = 1.
return output_ - exp_value
def solution(self, input_):
x = input_
return x - 1.0 + 2*torch.exp(-x)
conditions = {
'bc': Condition(CartesianDomain({'x': x_rng[0]}), fixed),
'dd': Condition(CartesianDomain({'x': x_rng}), ode),
}
truth_solution = solution
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run PINA")
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("-s", "-save", action="store_true")
group.add_argument("-l", "-load", action="store_true")
parser.add_argument("id_run", help="number of run", type=int)
parser.add_argument("--load", help="directory to save or load file", type=str)
parser.add_argument("--epochs", help="extra features", type=int, default=3000)
args = parser.parse_args()
# define Problem + Model + PINN
# create problem and discretise domain
problem = FirstOrderODE()
problem.discretise_domain(n=500, mode='grid', variables = 'x', locations=['D'])
problem.discretise_domain(n=1, mode='grid', variables = 'x', locations=['BC'])
# create model
model = FeedForward(
layers=[4]*2,
output_variables=problem.output_variables,
input_variables=problem.input_variables,
func=Softplus,
layers=[10, 10],
output_dimensions=len(problem.output_variables),
input_dimensions=len(problem.input_variables),
func=Softplus
)
pinn = PINN(problem, model, lr=0.03, error_norm='mse', regularizer=0)
if args.s:
# create solver
pinn = PINN(
problem=problem,
model=model,
extra_features=None,
optimizer_kwargs={'lr' : 0.001}
)
pinn.span_pts(
{'variables': ['x'], 'mode': 'grid', 'n': 1}, locations=['bc'])
pinn.span_pts(
{'variables': ['x'], 'mode': 'grid', 'n': 30}, locations=['dd'])
Plotter().plot_samples(pinn, ['x'])
pinn.train(1200, 50)
pinn.save_state('pina.ode')
# create trainer
directory = 'pina.ode'
trainer = Trainer(solver=pinn, accelerator='cpu', max_epochs=args.epochs, default_root_dir=directory)
else:
pinn.load_state('pina.ode')
if args.load:
pinn = PINN.load_from_checkpoint(checkpoint_path=args.load, problem=problem, model=model)
plotter = Plotter()
plotter.plot(pinn, components=['y'])
plotter.plot(pinn)
else:
trainer.train()