Self-loops management in KNNGraph and RadiusGraph (#522)

* Add self-loop option to RadiusGraph and KNNGraph
This commit is contained in:
Filippo Olivo
2025-03-31 16:55:36 +02:00
committed by Dario Coscia
parent 6ed3ca04fe
commit ce0c033de1
2 changed files with 31 additions and 11 deletions

View File

@@ -3,6 +3,7 @@
import torch import torch
from torch_geometric.data import Data, Batch from torch_geometric.data import Data, Batch
from torch_geometric.utils import to_undirected from torch_geometric.utils import to_undirected
from torch_geometric.utils.loop import remove_self_loops
from .label_tensor import LabelTensor from .label_tensor import LabelTensor
from .utils import check_consistency, is_function from .utils import check_consistency, is_function
@@ -209,6 +210,7 @@ class GraphBuilder:
x=None, x=None,
edge_attr=False, edge_attr=False,
custom_edge_func=None, custom_edge_func=None,
loop=True,
**kwargs, **kwargs,
): ):
""" """
@@ -224,18 +226,19 @@ class GraphBuilder:
:param x: Optional tensor of node features of shape ``(N, F)``, where :param x: Optional tensor of node features of shape ``(N, F)``, where
``F`` is the number of features per node. ``F`` is the number of features per node.
:type x: torch.Tensor | LabelTensor, optional :type x: torch.Tensor | LabelTensor, optional
:param edge_attr: Optional tensor of edge attributes of shape ``(E, F)`` :param bool edge_attr: Whether to compute the edge attributes.
, where ``F`` is the number of features per edge.
:type edge_attr: torch.Tensor, optional
:param custom_edge_func: A custom function to compute edge attributes. :param custom_edge_func: A custom function to compute edge attributes.
If provided, overrides ``edge_attr``. If provided, overrides ``edge_attr``.
:type custom_edge_func: Callable, optional :type custom_edge_func: Callable, optional
:param bool loop: Whether to include self-loops.
:param kwargs: Additional keyword arguments passed to the :param kwargs: Additional keyword arguments passed to the
:class:`~pina.graph.Graph` class constructor. :class:`~pina.graph.Graph` class constructor.
:return: A :class:`~pina.graph.Graph` instance constructed using the :return: A :class:`~pina.graph.Graph` instance constructed using the
provided information. provided information.
:rtype: Graph :rtype: Graph
""" """
if not loop:
edge_index = remove_self_loops(edge_index)[0]
edge_attr = cls._create_edge_attr( edge_attr = cls._create_edge_attr(
pos, edge_index, edge_attr, custom_edge_func or cls._build_edge_attr pos, edge_index, edge_attr, custom_edge_func or cls._build_edge_attr
) )
@@ -374,11 +377,8 @@ class KNNGraph(GraphBuilder):
representing the edge indices of the graph. representing the edge indices of the graph.
:rtype: torch.Tensor :rtype: torch.Tensor
""" """
dist = torch.cdist(points, points, p=2) dist = torch.cdist(points, points, p=2)
knn_indices = torch.topk(dist, k=neighbours + 1, largest=False).indices[ knn_indices = torch.topk(dist, k=neighbours, largest=False).indices
:, 1:
]
row = torch.arange(points.size(0)).repeat_interleave(neighbours) row = torch.arange(points.size(0)).repeat_interleave(neighbours)
col = knn_indices.flatten() col = knn_indices.flatten()
return torch.stack([row, col], dim=0).as_subclass(torch.Tensor) return torch.stack([row, col], dim=0).as_subclass(torch.Tensor)

View File

@@ -67,8 +67,9 @@ def test_build_graph(x, pos):
), ),
], ],
) )
def test_build_radius_graph(x, pos): @pytest.mark.parametrize("loop", [True, False])
graph = RadiusGraph(x=x, pos=pos, radius=0.5) def test_build_radius_graph(x, pos, loop):
graph = RadiusGraph(x=x, pos=pos, radius=0.5, loop=loop)
assert hasattr(graph, "x") assert hasattr(graph, "x")
assert hasattr(graph, "pos") assert hasattr(graph, "pos")
assert hasattr(graph, "edge_index") assert hasattr(graph, "edge_index")
@@ -84,6 +85,15 @@ def test_build_radius_graph(x, pos):
assert graph.pos.labels == pos.labels assert graph.pos.labels == pos.labels
else: else:
assert isinstance(graph.pos, torch.Tensor) assert isinstance(graph.pos, torch.Tensor)
if not loop:
assert (
len(
torch.nonzero(
graph.edge_index[0] == graph.edge_index[1], as_tuple=True
)[0]
)
== 0
) # Detect self loops
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -168,8 +178,9 @@ def test_build_radius_graph_custom_edge_attr(x, pos):
), ),
], ],
) )
def test_build_knn_graph(x, pos): @pytest.mark.parametrize("loop", [True, False])
graph = KNNGraph(x=x, pos=pos, neighbours=2) def test_build_knn_graph(x, pos, loop):
graph = KNNGraph(x=x, pos=pos, neighbours=2, loop=loop)
assert hasattr(graph, "x") assert hasattr(graph, "x")
assert hasattr(graph, "pos") assert hasattr(graph, "pos")
assert hasattr(graph, "edge_index") assert hasattr(graph, "edge_index")
@@ -186,6 +197,15 @@ def test_build_knn_graph(x, pos):
else: else:
assert isinstance(graph.pos, torch.Tensor) assert isinstance(graph.pos, torch.Tensor)
assert graph.edge_attr is None assert graph.edge_attr is None
self_loops = len(
torch.nonzero(
graph.edge_index[0] == graph.edge_index[1], as_tuple=True
)[0]
)
if loop:
assert self_loops != 0
else:
assert self_loops == 0
@pytest.mark.parametrize( @pytest.mark.parametrize(