54 lines
1.6 KiB
Python
54 lines
1.6 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from torch_geometric.nn import MessagePassing
|
|
|
|
# from torch.nn.utils import spectral_norm
|
|
|
|
|
|
class GCNConvLayer(MessagePassing):
|
|
def __init__(self, in_channels, out_channels):
|
|
super().__init__("add")
|
|
self.lin = nn.Sequential(
|
|
nn.Linear(in_channels, out_channels),
|
|
nn.ReLU(),
|
|
nn.Linear(out_channels, out_channels),
|
|
nn.ReLU(),
|
|
)
|
|
|
|
def _compute_edge_weight(self, edge_index, edge_w, deg):
|
|
""" """
|
|
return edge_w.squeeze() / (
|
|
1 + torch.sqrt(deg[edge_index[0]] * deg[edge_index[1]])
|
|
)
|
|
|
|
def forward(self, x, edge_index, edge_attr, deg):
|
|
edge_w = self._compute_edge_weight(edge_index, edge_attr, deg)
|
|
return self.propagate(edge_index, x=x, edge_weight=edge_w, deg=deg)
|
|
|
|
def message(self, x_j, edge_weight):
|
|
return edge_weight.view(-1, 1) * x_j
|
|
|
|
|
|
class CorrectionNet(nn.Module):
|
|
def __init__(self, hidden_dim=8):
|
|
super().__init__()
|
|
self.enc = nn.Sequential(
|
|
nn.Linear(1, hidden_dim // 2),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_dim // 2, hidden_dim),
|
|
nn.ReLU(),
|
|
)
|
|
self.model = GCNConvLayer(hidden_dim, hidden_dim)
|
|
self.dec = nn.Sequential(
|
|
nn.Linear(hidden_dim, hidden_dim // 2),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_dim // 2, 1),
|
|
nn.ReLU(),
|
|
)
|
|
|
|
def forward(self, x, edge_index, edge_attr, deg):
|
|
h = self.enc(x)
|
|
h = self.model(h, edge_index, edge_attr, deg)
|
|
out = self.dec(h)
|
|
return out
|