Rename classes and modules for GNO

This commit is contained in:
FilippoOlivo
2025-02-05 17:23:46 +01:00
committed by Nicola Demo
parent bd24b0c1c2
commit 6964f4e7d9
5 changed files with 56 additions and 53 deletions

View File

@@ -1,6 +1,6 @@
import torch
from torch.nn import Tanh
from .layers import GraphIntegralLayer
from .layers import GNOBlock
from .base_no import KernelNeuralOperator
@@ -46,7 +46,7 @@ class GraphNeuralKernel(torch.nn.Module):
internal_func = Tanh
if shared_weights:
self.layers = GraphIntegralLayer(
self.layers = GNOBlock(
width=width,
edges_features=edge_features,
n_layers=internal_n_layers,
@@ -58,7 +58,7 @@ class GraphNeuralKernel(torch.nn.Module):
self.forward = self.forward_shared
else:
self.layers = torch.nn.ModuleList(
[GraphIntegralLayer(
[GNOBlock(
width=width,
edges_features=edge_features,
n_layers=internal_n_layers,
@@ -101,7 +101,7 @@ class GraphNeuralKernel(torch.nn.Module):
return x
class GNO(KernelNeuralOperator):
class GraphNeuralOperator(KernelNeuralOperator):
"""
TODO add docstring
"""