Self-loops management in KNNGraph and RadiusGraph (#522)
* Add self-loop option to RadiusGraph and KNNGraph
This commit is contained in:
committed by
Dario Coscia
parent
6ed3ca04fe
commit
ce0c033de1
@@ -67,8 +67,9 @@ def test_build_graph(x, pos):
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_build_radius_graph(x, pos):
|
||||
graph = RadiusGraph(x=x, pos=pos, radius=0.5)
|
||||
@pytest.mark.parametrize("loop", [True, False])
|
||||
def test_build_radius_graph(x, pos, loop):
|
||||
graph = RadiusGraph(x=x, pos=pos, radius=0.5, loop=loop)
|
||||
assert hasattr(graph, "x")
|
||||
assert hasattr(graph, "pos")
|
||||
assert hasattr(graph, "edge_index")
|
||||
@@ -84,6 +85,15 @@ def test_build_radius_graph(x, pos):
|
||||
assert graph.pos.labels == pos.labels
|
||||
else:
|
||||
assert isinstance(graph.pos, torch.Tensor)
|
||||
if not loop:
|
||||
assert (
|
||||
len(
|
||||
torch.nonzero(
|
||||
graph.edge_index[0] == graph.edge_index[1], as_tuple=True
|
||||
)[0]
|
||||
)
|
||||
== 0
|
||||
) # Detect self loops
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -168,8 +178,9 @@ def test_build_radius_graph_custom_edge_attr(x, pos):
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_build_knn_graph(x, pos):
|
||||
graph = KNNGraph(x=x, pos=pos, neighbours=2)
|
||||
@pytest.mark.parametrize("loop", [True, False])
|
||||
def test_build_knn_graph(x, pos, loop):
|
||||
graph = KNNGraph(x=x, pos=pos, neighbours=2, loop=loop)
|
||||
assert hasattr(graph, "x")
|
||||
assert hasattr(graph, "pos")
|
||||
assert hasattr(graph, "edge_index")
|
||||
@@ -186,6 +197,15 @@ def test_build_knn_graph(x, pos):
|
||||
else:
|
||||
assert isinstance(graph.pos, torch.Tensor)
|
||||
assert graph.edge_attr is None
|
||||
self_loops = len(
|
||||
torch.nonzero(
|
||||
graph.edge_index[0] == graph.edge_index[1], as_tuple=True
|
||||
)[0]
|
||||
)
|
||||
if loop:
|
||||
assert self_loops != 0
|
||||
else:
|
||||
assert self_loops == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
Reference in New Issue
Block a user