Refactor Graph class to support custom edge attribute logic

This commit is contained in:
FilippoOlivo
2025-02-05 10:59:48 +01:00
committed by Nicola Demo
parent 78b276d995
commit bbdd5d4bf1
2 changed files with 128 additions and 86 deletions

View File

@@ -5,6 +5,7 @@ import torch
from . import LabelTensor from . import LabelTensor
from torch_geometric.data import Data from torch_geometric.data import Data
from torch_geometric.utils import to_undirected from torch_geometric.utils import to_undirected
import inspect
class Graph: class Graph:
@@ -12,14 +13,17 @@ class Graph:
Class for the graph construction. Class for the graph construction.
""" """
def __init__(self, def __init__(
self,
x, x,
pos, pos,
edge_index, edge_index,
edge_attr=None, edge_attr=None,
build_edge_attr=False, build_edge_attr=False,
undirected=False, undirected=False,
additional_params=None): custom_build_edge_attr=None,
additional_params=None
):
""" """
Constructor for the Graph class. Constructor for the Graph class.
:param x: The node features. :param x: The node features.
@@ -34,45 +38,23 @@ class Graph:
:type build_edge_attr: bool :type build_edge_attr: bool
:param undirected: Whether to build an undirected graph. :param undirected: Whether to build an undirected graph.
:type undirected: bool :type undirected: bool
:param custom_build_edge_attr: Custom function to build the edge
attributes.
:type custom_build_edge_attr: function
:param additional_params: Additional parameters. :param additional_params: Additional parameters.
:type additional_params: dict :type additional_params: dict
""" """
self.data = [] self.data = []
x, pos, edge_index = Graph._check_input_consistency(x, pos, edge_index) x, pos, edge_index = self._check_input_consistency(x, pos, edge_index)
# Check input dimension consistency and store the number of graphs # Check input dimension consistency and store the number of graphs
data_len = self._check_len_consistency(x, pos) data_len = self._check_len_consistency(x, pos)
if inspect.isfunction(custom_build_edge_attr):
self._build_edge_attr = custom_build_edge_attr
# Initialize additional_parameters (if present) # Check consistency and initialize additional_parameters (if present)
if additional_params is not None: additional_params = self._check_additional_params(additional_params,
if not isinstance(additional_params, dict): data_len)
raise TypeError("additional_params must be a dictionary.")
for param, val in additional_params.items():
# Check if the values are tensors or lists of tensors
if isinstance(val, torch.Tensor):
# If the tensor is 3D, we split it into a list of 2D tensors
# In this case there must be a additional parameter for each
# node
if val.ndim == 3:
additional_params[param] = [val[i] for i in
range(val.shape[0])]
# If the tensor is 2D, we replicate it for each node
elif val.ndim == 2:
additional_params[param] = [val] * data_len
# If the tensor is 1D, each graph has a scalar values as
# additional parameter
if val.ndim == 1:
if len(val) == data_len:
additional_params[param] = [val[i] for i in
range(len(val))]
else:
additional_params[param] = [val for _ in
range(data_len)]
elif not isinstance(val, list):
raise TypeError("additional_params values must be tensors "
"or lists of tensors.")
else:
additional_params = {}
# Make the graphs undirected # Make the graphs undirected
if undirected: if undirected:
@@ -81,27 +63,17 @@ class Graph:
else: else:
edge_index = to_undirected(edge_index) edge_index = to_undirected(edge_index)
if build_edge_attr:
if edge_attr is not None:
warning("Edge attributes are provided, build_edge_attr is set "
"to True. The provided edge attributes will be ignored.")
edge_attr = self._build_edge_attr(pos, edge_index)
# Prepare internal lists to create a graph list (same positions but # Prepare internal lists to create a graph list (same positions but
# different node features) # different node features)
if isinstance(x, list) and isinstance(pos, if isinstance(x, list) and isinstance(pos,
(torch.Tensor, LabelTensor)): (torch.Tensor, LabelTensor)):
# Replicate the positions, edge_index and edge_attr # Replicate the positions, edge_index and edge_attr
pos, edge_index = [pos] * data_len, [edge_index] * data_len pos, edge_index = [pos] * data_len, [edge_index] * data_len
if edge_attr is not None:
edge_attr = [edge_attr] * data_len
# Prepare internal lists to create a list containing a single graph # Prepare internal lists to create a list containing a single graph
elif isinstance(x, (torch.Tensor, LabelTensor)) and isinstance(pos, ( elif isinstance(x, (torch.Tensor, LabelTensor)) and isinstance(pos, (
torch.Tensor, LabelTensor)): torch.Tensor, LabelTensor)):
# Encapsulate the input tensors into lists # Encapsulate the input tensors into lists
x, pos, edge_index = [x], [pos], [edge_index] x, pos, edge_index = [x], [pos], [edge_index]
if isinstance(edge_attr, torch.Tensor):
edge_attr = [edge_attr]
# Prepare internal lists to create a list of graphs (same node features # Prepare internal lists to create a list of graphs (same node features
# but different positions) # but different positions)
elif (isinstance(x, (torch.Tensor, LabelTensor)) elif (isinstance(x, (torch.Tensor, LabelTensor))
@@ -111,6 +83,10 @@ class Graph:
elif not isinstance(x, list) and not isinstance(pos, list): elif not isinstance(x, list) and not isinstance(pos, list):
raise TypeError("x and pos must be lists or tensors.") raise TypeError("x and pos must be lists or tensors.")
# Build the edge attributes
edge_attr = self._check_and_build_edge_attr(edge_attr, build_edge_attr,
data_len, edge_index, pos, x)
# Perform the graph construction # Perform the graph construction
self._build_graph_list(x, pos, edge_index, edge_attr, additional_params) self._build_graph_list(x, pos, edge_index, edge_attr, additional_params)
@@ -130,12 +106,8 @@ class Graph:
**add_params_local)) **add_params_local))
@staticmethod @staticmethod
def _build_edge_attr(pos, edge_index): def _build_edge_attr(x, pos, edge_index):
if isinstance(pos, torch.Tensor): distance = torch.abs(pos[edge_index[0]] - pos[edge_index[1]])
pos = [pos]
edge_index = [edge_index]
distance = [pos_[edge_index_[0]] - pos_[edge_index_[1]] ** 2 for
pos_, edge_index_ in zip(pos, edge_index)]
return distance return distance
@staticmethod @staticmethod
@@ -166,15 +138,65 @@ class Graph:
edge_index = [edge_index[i] for i in range(edge_index.shape[0])] edge_index = [edge_index[i] for i in range(edge_index.shape[0])]
return x, pos, edge_index return x, pos, edge_index
@staticmethod
def _check_additional_params(additional_params, data_len):
if additional_params is not None:
if not isinstance(additional_params, dict):
raise TypeError("additional_params must be a dictionary.")
for param, val in additional_params.items():
# Check if the values are tensors or lists of tensors
if isinstance(val, torch.Tensor):
# If the tensor is 3D, we split it into a list of 2D tensors
# In this case there must be a additional parameter for each
# node
if val.ndim == 3:
additional_params[param] = [val[i] for i in
range(val.shape[0])]
# If the tensor is 2D, we replicate it for each node
elif val.ndim == 2:
additional_params[param] = [val] * data_len
# If the tensor is 1D, each graph has a scalar values as
# additional parameter
if val.ndim == 1:
if len(val) == data_len:
additional_params[param] = [val[i] for i in
range(len(val))]
else:
additional_params[param] = [val for _ in
range(data_len)]
elif not isinstance(val, list):
raise TypeError("additional_params values must be tensors "
"or lists of tensors.")
else:
additional_params = {}
return additional_params
def _check_and_build_edge_attr(self, edge_attr, build_edge_attr, data_len,
edge_index, pos, x):
# Check if edge_attr is consistent with x and pos
if edge_attr is not None:
if build_edge_attr is True:
warning("edge_attr is not None. build_edge_attr will not be "
"considered.")
if isinstance(edge_attr, list):
if len(edge_attr) != data_len:
raise ValueError("edge_attr must have the same length as x "
"and pos.")
return [edge_attr] * data_len
if build_edge_attr:
return [self._build_edge_attr(x,pos_, edge_index_) for
pos_, edge_index_ in zip(pos, edge_index)]
class RadiusGraph(Graph): class RadiusGraph(Graph):
def __init__(self, def __init__(
self,
x, x,
pos, pos,
r, r,
build_edge_attr=False, **kwargs
undirected=False, ):
additional_params=None, ):
x, pos, edge_index = Graph._check_input_consistency(x, pos) x, pos, edge_index = Graph._check_input_consistency(x, pos)
if isinstance(pos, (torch.Tensor, LabelTensor)): if isinstance(pos, (torch.Tensor, LabelTensor)):
@@ -183,9 +205,7 @@ class RadiusGraph(Graph):
edge_index = [RadiusGraph._radius_graph(p, r) for p in pos] edge_index = [RadiusGraph._radius_graph(p, r) for p in pos]
super().__init__(x=x, pos=pos, edge_index=edge_index, super().__init__(x=x, pos=pos, edge_index=edge_index,
build_edge_attr=build_edge_attr, **kwargs)
undirected=undirected,
additional_params=additional_params)
@staticmethod @staticmethod
def _radius_graph(points, r): def _radius_graph(points, r):
@@ -204,13 +224,12 @@ class RadiusGraph(Graph):
class KNNGraph(Graph): class KNNGraph(Graph):
def __init__(self, def __init__(
self,
x, x,
pos, pos,
k, k,
build_edge_attr=False, **kwargs
undirected=False,
additional_params=None,
): ):
x, pos, edge_index = Graph._check_input_consistency(x, pos) x, pos, edge_index = Graph._check_input_consistency(x, pos)
if isinstance(pos, (torch.Tensor, LabelTensor)): if isinstance(pos, (torch.Tensor, LabelTensor)):
@@ -218,9 +237,7 @@ class KNNGraph(Graph):
else: else:
edge_index = [KNNGraph._knn_graph(p, k) for p in pos] edge_index = [KNNGraph._knn_graph(p, k) for p in pos]
super().__init__(x=x, pos=pos, edge_index=edge_index, super().__init__(x=x, pos=pos, edge_index=edge_index,
build_edge_attr=build_edge_attr, **kwargs)
undirected=undirected,
additional_params=additional_params)
@staticmethod @staticmethod
def _knn_graph(points, k): def _knn_graph(points, k):
@@ -240,6 +257,7 @@ class KNNGraph(Graph):
edge_index = torch.stack([row, col], dim=0) edge_index = torch.stack([row, col], dim=0)
return edge_index return edge_index
class TemporalGraph(Graph): class TemporalGraph(Graph):
def __init__( def __init__(
self, self,
@@ -259,7 +277,8 @@ class TemporalGraph(Graph):
edge_index = [RadiusGraph._radius_graph(p, r) for p in pos] edge_index = [RadiusGraph._radius_graph(p, r) for p in pos]
additional_params = {'t': t} additional_params = {'t': t}
self._check_time_consistency(pos, t) self._check_time_consistency(pos, t)
super().__init__(x=x, pos=pos, edge_index=edge_index, edge_attr=edge_attr, super().__init__(x=x, pos=pos, edge_index=edge_index,
edge_attr=edge_attr,
build_edge_attr=build_edge_attr, build_edge_attr=build_edge_attr,
undirected=undirected, undirected=undirected,
additional_params=additional_params) additional_params=additional_params)

View File

@@ -7,8 +7,10 @@ from pina.graph import RadiusGraph, KNNGraph, TemporalGraph
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x, pos", "x, pos",
[ [
([torch.rand(10, 2) for _ in range(3)], [torch.rand(10, 3) for _ in range(3)]), ([torch.rand(10, 2) for _ in range(3)],
([torch.rand(10, 2) for _ in range(3)], [torch.rand(10, 3) for _ in range(3)]), [torch.rand(10, 3) for _ in range(3)]),
([torch.rand(10, 2) for _ in range(3)],
[torch.rand(10, 3) for _ in range(3)]),
(torch.rand(3, 10, 2), torch.rand(3, 10, 3)), (torch.rand(3, 10, 2), torch.rand(3, 10, 3)),
(torch.rand(3, 10, 2), torch.rand(3, 10, 3)), (torch.rand(3, 10, 2), torch.rand(3, 10, 3)),
] ]
@@ -112,6 +114,7 @@ def test_build_single_graph_single_val(pos):
assert all(d.edge_attr is not None for d in data) assert all(d.edge_attr is not None for d in data)
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data) assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data)
def test_additional_parameters_1(): def test_additional_parameters_1():
x = torch.rand(3, 10, 2) x = torch.rand(3, 10, 2)
pos = torch.rand(3, 10, 2) pos = torch.rand(3, 10, 2)
@@ -124,6 +127,7 @@ def test_additional_parameters_1():
assert all(hasattr(d, 'y') for d in data) assert all(hasattr(d, 'y') for d in data)
assert all(d_.y == 1 for d_ in data) assert all(d_.y == 1 for d_ in data)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"additional_parameters", "additional_parameters",
[ [
@@ -142,6 +146,7 @@ def test_additional_parameters_2(additional_parameters):
assert all(hasattr(d, 'y') for d in data) assert all(hasattr(d, 'y') for d in data)
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
def test_temporal_graph(): def test_temporal_graph():
x = torch.rand(3, 10, 2) x = torch.rand(3, 10, 2)
pos = torch.rand(3, 10, 2) pos = torch.rand(3, 10, 2)
@@ -152,3 +157,21 @@ def test_temporal_graph():
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
assert all(hasattr(d, 't') for d in data) assert all(hasattr(d, 't') for d in data)
assert all(d_.t == t_ for (d_, t_) in zip(data, t)) assert all(d_.t == t_ for (d_, t_) in zip(data, t))
def test_custom_build_edge_attr_func():
x = torch.rand(3, 10, 2)
pos = torch.rand(3, 10, 2)
def build_edge_attr(x, pos, edge_index):
return torch.cat([pos[edge_index[0]], pos[edge_index[1]]], dim=-1)
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3,
custom_build_edge_attr=build_edge_attr)
assert len(graph.data) == 3
data = graph.data
assert all(hasattr(d, 'edge_attr') for d in data)
assert all(d.edge_attr.shape[1] == 4 for d in data)
assert all(torch.isclose(d.edge_attr,
build_edge_attr(d.x, d.pos, d.edge_index)).all()
for d in data)