Self-loops management in KNNGraph and RadiusGraph (#522)
* Add self-loop option to RadiusGraph and KNNGraph
This commit is contained in:
committed by
Dario Coscia
parent
6ed3ca04fe
commit
ce0c033de1
@@ -3,6 +3,7 @@
|
||||
import torch
|
||||
from torch_geometric.data import Data, Batch
|
||||
from torch_geometric.utils import to_undirected
|
||||
from torch_geometric.utils.loop import remove_self_loops
|
||||
from .label_tensor import LabelTensor
|
||||
from .utils import check_consistency, is_function
|
||||
|
||||
@@ -209,6 +210,7 @@ class GraphBuilder:
|
||||
x=None,
|
||||
edge_attr=False,
|
||||
custom_edge_func=None,
|
||||
loop=True,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -224,18 +226,19 @@ class GraphBuilder:
|
||||
:param x: Optional tensor of node features of shape ``(N, F)``, where
|
||||
``F`` is the number of features per node.
|
||||
:type x: torch.Tensor | LabelTensor, optional
|
||||
:param edge_attr: Optional tensor of edge attributes of shape ``(E, F)``
|
||||
, where ``F`` is the number of features per edge.
|
||||
:type edge_attr: torch.Tensor, optional
|
||||
:param bool edge_attr: Whether to compute the edge attributes.
|
||||
:param custom_edge_func: A custom function to compute edge attributes.
|
||||
If provided, overrides ``edge_attr``.
|
||||
:type custom_edge_func: Callable, optional
|
||||
:param bool loop: Whether to include self-loops.
|
||||
:param kwargs: Additional keyword arguments passed to the
|
||||
:class:`~pina.graph.Graph` class constructor.
|
||||
:return: A :class:`~pina.graph.Graph` instance constructed using the
|
||||
provided information.
|
||||
:rtype: Graph
|
||||
"""
|
||||
if not loop:
|
||||
edge_index = remove_self_loops(edge_index)[0]
|
||||
edge_attr = cls._create_edge_attr(
|
||||
pos, edge_index, edge_attr, custom_edge_func or cls._build_edge_attr
|
||||
)
|
||||
@@ -374,11 +377,8 @@ class KNNGraph(GraphBuilder):
|
||||
representing the edge indices of the graph.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
|
||||
dist = torch.cdist(points, points, p=2)
|
||||
knn_indices = torch.topk(dist, k=neighbours + 1, largest=False).indices[
|
||||
:, 1:
|
||||
]
|
||||
knn_indices = torch.topk(dist, k=neighbours, largest=False).indices
|
||||
row = torch.arange(points.size(0)).repeat_interleave(neighbours)
|
||||
col = knn_indices.flatten()
|
||||
return torch.stack([row, col], dim=0).as_subclass(torch.Tensor)
|
||||
|
||||
Reference in New Issue
Block a user