297 lines
13 KiB
Python
297 lines
13 KiB
Python
from logging import warning
|
|
|
|
import torch
|
|
|
|
from . import LabelTensor
|
|
from torch_geometric.data import Data
|
|
from torch_geometric.utils import to_undirected
|
|
import inspect
|
|
|
|
|
|
class Graph:
|
|
"""
|
|
Class for the graph construction.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
x,
|
|
pos,
|
|
edge_index,
|
|
edge_attr=None,
|
|
build_edge_attr=False,
|
|
undirected=False,
|
|
custom_build_edge_attr=None,
|
|
additional_params=None
|
|
):
|
|
"""
|
|
Constructor for the Graph class. This object creates a list of PyTorch Geometric Data objects.
|
|
Based on the input of x and pos there could be the following cases:
|
|
1. 1 pos, 1 x: a single graph will be created
|
|
2. N pos, 1 x: N graphs will be created with the same node features
|
|
3. 1 pos, N x: N graphs will be created with the same nodes but different node features
|
|
4. N pos, N x: N graphs will be created
|
|
|
|
:param x: Node features. Can be a single 2D tensor of shape [num_nodes, num_node_features],
|
|
or a 3D tensor of shape [n_graphs, num_nodes, num_node_features]
|
|
or a list of such 2D tensors of shape [num_nodes, num_node_features].
|
|
:type x: torch.Tensor or list[torch.Tensor]
|
|
:param pos: Node coordinates. Can be a single 2D tensor of shape [num_nodes, num_coordinates],
|
|
or a 3D tensor of shape [n_graphs, num_nodes, num_coordinates]
|
|
or a list of such 2D tensors of shape [num_nodes, num_coordinates].
|
|
:type pos: torch.Tensor or list[torch.Tensor]
|
|
:param edge_index: The edge index defining connections between nodes.
|
|
It should be a 2D tensor of shape [2, num_edges]
|
|
or a 3D tensor of shape [n_graphs, 2, num_edges]
|
|
or a list of such 2D tensors of shape [2, num_edges].
|
|
:type edge_index: torch.Tensor or list[torch.Tensor]
|
|
:param edge_attr: Edge features. If provided, should have the shape [num_edges, num_edge_features]
|
|
or be a list of such tensors for multiple graphs.
|
|
:type edge_attr: torch.Tensor or list[torch.Tensor], optional
|
|
:param build_edge_attr: Whether to compute edge attributes during initialization.
|
|
:type build_edge_attr: bool, default=False
|
|
:param undirected: If True, converts the graph(s) into an undirected graph by adding reciprocal edges.
|
|
:type undirected: bool, default=False
|
|
:param custom_build_edge_attr: A user-defined function to generate edge attributes dynamically.
|
|
The function should take (x, pos, edge_index) as input and return a tensor
|
|
of shape [num_edges, num_edge_features].
|
|
:type custom_build_edge_attr: function or callable, optional
|
|
:param additional_params: Dictionary containing extra attributes to be added to each Data object.
|
|
Keys represent attribute names, and values should be tensors or lists of tensors.
|
|
:type additional_params: dict, optional
|
|
|
|
Note: if x, pos, and edge_index are both lists or 3D tensors, then len(x) == len(pos) == len(edge_index).
|
|
"""
|
|
|
|
self.data = []
|
|
x, pos, edge_index = self._check_input_consistency(x, pos, edge_index)
|
|
|
|
# Check input dimension consistency and store the number of graphs
|
|
data_len = self._check_len_consistency(x, pos)
|
|
if inspect.isfunction(custom_build_edge_attr):
|
|
self._build_edge_attr = custom_build_edge_attr
|
|
|
|
# Check consistency and initialize additional_parameters (if present)
|
|
additional_params = self._check_additional_params(additional_params,
|
|
data_len)
|
|
|
|
# Make the graphs undirected
|
|
if undirected:
|
|
if isinstance(edge_index, list):
|
|
edge_index = [to_undirected(e) for e in edge_index]
|
|
else:
|
|
edge_index = to_undirected(edge_index)
|
|
|
|
# Prepare internal lists to create a graph list (same positions but
|
|
# different node features)
|
|
if isinstance(x, list) and isinstance(pos,
|
|
(torch.Tensor, LabelTensor)):
|
|
# Replicate the positions, edge_index and edge_attr
|
|
pos, edge_index = [pos] * data_len, [edge_index] * data_len
|
|
# Prepare internal lists to create a list containing a single graph
|
|
elif isinstance(x, (torch.Tensor, LabelTensor)) and isinstance(pos, (
|
|
torch.Tensor, LabelTensor)):
|
|
# Encapsulate the input tensors into lists
|
|
x, pos, edge_index = [x], [pos], [edge_index]
|
|
# Prepare internal lists to create a list of graphs (same node features
|
|
# but different positions)
|
|
elif (isinstance(x, (torch.Tensor, LabelTensor))
|
|
and isinstance(pos, list)):
|
|
# Replicate the node features
|
|
x = [x] * data_len
|
|
elif not isinstance(x, list) and not isinstance(pos, list):
|
|
raise TypeError("x and pos must be lists or tensors.")
|
|
|
|
# Build the edge attributes
|
|
edge_attr = self._check_and_build_edge_attr(edge_attr, build_edge_attr,
|
|
data_len, edge_index, pos,
|
|
x)
|
|
|
|
# Perform the graph construction
|
|
self._build_graph_list(x, pos, edge_index, edge_attr, additional_params)
|
|
|
|
def _build_graph_list(self, x, pos, edge_index, edge_attr,
|
|
additional_params):
|
|
for i, (x_, pos_, edge_index_) in enumerate(zip(x, pos, edge_index)):
|
|
if isinstance(x_, LabelTensor):
|
|
x_ = x_.tensor
|
|
add_params_local = {k: v[i] for k, v in additional_params.items()}
|
|
if edge_attr is not None:
|
|
|
|
self.data.append(Data(x=x_, pos=pos_, edge_index=edge_index_,
|
|
edge_attr=edge_attr[i],
|
|
**add_params_local))
|
|
else:
|
|
self.data.append(Data(x=x_, pos=pos_, edge_index=edge_index_,
|
|
**add_params_local))
|
|
|
|
@staticmethod
|
|
def _build_edge_attr(x, pos, edge_index):
|
|
distance = torch.abs(pos[edge_index[0]] - pos[edge_index[1]])
|
|
return distance
|
|
|
|
@staticmethod
|
|
def _check_len_consistency(x, pos):
|
|
if isinstance(x, list) and isinstance(pos, list):
|
|
if len(x) != len(pos):
|
|
raise ValueError("x and pos must have the same length.")
|
|
return max(len(x), len(pos))
|
|
elif isinstance(x, list) and not isinstance(pos, list):
|
|
return len(x)
|
|
elif not isinstance(x, list) and isinstance(pos, list):
|
|
return len(pos)
|
|
else:
|
|
return 1
|
|
|
|
@staticmethod
|
|
def _check_input_consistency(x, pos, edge_index=None):
|
|
# If x is a 3D tensor, we split it into a list of 2D tensors
|
|
if isinstance(x, torch.Tensor) and x.ndim == 3:
|
|
x = [x[i] for i in range(x.shape[0])]
|
|
elif (not (isinstance(x, list) and all(t.ndim == 2 for t in x)) and
|
|
not (isinstance(x, torch.Tensor) and x.ndim == 2)):
|
|
raise TypeError("x must be either a list of 2D tensors or a 2D "
|
|
"tensor or a 3D tensor")
|
|
|
|
# If pos is a 3D tensor, we split it into a list of 2D tensors
|
|
if isinstance(pos, torch.Tensor) and pos.ndim == 3:
|
|
pos = [pos[i] for i in range(pos.shape[0])]
|
|
elif not (isinstance(pos, list) and all(
|
|
t.ndim == 2 for t in pos)) and not (
|
|
isinstance(pos, torch.Tensor) and pos.ndim == 2):
|
|
raise TypeError("pos must be either a list of 2D tensors or a 2D "
|
|
"tensor or a 3D tensor")
|
|
|
|
# If edge_index is a 3D tensor, we split it into a list of 2D tensors
|
|
if edge_index is not None:
|
|
if isinstance(edge_index, torch.Tensor) and edge_index.ndim == 3:
|
|
edge_index = [edge_index[i] for i in range(edge_index.shape[0])]
|
|
elif not (isinstance(edge_index, list) and all(
|
|
t.ndim == 2 for t in edge_index)) and not (
|
|
isinstance(edge_index,
|
|
torch.Tensor) and edge_index.ndim == 2):
|
|
raise TypeError(
|
|
"edge_index must be either a list of 2D tensors or a 2D "
|
|
"tensor or a 3D tensor")
|
|
|
|
return x, pos, edge_index
|
|
|
|
@staticmethod
|
|
def _check_additional_params(additional_params, data_len):
|
|
if additional_params is not None:
|
|
if not isinstance(additional_params, dict):
|
|
raise TypeError("additional_params must be a dictionary.")
|
|
for param, val in additional_params.items():
|
|
# Check if the values are tensors or lists of tensors
|
|
if isinstance(val, torch.Tensor):
|
|
# If the tensor is 3D, we split it into a list of 2D tensors
|
|
# In this case there must be a additional parameter for each
|
|
# node
|
|
if val.ndim == 3:
|
|
additional_params[param] = [val[i] for i in
|
|
range(val.shape[0])]
|
|
# If the tensor is 2D, we replicate it for each node
|
|
elif val.ndim == 2:
|
|
additional_params[param] = [val] * data_len
|
|
# If the tensor is 1D, each graph has a scalar values as
|
|
# additional parameter
|
|
if val.ndim == 1:
|
|
if len(val) == data_len:
|
|
additional_params[param] = [val[i] for i in
|
|
range(len(val))]
|
|
else:
|
|
additional_params[param] = [val for _ in
|
|
range(data_len)]
|
|
elif not isinstance(val, list):
|
|
raise TypeError("additional_params values must be tensors "
|
|
"or lists of tensors.")
|
|
else:
|
|
additional_params = {}
|
|
return additional_params
|
|
|
|
def _check_and_build_edge_attr(self, edge_attr, build_edge_attr, data_len,
|
|
edge_index, pos, x):
|
|
# Check if edge_attr is consistent with x and pos
|
|
if edge_attr is not None:
|
|
if build_edge_attr is True:
|
|
warning("edge_attr is not None. build_edge_attr will not be "
|
|
"considered.")
|
|
if isinstance(edge_attr, list):
|
|
if len(edge_attr) != data_len:
|
|
raise TypeError("edge_attr must have the same length as x "
|
|
"and pos.")
|
|
return [edge_attr] * data_len
|
|
|
|
if build_edge_attr:
|
|
return [self._build_edge_attr(x, pos_, edge_index_) for
|
|
pos_, edge_index_ in zip(pos, edge_index)]
|
|
|
|
|
|
class RadiusGraph(Graph):
|
|
def __init__(
|
|
self,
|
|
x,
|
|
pos,
|
|
r,
|
|
**kwargs
|
|
):
|
|
x, pos, edge_index = Graph._check_input_consistency(x, pos)
|
|
|
|
if isinstance(pos, (torch.Tensor, LabelTensor)):
|
|
edge_index = RadiusGraph._radius_graph(pos, r)
|
|
else:
|
|
edge_index = [RadiusGraph._radius_graph(p, r) for p in pos]
|
|
|
|
super().__init__(x=x, pos=pos, edge_index=edge_index,
|
|
**kwargs)
|
|
|
|
@staticmethod
|
|
def _radius_graph(points, r):
|
|
"""
|
|
Implementation of the radius graph construction.
|
|
:param points: The input points.
|
|
:type points: torch.Tensor
|
|
:param r: The radius.
|
|
:type r: float
|
|
:return: The edge index.
|
|
:rtype: torch.Tensor
|
|
"""
|
|
dist = torch.cdist(points, points, p=2)
|
|
edge_index = torch.nonzero(dist <= r, as_tuple=False).t()
|
|
return edge_index
|
|
|
|
|
|
class KNNGraph(Graph):
|
|
def __init__(
|
|
self,
|
|
x,
|
|
pos,
|
|
k,
|
|
**kwargs
|
|
):
|
|
x, pos, edge_index = Graph._check_input_consistency(x, pos)
|
|
if isinstance(pos, (torch.Tensor, LabelTensor)):
|
|
edge_index = KNNGraph._knn_graph(pos, k)
|
|
else:
|
|
edge_index = [KNNGraph._knn_graph(p, k) for p in pos]
|
|
super().__init__(x=x, pos=pos, edge_index=edge_index,
|
|
**kwargs)
|
|
|
|
@staticmethod
|
|
def _knn_graph(points, k):
|
|
"""
|
|
Implementation of the k-nearest neighbors graph construction.
|
|
:param points: The input points.
|
|
:type points: torch.Tensor
|
|
:param k: The number of nearest neighbors.
|
|
:type k: int
|
|
:return: The edge index.
|
|
:rtype: torch.Tensor
|
|
"""
|
|
dist = torch.cdist(points, points, p=2)
|
|
knn_indices = torch.topk(dist, k=k + 1, largest=False).indices[:, 1:]
|
|
row = torch.arange(points.size(0)).repeat_interleave(k)
|
|
col = knn_indices.flatten()
|
|
edge_index = torch.stack([row, col], dim=0)
|
|
return edge_index
|