Generic DeepONet (#68)

* generic deeponet
* example for generic deeponet
* adapt tests to new interface
This commit is contained in:
Francesco Andreuzzi
2023-01-11 12:07:19 +01:00
committed by GitHub
parent e227700fbc
commit 7ce080fd85
5 changed files with 280 additions and 37 deletions

View File

@@ -1,9 +1,9 @@
import torch
import pytest
import torch
from pina import LabelTensor
from pina.model import DeepONet, FeedForward as FFN
from pina.model import DeepONet
from pina.model import FeedForward as FFN
data = torch.rand((20, 3))
input_vars = ['a', 'b', 'c']
@@ -14,19 +14,17 @@ input_ = LabelTensor(data, input_vars)
def test_constructor():
branch = FFN(input_variables=['a', 'c'], output_variables=20)
trunk = FFN(input_variables=['b'], output_variables=20)
onet = DeepONet(trunk_net=trunk, branch_net=branch,
output_variables=output_vars)
onet = DeepONet(nets=[trunk, branch], output_variables=output_vars)
def test_constructor_fails_when_invalid_inner_layer_size():
branch = FFN(input_variables=['a', 'c'], output_variables=20)
trunk = FFN(input_variables=['b'], output_variables=19)
with pytest.raises(ValueError):
DeepONet(trunk_net=trunk, branch_net=branch, output_variables=output_vars)
DeepONet(nets=[trunk, branch], output_variables=output_vars)
def test_forward():
branch = FFN(input_variables=['a', 'c'], output_variables=10)
trunk = FFN(input_variables=['b'], output_variables=10)
onet = DeepONet(trunk_net=trunk, branch_net=branch,
output_variables=output_vars)
onet = DeepONet(nets=[trunk, branch], output_variables=output_vars)
output_ = onet(input_)
assert output_.labels == output_vars