minor fixes deeponet, remove LabelTensor import

This commit is contained in:
Dario Coscia
2023-09-07 10:47:26 +02:00
committed by Nicola Demo
parent 3a7fc63760
commit 190ee0561d
2 changed files with 39 additions and 8 deletions

View File

@@ -53,3 +53,15 @@ def test_forward_extract_int():
reduction='+',
aggregator='*')
model(data)
def test_forward_extract_str_wrong():
branch_net = FeedForward(input_dimensons=1, output_dimensions=10)
trunk_net = FeedForward(input_dimensons=2, output_dimensions=10)
model = DeepONet(branch_net=branch_net,
trunk_net=trunk_net,
input_indeces_branch_net=['a'],
input_indeces_trunk_net=['b', 'c'],
reduction='+',
aggregator='*')
with pytest.raises(RuntimeError):
model(data)