""" This module provides an interface to build torch_geometric.data.Data objects. """ import torch from torch_geometric.data import Data from torch_geometric.utils import to_undirected from . import LabelTensor from .utils import check_consistency, is_function class Graph(Data): """ A class to build torch_geometric.data.Data objects. """ def __new__( cls, **kwargs, ): """ :param kwargs: Parameters to construct the Graph object. :return: A new instance of the Graph class. :rtype: Graph """ # create class instance instance = Data.__new__(cls) # check the consistency of types defined in __init__, the others are not # checked (as in pyg Data object) instance._check_type_consistency(**kwargs) return instance def __init__( self, x=None, edge_index=None, pos=None, edge_attr=None, undirected=False, **kwargs, ): """ Initialize the Graph object. :param x: Optional tensor of node features (N, F) where F is the number of features per node. :type x: torch.Tensor, LabelTensor :param torch.Tensor edge_index: A tensor of shape (2, E) representing the indices of the graph's edges. :param pos: A tensor of shape (N, D) representing the positions of N points in D-dimensional space. :type pos: torch.Tensor | LabelTensor :param edge_attr: Optional tensor of edge_featured (E, F') where F' is the number of edge features :param bool undirected: Whether to make the graph undirected :param kwargs: Additional keyword arguments passed to the `torch_geometric.data.Data` class constructor. If the argument is a `torch.Tensor` or `LabelTensor`, it is included in the Data object as a graph parameter. """ # preprocessing self._preprocess_edge_index(edge_index, undirected) # calling init super().__init__( x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos, **kwargs ) def _check_type_consistency(self, **kwargs): # default types, specified in cls.__new__, by default they are Nont # if specified in **kwargs they get override x, pos, edge_index, edge_attr = None, None, None, None if "pos" in kwargs: pos = kwargs["pos"] self._check_pos_consistency(pos) if "edge_index" in kwargs: edge_index = kwargs["edge_index"] self._check_edge_index_consistency(edge_index) if "x" in kwargs: x = kwargs["x"] self._check_x_consistency(x, pos) if "edge_attr" in kwargs: edge_attr = kwargs["edge_attr"] self._check_edge_attr_consistency(edge_attr, edge_index) if "undirected" in kwargs: undirected = kwargs["undirected"] check_consistency(undirected, bool) @staticmethod def _check_pos_consistency(pos): """ Check if the position tensor is consistent. :param torch.Tensor pos: The position tensor. """ if pos is not None: check_consistency(pos, (torch.Tensor, LabelTensor)) if pos.ndim != 2: raise ValueError("pos must be a 2D tensor.") @staticmethod def _check_edge_index_consistency(edge_index): """ Check if the edge index is consistent. :param torch.Tensor edge_index: The edge index tensor. """ check_consistency(edge_index, (torch.Tensor, LabelTensor)) if edge_index.ndim != 2: raise ValueError("edge_index must be a 2D tensor.") if edge_index.size(0) != 2: raise ValueError("edge_index must have shape [2, num_edges].") @staticmethod def _check_edge_attr_consistency(edge_attr, edge_index): """ Check if the edge attr is consistent. :param torch.Tensor edge_attr: The edge attribute tensor. :param torch.Tensor edge_index: The edge index tensor. """ if edge_attr is not None: check_consistency(edge_attr, (torch.Tensor, LabelTensor)) if edge_attr.ndim != 2: raise ValueError("edge_attr must be a 2D tensor.") if edge_attr.size(0) != edge_index.size(1): raise ValueError( "edge_attr must have shape " "[num_edges, num_edge_features], expected " f"num_edges {edge_index.size(1)} " f"got {edge_attr.size(0)}." ) @staticmethod def _check_x_consistency(x, pos=None): """ Check if the input tensor x is consistent with the position tensor pos. :param torch.Tensor x: The input tensor. :param torch.Tensor pos: The position tensor. """ if x is not None: check_consistency(x, (torch.Tensor, LabelTensor)) if x.ndim != 2: raise ValueError("x must be a 2D tensor.") if pos is not None: if x.size(0) != pos.size(0): raise ValueError("Inconsistent number of nodes.") if pos is not None: if x.size(0) != pos.size(0): raise ValueError("Inconsistent number of nodes.") @staticmethod def _preprocess_edge_index(edge_index, undirected): """ Preprocess the edge index. :param torch.Tensor edge_index: The edge index. :param bool undirected: Whether the graph is undirected. :return: The preprocessed edge index. :rtype: torch.Tensor """ if undirected: edge_index = to_undirected(edge_index) return edge_index class GraphBuilder: """ A class that allows the simple definition of Graph instances. """ def __new__( cls, pos, edge_index, x=None, edge_attr=False, custom_edge_func=None, **kwargs, ): """ Creates a new instance of the Graph class. :param pos: A tensor of shape (N, D) representing the positions of N points in D-dimensional space. :type pos: torch.Tensor | LabelTensor :param edge_index: A tensor of shape (2, E) representing the indices of the graph's edges. :type edge_index: torch.Tensor :param x: Optional tensor of node features (N, F) where F is the number of features per node. :type x: torch.Tensor, LabelTensor :param bool edge_attr: Optional edge attributes (E, F) where F is the number of features per edge. :param callable custom_edge_func: A custom function to compute edge attributes. :param kwargs: Additional keyword arguments passed to the Graph class constructor. :return: A Graph instance constructed using the provided information. :rtype: Graph """ edge_attr = cls._create_edge_attr( pos, edge_index, edge_attr, custom_edge_func or cls._build_edge_attr ) return Graph( x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos, **kwargs, ) @staticmethod def _create_edge_attr(pos, edge_index, edge_attr, func): check_consistency(edge_attr, bool) if edge_attr: if is_function(func): return func(pos, edge_index) raise ValueError("custom_edge_func must be a function.") return None @staticmethod def _build_edge_attr(pos, edge_index): return ( (pos[edge_index[0]] - pos[edge_index[1]]) .abs() .as_subclass(torch.Tensor) ) class RadiusGraph(GraphBuilder): """ A class to build a radius graph. """ def __new__(cls, pos, radius, **kwargs): """ Creates a new instance of the Graph class using a radius-based graph construction. :param pos: A tensor of shape (N, D) representing the positions of N points in D-dimensional space. :type pos: torch.Tensor | LabelTensor :param float radius: The radius within which points are connected. :Keyword Arguments: The additional keyword arguments to be passed to GraphBuilder and Graph classes :return: Graph instance containg the information passed in input and the computed edge_index :rtype: Graph """ edge_index = cls.compute_radius_graph(pos, radius) return super().__new__(cls, pos=pos, edge_index=edge_index, **kwargs) @staticmethod def compute_radius_graph(points, radius): """ Computes a radius-based graph for a given set of points. :param points: A tensor of shape (N, D) representing the positions of N points in D-dimensional space. :type points: torch.Tensor | LabelTensor :param float radius: The number of nearest neighbors to find for each point. :rtype torch.Tensor: A tensor of shape (2, E), where E is the number of edges, representing the edge indices of the KNN graph. """ dist = torch.cdist(points, points, p=2) return ( torch.nonzero(dist <= radius, as_tuple=False) .t() .as_subclass(torch.Tensor) ) class KNNGraph(GraphBuilder): """ A class to build a KNN graph. """ def __new__(cls, pos, neighbours, **kwargs): """ Creates a new instance of the Graph class using k-nearest neighbors to compute edge_index. :param pos: A tensor of shape (N, D) representing the positions of N points in D-dimensional space. :type pos: torch.Tensor | LabelTensor :param int neighbours: The number of nearest neighbors to consider when building the graph. :Keyword Arguments: The additional keyword arguments to be passed to GraphBuilder and Graph classes :return: Graph instance containg the information passed in input and the computed edge_index :rtype: Graph """ edge_index = cls.compute_knn_graph(pos, neighbours) return super().__new__(cls, pos=pos, edge_index=edge_index, **kwargs) @staticmethod def compute_knn_graph(points, k): """ Computes the edge_index based k-nearest neighbors graph algorithm :param points: A tensor of shape (N, D) representing the positions of N points in D-dimensional space. :type points: torch.Tensor | LabelTensor :param int k: The number of nearest neighbors to find for each point. :rtype torch.Tensor: A tensor of shape (2, E), where E is the number of edges, representing the edge indices of the KNN graph. """ dist = torch.cdist(points, points, p=2) knn_indices = torch.topk(dist, k=k + 1, largest=False).indices[:, 1:] row = torch.arange(points.size(0)).repeat_interleave(k) col = knn_indices.flatten() return torch.stack([row, col], dim=0).as_subclass(torch.Tensor)