Rename classes and modules for GNO

This commit is contained in:
FilippoOlivo
2025-02-05 17:23:46 +01:00
committed by Nicola Demo
parent bd24b0c1c2
commit 6964f4e7d9
5 changed files with 56 additions and 53 deletions

View File

@@ -15,7 +15,7 @@ __all__ = [
"AVNOBlock",
"LowRankBlock",
"RBFBlock",
"GraphIntegralLayer"
"GNOBlock"
]
from .convolution_2d import ContinuousConvBlock
@@ -32,4 +32,4 @@ from .embedding import PeriodicBoundaryEmbedding, FourierFeatureEmbedding
from .avno_layer import AVNOBlock
from .lowrank_layer import LowRankBlock
from .rbf_layer import RBFBlock
from .graph_integral_kernel import GraphIntegralLayer
from .gno_block import GNOBlock

View File

@@ -2,10 +2,11 @@ import torch
from torch_geometric.nn import MessagePassing
class GraphIntegralLayer(MessagePassing):
class GNOBlock(MessagePassing):
"""
TODO: Add documentation
"""
def __init__(
self,
width,
@@ -27,7 +28,7 @@ class GraphIntegralLayer(MessagePassing):
:type n_layers: int
"""
from pina.model import FeedForward
super(GraphIntegralLayer, self).__init__(aggr='mean')
super(GNOBlock, self).__init__(aggr='mean')
self.width = width
if layers is None and inner_size is None:
inner_size = width