Formatting
* Adding black as dev dependency * Formatting pina code * Formatting tests
This commit is contained in:
committed by
Nicola Demo
parent
4c4482b155
commit
42ab1a666b
@@ -10,16 +10,16 @@ class GraphNeuralKernel(torch.nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
width,
|
||||
edge_features,
|
||||
n_layers=2,
|
||||
internal_n_layers=0,
|
||||
internal_layers=None,
|
||||
inner_size=None,
|
||||
internal_func=None,
|
||||
external_func=None,
|
||||
shared_weights=False
|
||||
self,
|
||||
width,
|
||||
edge_features,
|
||||
n_layers=2,
|
||||
internal_n_layers=0,
|
||||
internal_layers=None,
|
||||
inner_size=None,
|
||||
internal_func=None,
|
||||
external_func=None,
|
||||
shared_weights=False,
|
||||
):
|
||||
"""
|
||||
The Graph Neural Kernel constructor.
|
||||
@@ -53,21 +53,24 @@ class GraphNeuralKernel(torch.nn.Module):
|
||||
layers=internal_layers,
|
||||
inner_size=inner_size,
|
||||
internal_func=internal_func,
|
||||
external_func=external_func)
|
||||
external_func=external_func,
|
||||
)
|
||||
self.n_layers = n_layers
|
||||
self.forward = self.forward_shared
|
||||
else:
|
||||
self.layers = torch.nn.ModuleList(
|
||||
[GNOBlock(
|
||||
width=width,
|
||||
edges_features=edge_features,
|
||||
n_layers=internal_n_layers,
|
||||
layers=internal_layers,
|
||||
inner_size=inner_size,
|
||||
internal_func=internal_func,
|
||||
external_func=external_func
|
||||
)
|
||||
for _ in range(n_layers)]
|
||||
[
|
||||
GNOBlock(
|
||||
width=width,
|
||||
edges_features=edge_features,
|
||||
n_layers=internal_n_layers,
|
||||
layers=internal_layers,
|
||||
inner_size=inner_size,
|
||||
internal_func=internal_func,
|
||||
external_func=external_func,
|
||||
)
|
||||
for _ in range(n_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x, edge_index, edge_attr):
|
||||
@@ -107,17 +110,17 @@ class GraphNeuralOperator(KernelNeuralOperator):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lifting_operator,
|
||||
projection_operator,
|
||||
edge_features,
|
||||
n_layers=10,
|
||||
internal_n_layers=0,
|
||||
inner_size=None,
|
||||
internal_layers=None,
|
||||
internal_func=None,
|
||||
external_func=None,
|
||||
shared_weights=True
|
||||
self,
|
||||
lifting_operator,
|
||||
projection_operator,
|
||||
edge_features,
|
||||
n_layers=10,
|
||||
internal_n_layers=0,
|
||||
inner_size=None,
|
||||
internal_layers=None,
|
||||
internal_func=None,
|
||||
external_func=None,
|
||||
shared_weights=True,
|
||||
):
|
||||
"""
|
||||
The Graph Neural Operator constructor.
|
||||
@@ -158,9 +161,9 @@ class GraphNeuralOperator(KernelNeuralOperator):
|
||||
external_func=external_func,
|
||||
internal_func=internal_func,
|
||||
n_layers=n_layers,
|
||||
shared_weights=shared_weights
|
||||
shared_weights=shared_weights,
|
||||
),
|
||||
projection_operator=projection_operator
|
||||
projection_operator=projection_operator,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
Reference in New Issue
Block a user