Formatting

* Adding black as dev dependency
* Formatting pina code
* Formatting tests
This commit is contained in:
Dario Coscia
2025-02-24 11:26:49 +01:00
committed by Nicola Demo
parent 4c4482b155
commit 42ab1a666b
77 changed files with 1170 additions and 924 deletions

View File

@@ -10,16 +10,16 @@ class GraphNeuralKernel(torch.nn.Module):
"""
def __init__(
self,
width,
edge_features,
n_layers=2,
internal_n_layers=0,
internal_layers=None,
inner_size=None,
internal_func=None,
external_func=None,
shared_weights=False
self,
width,
edge_features,
n_layers=2,
internal_n_layers=0,
internal_layers=None,
inner_size=None,
internal_func=None,
external_func=None,
shared_weights=False,
):
"""
The Graph Neural Kernel constructor.
@@ -53,21 +53,24 @@ class GraphNeuralKernel(torch.nn.Module):
layers=internal_layers,
inner_size=inner_size,
internal_func=internal_func,
external_func=external_func)
external_func=external_func,
)
self.n_layers = n_layers
self.forward = self.forward_shared
else:
self.layers = torch.nn.ModuleList(
[GNOBlock(
width=width,
edges_features=edge_features,
n_layers=internal_n_layers,
layers=internal_layers,
inner_size=inner_size,
internal_func=internal_func,
external_func=external_func
)
for _ in range(n_layers)]
[
GNOBlock(
width=width,
edges_features=edge_features,
n_layers=internal_n_layers,
layers=internal_layers,
inner_size=inner_size,
internal_func=internal_func,
external_func=external_func,
)
for _ in range(n_layers)
]
)
def forward(self, x, edge_index, edge_attr):
@@ -107,17 +110,17 @@ class GraphNeuralOperator(KernelNeuralOperator):
"""
def __init__(
self,
lifting_operator,
projection_operator,
edge_features,
n_layers=10,
internal_n_layers=0,
inner_size=None,
internal_layers=None,
internal_func=None,
external_func=None,
shared_weights=True
self,
lifting_operator,
projection_operator,
edge_features,
n_layers=10,
internal_n_layers=0,
inner_size=None,
internal_layers=None,
internal_func=None,
external_func=None,
shared_weights=True,
):
"""
The Graph Neural Operator constructor.
@@ -158,9 +161,9 @@ class GraphNeuralOperator(KernelNeuralOperator):
external_func=external_func,
internal_func=internal_func,
n_layers=n_layers,
shared_weights=shared_weights
shared_weights=shared_weights,
),
projection_operator=projection_operator
projection_operator=projection_operator,
)
def forward(self, x):