diff --git a/pina/model/gno.py b/pina/model/gno.py index 991ca39..3e9af8a 100644 --- a/pina/model/gno.py +++ b/pina/model/gno.py @@ -16,6 +16,7 @@ class GraphNeuralKernel(torch.nn.Module): n_layers=2, internal_n_layers=0, internal_layers=None, + inner_size=None, internal_func=None, external_func=None, shared_weights=False @@ -50,6 +51,7 @@ class GraphNeuralKernel(torch.nn.Module): edges_features=edge_features, n_layers=internal_n_layers, layers=internal_layers, + inner_size=inner_size, internal_func=internal_func, external_func=external_func) self.n_layers = n_layers @@ -61,6 +63,7 @@ class GraphNeuralKernel(torch.nn.Module): edges_features=edge_features, n_layers=internal_n_layers, layers=internal_layers, + inner_size=inner_size, internal_func=internal_func, external_func=external_func ) @@ -150,6 +153,7 @@ class GNO(KernelNeuralOperator): width=lifting_operator.out_features, edge_features=edge_features, internal_n_layers=internal_n_layers, + inner_size=inner_size, internal_layers=internal_layers, external_func=external_func, internal_func=internal_func, diff --git a/pina/model/layers/graph_integral_kernel.py b/pina/model/layers/graph_integral_kernel.py index 713b0d7..70d172c 100644 --- a/pina/model/layers/graph_integral_kernel.py +++ b/pina/model/layers/graph_integral_kernel.py @@ -10,8 +10,9 @@ class GraphIntegralLayer(MessagePassing): self, width, edges_features, - n_layers=0, + n_layers=2, layers=None, + inner_size=None, internal_func=None, external_func=None ): @@ -28,10 +29,13 @@ class GraphIntegralLayer(MessagePassing): from pina.model import FeedForward super(GraphIntegralLayer, self).__init__(aggr='mean') self.width = width + if layers is None and inner_size is None: + inner_size = width self.dense = FeedForward(input_dimensions=edges_features, output_dimensions=width ** 2, n_layers=n_layers, layers=layers, + inner_size=inner_size, func=internal_func) self.W = torch.nn.Linear(width, width) self.func = external_func() diff --git a/tests/test_model/test_gno.py b/tests/test_model/test_gno.py new file mode 100644 index 0000000..e07d845 --- /dev/null +++ b/tests/test_model/test_gno.py @@ -0,0 +1,127 @@ +import pytest +import torch +from pina.graph import KNNGraph +from pina.model import GNO +from torch_geometric.data import Batch + +x = [torch.rand(100, 6) for _ in range(10)] +pos = [torch.rand(100, 3) for _ in range(10)] +graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k=6) +input_ = Batch.from_data_list(graph.data) + + + +@pytest.mark.parametrize( + "shared_weights", + [ + True, + False + ] +) +def test_constructor(shared_weights): + lifting_operator = torch.nn.Linear(6, 16) + projection_operator = torch.nn.Linear(16, 3) + GNO(lifting_operator=lifting_operator, + projection_operator=projection_operator, + edge_features=3, + internal_layers=[16, 16], + shared_weights=shared_weights) + + GNO(lifting_operator=lifting_operator, + projection_operator=projection_operator, + edge_features=3, + inner_size=16, + internal_n_layers=10, + shared_weights=shared_weights) + + int_func = torch.nn.Softplus + ext_func = torch.nn.ReLU + + GNO(lifting_operator=lifting_operator, + projection_operator=projection_operator, + edge_features=3, + internal_n_layers=10, + shared_weights=shared_weights, + internal_func=int_func, + external_func=ext_func) + + +@pytest.mark.parametrize( + "shared_weights", + [ + True, + False + ] +) +def test_forward_1(shared_weights): + lifting_operator = torch.nn.Linear(6, 16) + projection_operator = torch.nn.Linear(16, 3) + model = GNO(lifting_operator=lifting_operator, + projection_operator=projection_operator, + edge_features=3, + internal_layers=[16, 16], + shared_weights=shared_weights) + output_ = model(input_) + assert output_.shape == torch.Size([1000, 3]) + +@pytest.mark.parametrize( + "shared_weights", + [ + True, + False + ] +) +def test_forward_2(shared_weights): + lifting_operator = torch.nn.Linear(6, 16) + projection_operator = torch.nn.Linear(16, 3) + model = GNO(lifting_operator=lifting_operator, + projection_operator=projection_operator, + edge_features=3, + inner_size=32, + internal_n_layers=2, + shared_weights=shared_weights) + output_ = model(input_) + assert output_.shape == torch.Size([1000, 3]) + +@pytest.mark.parametrize( + "shared_weights", + [ + True, + False + ] +) +def test_backward(shared_weights): + lifting_operator = torch.nn.Linear(6, 16) + projection_operator = torch.nn.Linear(16, 3) + model = GNO(lifting_operator=lifting_operator, + projection_operator=projection_operator, + edge_features=3, + internal_layers=[16, 16], + shared_weights=shared_weights) + input_.x.requires_grad = True + output_ = model(input_) + l = torch.mean(output_) + l.backward() + assert input_.x.grad.shape == torch.Size([1000, 6]) + +@pytest.mark.parametrize( + "shared_weights", + [ + True, + False + ] +) +def test_backward_2(shared_weights): + lifting_operator = torch.nn.Linear(6, 16) + projection_operator = torch.nn.Linear(16, 3) + model = GNO(lifting_operator=lifting_operator, + projection_operator=projection_operator, + edge_features=3, + inner_size=32, + internal_n_layers=2, + shared_weights=shared_weights) + input_.x.requires_grad = True + output_ = model(input_) + l = torch.mean(output_) + l.backward() + assert input_.x.grad.shape == torch.Size([1000, 6]) \ No newline at end of file