Bug fix in GNO and implementation of tests
This commit is contained in:
committed by
Nicola Demo
parent
4c5e1569ff
commit
54a62dee26
@@ -16,6 +16,7 @@ class GraphNeuralKernel(torch.nn.Module):
|
||||
n_layers=2,
|
||||
internal_n_layers=0,
|
||||
internal_layers=None,
|
||||
inner_size=None,
|
||||
internal_func=None,
|
||||
external_func=None,
|
||||
shared_weights=False
|
||||
@@ -50,6 +51,7 @@ class GraphNeuralKernel(torch.nn.Module):
|
||||
edges_features=edge_features,
|
||||
n_layers=internal_n_layers,
|
||||
layers=internal_layers,
|
||||
inner_size=inner_size,
|
||||
internal_func=internal_func,
|
||||
external_func=external_func)
|
||||
self.n_layers = n_layers
|
||||
@@ -61,6 +63,7 @@ class GraphNeuralKernel(torch.nn.Module):
|
||||
edges_features=edge_features,
|
||||
n_layers=internal_n_layers,
|
||||
layers=internal_layers,
|
||||
inner_size=inner_size,
|
||||
internal_func=internal_func,
|
||||
external_func=external_func
|
||||
)
|
||||
@@ -150,6 +153,7 @@ class GNO(KernelNeuralOperator):
|
||||
width=lifting_operator.out_features,
|
||||
edge_features=edge_features,
|
||||
internal_n_layers=internal_n_layers,
|
||||
inner_size=inner_size,
|
||||
internal_layers=internal_layers,
|
||||
external_func=external_func,
|
||||
internal_func=internal_func,
|
||||
|
||||
@@ -10,8 +10,9 @@ class GraphIntegralLayer(MessagePassing):
|
||||
self,
|
||||
width,
|
||||
edges_features,
|
||||
n_layers=0,
|
||||
n_layers=2,
|
||||
layers=None,
|
||||
inner_size=None,
|
||||
internal_func=None,
|
||||
external_func=None
|
||||
):
|
||||
@@ -28,10 +29,13 @@ class GraphIntegralLayer(MessagePassing):
|
||||
from pina.model import FeedForward
|
||||
super(GraphIntegralLayer, self).__init__(aggr='mean')
|
||||
self.width = width
|
||||
if layers is None and inner_size is None:
|
||||
inner_size = width
|
||||
self.dense = FeedForward(input_dimensions=edges_features,
|
||||
output_dimensions=width ** 2,
|
||||
n_layers=n_layers,
|
||||
layers=layers,
|
||||
inner_size=inner_size,
|
||||
func=internal_func)
|
||||
self.W = torch.nn.Linear(width, width)
|
||||
self.func = external_func()
|
||||
|
||||
Reference in New Issue
Block a user