minor fixes deeponet, remove LabelTensor import
This commit is contained in:
committed by
Nicola Demo
parent
3a7fc63760
commit
190ee0561d
@@ -53,3 +53,15 @@ def test_forward_extract_int():
|
||||
reduction='+',
|
||||
aggregator='*')
|
||||
model(data)
|
||||
|
||||
def test_forward_extract_str_wrong():
|
||||
branch_net = FeedForward(input_dimensons=1, output_dimensions=10)
|
||||
trunk_net = FeedForward(input_dimensons=2, output_dimensions=10)
|
||||
model = DeepONet(branch_net=branch_net,
|
||||
trunk_net=trunk_net,
|
||||
input_indeces_branch_net=['a'],
|
||||
input_indeces_trunk_net=['b', 'c'],
|
||||
reduction='+',
|
||||
aggregator='*')
|
||||
with pytest.raises(RuntimeError):
|
||||
model(data)
|
||||
|
||||
Reference in New Issue
Block a user