Generic DeepONet (#68)

* generic deeponet
* example for generic deeponet
* adapt tests to new interface
This commit is contained in:
Francesco Andreuzzi
2023-01-11 12:07:19 +01:00
committed by GitHub
parent e227700fbc
commit 7ce080fd85
5 changed files with 280 additions and 37 deletions

View File

@@ -0,0 +1,113 @@
import argparse
import logging
import torch
from problems.poisson import Poisson
from pina import PINN, LabelTensor, Plotter
from pina.model.deeponet import DeepONet, check_combos, spawn_combo_networks
logging.basicConfig(
filename="poisson_deeponet.log", filemode="w", level=logging.INFO
)
class SinFeature(torch.nn.Module):
"""
Feature: sin(x)
"""
def __init__(self, label):
super().__init__()
if not isinstance(label, (tuple, list)):
label = [label]
self._label = label
def forward(self, x):
"""
Defines the computation performed at every call.
:param LabelTensor x: the input tensor.
:return: the output computed by the model.
:rtype: LabelTensor
"""
t = torch.sin(x.extract(self._label) * torch.pi)
return LabelTensor(t, [f"sin({self._label})"])
def prepare_deeponet_model(args, problem, extra_feature_combo_func=None):
combos = tuple(map(lambda combo: combo.split("-"), args.combos.split(",")))
check_combos(combos, problem.input_variables)
extra_feature = extra_feature_combo_func if args.extra else None
networks = spawn_combo_networks(
combos=combos,
layers=list(map(int, args.layers.split(","))) if args.layers else [],
output_dimension=args.hidden * len(problem.output_variables),
func=torch.nn.Softplus,
extra_feature=extra_feature,
bias=not args.nobias,
)
return DeepONet(
networks,
problem.output_variables,
aggregator=args.aggregator,
reduction=args.reduction,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run PINA")
parser.add_argument("-s", "--save", action="store_true")
parser.add_argument("-l", "--load", action="store_true")
parser.add_argument("id_run", help="Run ID", type=int)
parser.add_argument("--extra", help="Extra features", action="store_true")
parser.add_argument("--nobias", action="store_true")
parser.add_argument(
"--combos",
help="DeepONet internal network combinations",
type=str,
required=True,
)
parser.add_argument(
"--aggregator", help="Aggregator for DeepONet", type=str, default="*"
)
parser.add_argument(
"--reduction", help="Reduction for DeepONet", type=str, default="+"
)
parser.add_argument(
"--hidden",
help="Number of variables in the hidden DeepONet layer",
type=int,
required=True,
)
parser.add_argument(
"--layers",
help="Structure of the DeepONet partial layers",
type=str,
required=True,
)
cli_args = parser.parse_args()
poisson_problem = Poisson()
model = prepare_deeponet_model(
cli_args,
poisson_problem,
extra_feature_combo_func=lambda combo: [SinFeature(combo)],
)
pinn = PINN(poisson_problem, model, lr=0.01, regularizer=1e-8)
if cli_args.save:
pinn.span_pts(
20, "grid", locations=["gamma1", "gamma2", "gamma3", "gamma4"]
)
pinn.span_pts(20, "grid", locations=["D"])
pinn.train(1.0e-10, 100)
pinn.save_state(f"pina.poisson_{cli_args.id_run}")
if cli_args.load:
pinn.load_state(f"pina.poisson_{cli_args.id_run}")
plotter = Plotter()
plotter.plot(pinn)