From 5d2ca62e652d4ea2ad494dc4d6b93c8048f25435 Mon Sep 17 00:00:00 2001 From: Anna Ivagnes Date: Thu, 3 Oct 2024 17:00:23 +0200 Subject: [PATCH] add multiple outputs possibility in DeepONet --- pina/model/deeponet.py | 6 +++++- tests/test_model/test_deeponet.py | 33 +++++++++++++++++++++++++++++-- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/pina/model/deeponet.py b/pina/model/deeponet.py index b39f532..eb5d618 100644 --- a/pina/model/deeponet.py +++ b/pina/model/deeponet.py @@ -172,6 +172,7 @@ class MIONet(torch.nn.Module): raise ValueError(f"Unsupported aggregation: {str(aggregator)}") self._aggregator = aggregator_func + self._aggregator_type = aggregator def _init_reduction(self, reduction): reduction_funcs = DeepONet._symbol_functions(dim=-1) @@ -183,6 +184,7 @@ class MIONet(torch.nn.Module): raise ValueError(f"Unsupported reduction: {reduction}") self._reduction = reduction_func + self._reduction_type = reduction def _get_vars(self, x, indeces): if isinstance(indeces[0], str): @@ -222,7 +224,9 @@ class MIONet(torch.nn.Module): aggregated = self._aggregator(torch.dstack(output_)) # reduce - output_ = self._reduction(aggregated).reshape(-1, 1) + output_ = self._reduction(aggregated) + if self._reduction_type in DeepONet._symbol_functions(dim=-1): + output_ = output_.reshape(-1, 1) # scale and translate output_ *= self._scale diff --git a/tests/test_model/test_deeponet.py b/tests/test_model/test_deeponet.py index 78819e5..9670424 100644 --- a/tests/test_model/test_deeponet.py +++ b/tests/test_model/test_deeponet.py @@ -1,5 +1,6 @@ import pytest import torch +from torch.nn import Linear from pina import LabelTensor from pina.model import DeepONet @@ -8,7 +9,8 @@ from pina.model import FeedForward data = torch.rand((20, 3)) input_vars = ['a', 'b', 'c'] input_ = LabelTensor(data, input_vars) - +symbol_funcs_red = DeepONet._symbol_functions(dim=-1) +output_dims = [1, 5, 10, 20] def test_constructor(): branch_net = FeedForward(input_dimensions=1, output_dimensions=10) @@ -32,7 +34,6 @@ def test_constructor_fails_when_invalid_inner_layer_size(): reduction='+', aggregator='*') - def test_forward_extract_str(): branch_net = FeedForward(input_dimensions=1, output_dimensions=10) trunk_net = FeedForward(input_dimensions=2, output_dimensions=10) @@ -43,6 +44,7 @@ def test_forward_extract_str(): reduction='+', aggregator='*') model(input_) + assert model(input_).shape[-1] == 1 def test_forward_extract_int(): @@ -100,3 +102,30 @@ def test_backward_extract_str_wrong(): l=torch.mean(model(data)) l.backward() assert data._grad.shape == torch.Size([20,3]) + +@pytest.mark.parametrize('red', symbol_funcs_red) +def test_forward_symbol_funcs(red): + branch_net = FeedForward(input_dimensions=1, output_dimensions=10) + trunk_net = FeedForward(input_dimensions=2, output_dimensions=10) + model = DeepONet(branch_net=branch_net, + trunk_net=trunk_net, + input_indeces_branch_net=['a'], + input_indeces_trunk_net=['b', 'c'], + reduction=red, + aggregator='*') + model(input_) + assert model(input_).shape[-1] == 1 + +@pytest.mark.parametrize('out_dim', output_dims) +def test_forward_callable_reduction(out_dim): + branch_net = FeedForward(input_dimensions=1, output_dimensions=10) + trunk_net = FeedForward(input_dimensions=2, output_dimensions=10) + reduction_layer = Linear(10, out_dim) + model = DeepONet(branch_net=branch_net, + trunk_net=trunk_net, + input_indeces_branch_net=['a'], + input_indeces_trunk_net=['b', 'c'], + reduction=reduction_layer, + aggregator='*') + model(input_) + assert model(input_).shape[-1] == out_dim