import torch import torch.nn as nn from torch_geometric.nn import MessagePassing from torch.nn.utils import spectral_norm class FiniteDifferenceStep(MessagePassing): """ TODO: add docstring. """ def __init__(self, edge_ch=5, hidden_dim=16, aggr: str = "add"): super().__init__(aggr=aggr) self.x_embedding = nn.Sequential( spectral_norm(nn.Linear(1, hidden_dim // 2)), nn.GELU(), spectral_norm(nn.Linear(hidden_dim // 2, hidden_dim)), ) self.update_net = nn.Sequential( spectral_norm(nn.Linear(2 * hidden_dim, hidden_dim)), nn.GELU(), spectral_norm(nn.Linear(hidden_dim, hidden_dim)), ) self.out_net = nn.Sequential( spectral_norm(nn.Linear(hidden_dim, hidden_dim // 2)), nn.GELU(), spectral_norm(nn.Linear(hidden_dim // 2, 1)), ) def forward(self, x, edge_index, edge_attr, deg): """ TODO: add docstring. """ x_ = self.x_embedding(x) out = self.propagate(edge_index, x=x_, edge_attr=edge_attr, deg=deg) return self.out_net(x_ + out) def message(self, x_i, x_j, edge_attr): """ TODO: add docstring. """ return (x_j - x_i) * edge_attr.view(-1, 1) def update(self, aggr_out, x): """ TODO: add docstring. """ update_input = torch.cat([x, aggr_out], dim=-1) return self.update_net(update_input) def aggregate(self, inputs, index, deg): """ TODO: add docstring. """ out = super().aggregate(inputs, index) deg = deg + 1e-7 return out / deg.view(-1, 1)