Rename classes and modules for GNO
This commit is contained in:
committed by
Nicola Demo
parent
bd24b0c1c2
commit
6964f4e7d9
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
from torch.nn import Tanh
|
||||
from .layers import GraphIntegralLayer
|
||||
from .layers import GNOBlock
|
||||
from .base_no import KernelNeuralOperator
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ class GraphNeuralKernel(torch.nn.Module):
|
||||
internal_func = Tanh
|
||||
|
||||
if shared_weights:
|
||||
self.layers = GraphIntegralLayer(
|
||||
self.layers = GNOBlock(
|
||||
width=width,
|
||||
edges_features=edge_features,
|
||||
n_layers=internal_n_layers,
|
||||
@@ -58,7 +58,7 @@ class GraphNeuralKernel(torch.nn.Module):
|
||||
self.forward = self.forward_shared
|
||||
else:
|
||||
self.layers = torch.nn.ModuleList(
|
||||
[GraphIntegralLayer(
|
||||
[GNOBlock(
|
||||
width=width,
|
||||
edges_features=edge_features,
|
||||
n_layers=internal_n_layers,
|
||||
@@ -101,7 +101,7 @@ class GraphNeuralKernel(torch.nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class GNO(KernelNeuralOperator):
|
||||
class GraphNeuralOperator(KernelNeuralOperator):
|
||||
"""
|
||||
TODO add docstring
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user