import torch import torch.nn as nn from torch_geometric.nn import MessagePassing # from torch.nn.utils import spectral_norm class GCNConvLayer(MessagePassing): def __init__(self, in_channels, out_channels): super().__init__("add") self.lin = nn.Sequential( nn.Linear(in_channels, out_channels), nn.ReLU(), nn.Linear(out_channels, out_channels), nn.ReLU(), ) def _compute_edge_weight(self, edge_index, edge_w, deg): """ """ return edge_w.squeeze() / ( 1 + torch.sqrt(deg[edge_index[0]] * deg[edge_index[1]]) ) def forward(self, x, edge_index, edge_attr, deg): edge_w = self._compute_edge_weight(edge_index, edge_attr, deg) return self.propagate(edge_index, x=x, edge_weight=edge_w, deg=deg) def message(self, x_j, edge_weight): return edge_weight.view(-1, 1) * x_j class CorrectionNet(nn.Module): def __init__(self, hidden_dim=8): super().__init__() self.enc = nn.Sequential( nn.Linear(1, hidden_dim // 2), nn.ReLU(), nn.Linear(hidden_dim // 2, hidden_dim), nn.ReLU(), ) self.model = GCNConvLayer(hidden_dim, hidden_dim) self.dec = nn.Sequential( nn.Linear(hidden_dim, hidden_dim // 2), nn.ReLU(), nn.Linear(hidden_dim // 2, 1), nn.ReLU(), ) def forward(self, x, edge_index, edge_attr, deg): h = self.enc(x) h = self.model(h, edge_index, edge_attr, deg) out = self.dec(h) return out