import torch import torch.nn as nn from torch_geometric.nn import MessagePassing from torch.nn.utils import spectral_norm class FiniteDifferenceStep(MessagePassing): """ TODO: add docstring. """ def __init__(self, aggr: str = "add", root_weight: float = 1.0): super().__init__(aggr=aggr) assert ( aggr == "add" ), "Per somme pesate, l'aggregazione deve essere 'add'." self.correction_net = nn.Sequential( nn.Linear(2, 6), nn.Tanh(), nn.Linear(6, 1), nn.Tanh(), ) self.update_net = nn.Sequential( spectral_norm(nn.Linear(1, 6)), nn.Softplus(), spectral_norm(nn.Linear(6, 1)), nn.Softplus(), ) self.message_net = nn.Sequential( spectral_norm(nn.Linear(1, 6)), nn.Softplus(), spectral_norm(nn.Linear(6, 1)), nn.Softplus(), ) self.p = torch.nn.Parameter(torch.tensor(0.5)) # self.a = torch.nn.Parameter(torch.tensor(root_weight)) def forward(self, x, edge_index, edge_attr, deg): """ TODO: add docstring. """ out = self.propagate(edge_index, x=x, edge_attr=edge_attr, deg=deg) return out def message(self, x_j, edge_attr): """ TODO: add docstring. """ # x_in = torch.cat([x_j, edge_attr.view(-1, 1)], dim=-1) # correction = self.correction_net(x_in) # p = torch.sigmoid(self.p) # return (p * edge_attr.view(-1, 1) + (1 - p) * correction) * x_j return edge_attr.view(-1, 1) * x_j def aggregate(self, inputs, index, deg): """ TODO: add docstring. """ out = super().aggregate(inputs, index) deg = deg + 1e-7 return out / deg.view(-1, 1) def update(self, aggr_out, x): """ TODO: add docstring. """ return self.update_net(aggr_out) class GraphFiniteDifference(nn.Module): """ TODO: add docstring. """ def __init__(self, max_iters: int = 5000, threshold: float = 1e-4): """ TODO: add docstring. """ super().__init__() self.max_iters = max_iters self.threshold = threshold self.fd_step = FiniteDifferenceStep(aggr="add", root_weight=1.0) @staticmethod def _compute_deg(edge_index, edge_attr, num_nodes): """ TODO: add docstring. """ deg = torch.zeros(num_nodes, device=edge_index.device) deg = deg.scatter_add(0, edge_index[1], edge_attr) return deg + 1e-7 @staticmethod def _compute_c_ij(c, edge_index): """ TODO: add docstring. """ return (0.5 * (c[edge_index[0]] + c[edge_index[1]])).squeeze() def forward( self, x, edge_index, edge_attr, c, boundary_mask, boundary_values, **kwargs, ): """ TODO: add docstring. """ edge_attr = 1 / edge_attr[:, -1] c_ij = self._compute_c_ij(c, edge_index) edge_attr = edge_attr * c_ij deg = self._compute_deg(edge_index, edge_attr, x.size(0)) conv_thres = self.threshold * torch.norm(x.detach()) for _i in range(self.max_iters): out = self.fd_step(x, edge_index, edge_attr, deg) out[boundary_mask] = boundary_values.unsqueeze(-1) with torch.no_grad(): residual_norm = torch.norm(out - x) if residual_norm < conv_thres: break x = out.detach() return out, _i + 1