import torch import torch.nn as nn from torch_geometric.nn import MessagePassing from torch.nn.utils import spectral_norm from torch_geometric.nn.conv import GCNConv class GCNConvLayer(MessagePassing): def __init__(self, in_channels, out_channels): super().__init__(aggr="add") self.lin_l = nn.Linear(in_channels, out_channels, bias=True) # self.lin_r = spectral_norm(nn.Linear(in_channels, out_channels, bias=False)) def forward(self, x, edge_index, edge_attr, deg): out = self.propagate(edge_index, x=x, edge_attr=edge_attr, deg=deg) out = self.lin_l(out) return out def message(self, x_j, edge_attr): return x_j * edge_attr.view(-1, 1) def aggregate(self, inputs, index, deg): """ TODO: add docstring. """ out = super().aggregate(inputs, index) deg = deg + 1e-7 return out / deg.view(-1, 1) class CorrectionNet(nn.Module): def __init__(self, input_dim=1, output_dim=1, hidden_dim=8, n_layers=8): super().__init__() self.enc = nn.Linear(input_dim, hidden_dim, bias=False) # self.layers = n_layers # self.l = GCNConv(hidden_dim, hidden_dim, aggr="mean") self.layers = torch.nn.ModuleList( [GCNConv(hidden_dim, hidden_dim, aggr="mean", bias=False) for _ in range(n_layers)] ) self.dec = nn.Linear(hidden_dim, output_dim) def forward(self, x, edge_index, edge_attr,): h = self.enc(x) # h = self.relu(h) for l in self.layers: # print(f"Forward pass layer {_}") h = l(h, edge_index, edge_attr) # h = self.relu(h) out = self.dec(h) return out class MLPNet(nn.Module): def __init__(self, input_dim=1, output_dim=1, hidden_dim=8, n_layers=1): super().__init__() layers = [] func = torch.nn.ReLU self.network = nn.Sequential( nn.Linear(input_dim, hidden_dim), func(), nn.Linear(hidden_dim, hidden_dim), func(), nn.Linear(hidden_dim, hidden_dim), func(), nn.Linear(hidden_dim, output_dim), ) def forward(self, x, edge_index=None, edge_attr=None): return self.network(x)