import torch import torch.nn as nn from torch_geometric.nn import MessagePassing from torch.nn.utils import spectral_norm class FiniteDifferenceStep(MessagePassing): """ TODO: add docstring. """ def __init__(self): super().__init__(aggr="add") def forward(self, x, edge_index, edge_attr, deg): """ TODO: add docstring. """ out = self.propagate(edge_index, x=x, edge_attr=edge_attr, deg=deg) return out def message(self, x_j, edge_attr): """ TODO: add docstring. """ return x_j * edge_attr def update(self, aggr_out, _): """ TODO: add docstring. """ return aggr_out def aggregate(self, inputs, index, deg): """ TODO: add docstring. """ out = super().aggregate(inputs, index) deg = deg + 1e-7 return out / deg.view(-1, 1)