Rename classes and modules for GNO
This commit is contained in:
committed by
Nicola Demo
parent
bd24b0c1c2
commit
6964f4e7d9
@@ -10,7 +10,7 @@ __all__ = [
|
||||
"AveragingNeuralOperator",
|
||||
"LowRankNeuralOperator",
|
||||
"Spline",
|
||||
"GNO"
|
||||
"GraphNeuralOperator"
|
||||
]
|
||||
|
||||
from .feed_forward import FeedForward, ResidualFeedForward
|
||||
@@ -21,4 +21,4 @@ from .base_no import KernelNeuralOperator
|
||||
from .avno import AveragingNeuralOperator
|
||||
from .lno import LowRankNeuralOperator
|
||||
from .spline import Spline
|
||||
from .gno import GNO
|
||||
from .gno import GraphNeuralOperator
|
||||
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
from torch.nn import Tanh
|
||||
from .layers import GraphIntegralLayer
|
||||
from .layers import GNOBlock
|
||||
from .base_no import KernelNeuralOperator
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ class GraphNeuralKernel(torch.nn.Module):
|
||||
internal_func = Tanh
|
||||
|
||||
if shared_weights:
|
||||
self.layers = GraphIntegralLayer(
|
||||
self.layers = GNOBlock(
|
||||
width=width,
|
||||
edges_features=edge_features,
|
||||
n_layers=internal_n_layers,
|
||||
@@ -58,7 +58,7 @@ class GraphNeuralKernel(torch.nn.Module):
|
||||
self.forward = self.forward_shared
|
||||
else:
|
||||
self.layers = torch.nn.ModuleList(
|
||||
[GraphIntegralLayer(
|
||||
[GNOBlock(
|
||||
width=width,
|
||||
edges_features=edge_features,
|
||||
n_layers=internal_n_layers,
|
||||
@@ -101,7 +101,7 @@ class GraphNeuralKernel(torch.nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class GNO(KernelNeuralOperator):
|
||||
class GraphNeuralOperator(KernelNeuralOperator):
|
||||
"""
|
||||
TODO add docstring
|
||||
"""
|
||||
|
||||
@@ -15,7 +15,7 @@ __all__ = [
|
||||
"AVNOBlock",
|
||||
"LowRankBlock",
|
||||
"RBFBlock",
|
||||
"GraphIntegralLayer"
|
||||
"GNOBlock"
|
||||
]
|
||||
|
||||
from .convolution_2d import ContinuousConvBlock
|
||||
@@ -32,4 +32,4 @@ from .embedding import PeriodicBoundaryEmbedding, FourierFeatureEmbedding
|
||||
from .avno_layer import AVNOBlock
|
||||
from .lowrank_layer import LowRankBlock
|
||||
from .rbf_layer import RBFBlock
|
||||
from .graph_integral_kernel import GraphIntegralLayer
|
||||
from .gno_block import GNOBlock
|
||||
|
||||
@@ -2,10 +2,11 @@ import torch
|
||||
from torch_geometric.nn import MessagePassing
|
||||
|
||||
|
||||
class GraphIntegralLayer(MessagePassing):
|
||||
class GNOBlock(MessagePassing):
|
||||
"""
|
||||
TODO: Add documentation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
width,
|
||||
@@ -27,7 +28,7 @@ class GraphIntegralLayer(MessagePassing):
|
||||
:type n_layers: int
|
||||
"""
|
||||
from pina.model import FeedForward
|
||||
super(GraphIntegralLayer, self).__init__(aggr='mean')
|
||||
super(GNOBlock, self).__init__(aggr='mean')
|
||||
self.width = width
|
||||
if layers is None and inner_size is None:
|
||||
inner_size = width
|
||||
Reference in New Issue
Block a user