Self-loops management in KNNGraph and RadiusGraph (#522)
* Add self-loop option to RadiusGraph and KNNGraph
This commit is contained in:
committed by
Dario Coscia
parent
6ed3ca04fe
commit
ce0c033de1
@@ -3,6 +3,7 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch_geometric.data import Data, Batch
|
from torch_geometric.data import Data, Batch
|
||||||
from torch_geometric.utils import to_undirected
|
from torch_geometric.utils import to_undirected
|
||||||
|
from torch_geometric.utils.loop import remove_self_loops
|
||||||
from .label_tensor import LabelTensor
|
from .label_tensor import LabelTensor
|
||||||
from .utils import check_consistency, is_function
|
from .utils import check_consistency, is_function
|
||||||
|
|
||||||
@@ -209,6 +210,7 @@ class GraphBuilder:
|
|||||||
x=None,
|
x=None,
|
||||||
edge_attr=False,
|
edge_attr=False,
|
||||||
custom_edge_func=None,
|
custom_edge_func=None,
|
||||||
|
loop=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -224,18 +226,19 @@ class GraphBuilder:
|
|||||||
:param x: Optional tensor of node features of shape ``(N, F)``, where
|
:param x: Optional tensor of node features of shape ``(N, F)``, where
|
||||||
``F`` is the number of features per node.
|
``F`` is the number of features per node.
|
||||||
:type x: torch.Tensor | LabelTensor, optional
|
:type x: torch.Tensor | LabelTensor, optional
|
||||||
:param edge_attr: Optional tensor of edge attributes of shape ``(E, F)``
|
:param bool edge_attr: Whether to compute the edge attributes.
|
||||||
, where ``F`` is the number of features per edge.
|
|
||||||
:type edge_attr: torch.Tensor, optional
|
|
||||||
:param custom_edge_func: A custom function to compute edge attributes.
|
:param custom_edge_func: A custom function to compute edge attributes.
|
||||||
If provided, overrides ``edge_attr``.
|
If provided, overrides ``edge_attr``.
|
||||||
:type custom_edge_func: Callable, optional
|
:type custom_edge_func: Callable, optional
|
||||||
|
:param bool loop: Whether to include self-loops.
|
||||||
:param kwargs: Additional keyword arguments passed to the
|
:param kwargs: Additional keyword arguments passed to the
|
||||||
:class:`~pina.graph.Graph` class constructor.
|
:class:`~pina.graph.Graph` class constructor.
|
||||||
:return: A :class:`~pina.graph.Graph` instance constructed using the
|
:return: A :class:`~pina.graph.Graph` instance constructed using the
|
||||||
provided information.
|
provided information.
|
||||||
:rtype: Graph
|
:rtype: Graph
|
||||||
"""
|
"""
|
||||||
|
if not loop:
|
||||||
|
edge_index = remove_self_loops(edge_index)[0]
|
||||||
edge_attr = cls._create_edge_attr(
|
edge_attr = cls._create_edge_attr(
|
||||||
pos, edge_index, edge_attr, custom_edge_func or cls._build_edge_attr
|
pos, edge_index, edge_attr, custom_edge_func or cls._build_edge_attr
|
||||||
)
|
)
|
||||||
@@ -374,11 +377,8 @@ class KNNGraph(GraphBuilder):
|
|||||||
representing the edge indices of the graph.
|
representing the edge indices of the graph.
|
||||||
:rtype: torch.Tensor
|
:rtype: torch.Tensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dist = torch.cdist(points, points, p=2)
|
dist = torch.cdist(points, points, p=2)
|
||||||
knn_indices = torch.topk(dist, k=neighbours + 1, largest=False).indices[
|
knn_indices = torch.topk(dist, k=neighbours, largest=False).indices
|
||||||
:, 1:
|
|
||||||
]
|
|
||||||
row = torch.arange(points.size(0)).repeat_interleave(neighbours)
|
row = torch.arange(points.size(0)).repeat_interleave(neighbours)
|
||||||
col = knn_indices.flatten()
|
col = knn_indices.flatten()
|
||||||
return torch.stack([row, col], dim=0).as_subclass(torch.Tensor)
|
return torch.stack([row, col], dim=0).as_subclass(torch.Tensor)
|
||||||
|
|||||||
@@ -67,8 +67,9 @@ def test_build_graph(x, pos):
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_build_radius_graph(x, pos):
|
@pytest.mark.parametrize("loop", [True, False])
|
||||||
graph = RadiusGraph(x=x, pos=pos, radius=0.5)
|
def test_build_radius_graph(x, pos, loop):
|
||||||
|
graph = RadiusGraph(x=x, pos=pos, radius=0.5, loop=loop)
|
||||||
assert hasattr(graph, "x")
|
assert hasattr(graph, "x")
|
||||||
assert hasattr(graph, "pos")
|
assert hasattr(graph, "pos")
|
||||||
assert hasattr(graph, "edge_index")
|
assert hasattr(graph, "edge_index")
|
||||||
@@ -84,6 +85,15 @@ def test_build_radius_graph(x, pos):
|
|||||||
assert graph.pos.labels == pos.labels
|
assert graph.pos.labels == pos.labels
|
||||||
else:
|
else:
|
||||||
assert isinstance(graph.pos, torch.Tensor)
|
assert isinstance(graph.pos, torch.Tensor)
|
||||||
|
if not loop:
|
||||||
|
assert (
|
||||||
|
len(
|
||||||
|
torch.nonzero(
|
||||||
|
graph.edge_index[0] == graph.edge_index[1], as_tuple=True
|
||||||
|
)[0]
|
||||||
|
)
|
||||||
|
== 0
|
||||||
|
) # Detect self loops
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -168,8 +178,9 @@ def test_build_radius_graph_custom_edge_attr(x, pos):
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_build_knn_graph(x, pos):
|
@pytest.mark.parametrize("loop", [True, False])
|
||||||
graph = KNNGraph(x=x, pos=pos, neighbours=2)
|
def test_build_knn_graph(x, pos, loop):
|
||||||
|
graph = KNNGraph(x=x, pos=pos, neighbours=2, loop=loop)
|
||||||
assert hasattr(graph, "x")
|
assert hasattr(graph, "x")
|
||||||
assert hasattr(graph, "pos")
|
assert hasattr(graph, "pos")
|
||||||
assert hasattr(graph, "edge_index")
|
assert hasattr(graph, "edge_index")
|
||||||
@@ -186,6 +197,15 @@ def test_build_knn_graph(x, pos):
|
|||||||
else:
|
else:
|
||||||
assert isinstance(graph.pos, torch.Tensor)
|
assert isinstance(graph.pos, torch.Tensor)
|
||||||
assert graph.edge_attr is None
|
assert graph.edge_attr is None
|
||||||
|
self_loops = len(
|
||||||
|
torch.nonzero(
|
||||||
|
graph.edge_index[0] == graph.edge_index[1], as_tuple=True
|
||||||
|
)[0]
|
||||||
|
)
|
||||||
|
if loop:
|
||||||
|
assert self_loops != 0
|
||||||
|
else:
|
||||||
|
assert self_loops == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|||||||
Reference in New Issue
Block a user