Rename classes and modules for GNO

This commit is contained in:
FilippoOlivo
2025-02-05 17:23:46 +01:00
committed by Nicola Demo
parent bd24b0c1c2
commit 6964f4e7d9
5 changed files with 56 additions and 53 deletions

View File

@@ -10,7 +10,7 @@ __all__ = [
"AveragingNeuralOperator",
"LowRankNeuralOperator",
"Spline",
"GNO"
"GraphNeuralOperator"
]
from .feed_forward import FeedForward, ResidualFeedForward
@@ -21,4 +21,4 @@ from .base_no import KernelNeuralOperator
from .avno import AveragingNeuralOperator
from .lno import LowRankNeuralOperator
from .spline import Spline
from .gno import GNO
from .gno import GraphNeuralOperator

View File

@@ -1,6 +1,6 @@
import torch
from torch.nn import Tanh
from .layers import GraphIntegralLayer
from .layers import GNOBlock
from .base_no import KernelNeuralOperator
@@ -46,7 +46,7 @@ class GraphNeuralKernel(torch.nn.Module):
internal_func = Tanh
if shared_weights:
self.layers = GraphIntegralLayer(
self.layers = GNOBlock(
width=width,
edges_features=edge_features,
n_layers=internal_n_layers,
@@ -58,7 +58,7 @@ class GraphNeuralKernel(torch.nn.Module):
self.forward = self.forward_shared
else:
self.layers = torch.nn.ModuleList(
[GraphIntegralLayer(
[GNOBlock(
width=width,
edges_features=edge_features,
n_layers=internal_n_layers,
@@ -101,7 +101,7 @@ class GraphNeuralKernel(torch.nn.Module):
return x
class GNO(KernelNeuralOperator):
class GraphNeuralOperator(KernelNeuralOperator):
"""
TODO add docstring
"""

View File

@@ -15,7 +15,7 @@ __all__ = [
"AVNOBlock",
"LowRankBlock",
"RBFBlock",
"GraphIntegralLayer"
"GNOBlock"
]
from .convolution_2d import ContinuousConvBlock
@@ -32,4 +32,4 @@ from .embedding import PeriodicBoundaryEmbedding, FourierFeatureEmbedding
from .avno_layer import AVNOBlock
from .lowrank_layer import LowRankBlock
from .rbf_layer import RBFBlock
from .graph_integral_kernel import GraphIntegralLayer
from .gno_block import GNOBlock

View File

@@ -2,10 +2,11 @@ import torch
from torch_geometric.nn import MessagePassing
class GraphIntegralLayer(MessagePassing):
class GNOBlock(MessagePassing):
"""
TODO: Add documentation
"""
def __init__(
self,
width,
@@ -27,7 +28,7 @@ class GraphIntegralLayer(MessagePassing):
:type n_layers: int
"""
from pina.model import FeedForward
super(GraphIntegralLayer, self).__init__(aggr='mean')
super(GNOBlock, self).__init__(aggr='mean')
self.width = width
if layers is None and inner_size is None:
inner_size = width