add models and layers backward test

This commit is contained in:
cyberguli
2024-02-19 23:09:10 +01:00
committed by Nicola Demo
parent cbb43a5392
commit eb1af0b50e
10 changed files with 308 additions and 1 deletions

View File

@@ -56,6 +56,21 @@ def test_forward_extract_int():
aggregator='*')
model(data)
def test_backward_extract_int():
data = torch.rand((20, 3))
branch_net = FeedForward(input_dimensions=1, output_dimensions=10)
trunk_net = FeedForward(input_dimensions=2, output_dimensions=10)
model = DeepONet(branch_net=branch_net,
trunk_net=trunk_net,
input_indeces_branch_net=[0],
input_indeces_trunk_net=[1, 2],
reduction='+',
aggregator='*')
data.requires_grad = True
model(data)
l=torch.mean(model(data))
l.backward()
assert data._grad.shape == torch.Size([20,3])
def test_forward_extract_str_wrong():
branch_net = FeedForward(input_dimensions=1, output_dimensions=10)
@@ -68,3 +83,20 @@ def test_forward_extract_str_wrong():
aggregator='*')
with pytest.raises(RuntimeError):
model(data)
def test_backward_extract_str_wrong():
data = torch.rand((20, 3))
branch_net = FeedForward(input_dimensions=1, output_dimensions=10)
trunk_net = FeedForward(input_dimensions=2, output_dimensions=10)
model = DeepONet(branch_net=branch_net,
trunk_net=trunk_net,
input_indeces_branch_net=['a'],
input_indeces_trunk_net=['b', 'c'],
reduction='+',
aggregator='*')
data.requires_grad = True
with pytest.raises(RuntimeError):
model(data)
l=torch.mean(model(data))
l.backward()
assert data._grad.shape == torch.Size([20,3])