Refactor Graph class to support custom edge attribute logic

This commit is contained in:
FilippoOlivo
2025-02-05 10:59:48 +01:00
committed by Nicola Demo
parent 78b276d995
commit bbdd5d4bf1
2 changed files with 128 additions and 86 deletions

View File

@@ -5,6 +5,7 @@ import torch
from . import LabelTensor from . import LabelTensor
from torch_geometric.data import Data from torch_geometric.data import Data
from torch_geometric.utils import to_undirected from torch_geometric.utils import to_undirected
import inspect
class Graph: class Graph:
@@ -12,14 +13,17 @@ class Graph:
Class for the graph construction. Class for the graph construction.
""" """
def __init__(self, def __init__(
x, self,
pos, x,
edge_index, pos,
edge_attr=None, edge_index,
build_edge_attr=False, edge_attr=None,
undirected=False, build_edge_attr=False,
additional_params=None): undirected=False,
custom_build_edge_attr=None,
additional_params=None
):
""" """
Constructor for the Graph class. Constructor for the Graph class.
:param x: The node features. :param x: The node features.
@@ -34,45 +38,23 @@ class Graph:
:type build_edge_attr: bool :type build_edge_attr: bool
:param undirected: Whether to build an undirected graph. :param undirected: Whether to build an undirected graph.
:type undirected: bool :type undirected: bool
:param custom_build_edge_attr: Custom function to build the edge
attributes.
:type custom_build_edge_attr: function
:param additional_params: Additional parameters. :param additional_params: Additional parameters.
:type additional_params: dict :type additional_params: dict
""" """
self.data = [] self.data = []
x, pos, edge_index = Graph._check_input_consistency(x, pos, edge_index) x, pos, edge_index = self._check_input_consistency(x, pos, edge_index)
# Check input dimension consistency and store the number of graphs # Check input dimension consistency and store the number of graphs
data_len = self._check_len_consistency(x, pos) data_len = self._check_len_consistency(x, pos)
if inspect.isfunction(custom_build_edge_attr):
self._build_edge_attr = custom_build_edge_attr
# Initialize additional_parameters (if present) # Check consistency and initialize additional_parameters (if present)
if additional_params is not None: additional_params = self._check_additional_params(additional_params,
if not isinstance(additional_params, dict): data_len)
raise TypeError("additional_params must be a dictionary.")
for param, val in additional_params.items():
# Check if the values are tensors or lists of tensors
if isinstance(val, torch.Tensor):
# If the tensor is 3D, we split it into a list of 2D tensors
# In this case there must be a additional parameter for each
# node
if val.ndim == 3:
additional_params[param] = [val[i] for i in
range(val.shape[0])]
# If the tensor is 2D, we replicate it for each node
elif val.ndim == 2:
additional_params[param] = [val] * data_len
# If the tensor is 1D, each graph has a scalar values as
# additional parameter
if val.ndim == 1:
if len(val) == data_len:
additional_params[param] = [val[i] for i in
range(len(val))]
else:
additional_params[param] = [val for _ in
range(data_len)]
elif not isinstance(val, list):
raise TypeError("additional_params values must be tensors "
"or lists of tensors.")
else:
additional_params = {}
# Make the graphs undirected # Make the graphs undirected
if undirected: if undirected:
@@ -81,27 +63,17 @@ class Graph:
else: else:
edge_index = to_undirected(edge_index) edge_index = to_undirected(edge_index)
if build_edge_attr:
if edge_attr is not None:
warning("Edge attributes are provided, build_edge_attr is set "
"to True. The provided edge attributes will be ignored.")
edge_attr = self._build_edge_attr(pos, edge_index)
# Prepare internal lists to create a graph list (same positions but # Prepare internal lists to create a graph list (same positions but
# different node features) # different node features)
if isinstance(x, list) and isinstance(pos, if isinstance(x, list) and isinstance(pos,
(torch.Tensor, LabelTensor)): (torch.Tensor, LabelTensor)):
# Replicate the positions, edge_index and edge_attr # Replicate the positions, edge_index and edge_attr
pos, edge_index = [pos] * data_len, [edge_index] * data_len pos, edge_index = [pos] * data_len, [edge_index] * data_len
if edge_attr is not None:
edge_attr = [edge_attr] * data_len
# Prepare internal lists to create a list containing a single graph # Prepare internal lists to create a list containing a single graph
elif isinstance(x, (torch.Tensor, LabelTensor)) and isinstance(pos, ( elif isinstance(x, (torch.Tensor, LabelTensor)) and isinstance(pos, (
torch.Tensor, LabelTensor)): torch.Tensor, LabelTensor)):
# Encapsulate the input tensors into lists # Encapsulate the input tensors into lists
x, pos, edge_index = [x], [pos], [edge_index] x, pos, edge_index = [x], [pos], [edge_index]
if isinstance(edge_attr, torch.Tensor):
edge_attr = [edge_attr]
# Prepare internal lists to create a list of graphs (same node features # Prepare internal lists to create a list of graphs (same node features
# but different positions) # but different positions)
elif (isinstance(x, (torch.Tensor, LabelTensor)) elif (isinstance(x, (torch.Tensor, LabelTensor))
@@ -111,6 +83,10 @@ class Graph:
elif not isinstance(x, list) and not isinstance(pos, list): elif not isinstance(x, list) and not isinstance(pos, list):
raise TypeError("x and pos must be lists or tensors.") raise TypeError("x and pos must be lists or tensors.")
# Build the edge attributes
edge_attr = self._check_and_build_edge_attr(edge_attr, build_edge_attr,
data_len, edge_index, pos, x)
# Perform the graph construction # Perform the graph construction
self._build_graph_list(x, pos, edge_index, edge_attr, additional_params) self._build_graph_list(x, pos, edge_index, edge_attr, additional_params)
@@ -130,12 +106,8 @@ class Graph:
**add_params_local)) **add_params_local))
@staticmethod @staticmethod
def _build_edge_attr(pos, edge_index): def _build_edge_attr(x, pos, edge_index):
if isinstance(pos, torch.Tensor): distance = torch.abs(pos[edge_index[0]] - pos[edge_index[1]])
pos = [pos]
edge_index = [edge_index]
distance = [pos_[edge_index_[0]] - pos_[edge_index_[1]] ** 2 for
pos_, edge_index_ in zip(pos, edge_index)]
return distance return distance
@staticmethod @staticmethod
@@ -166,15 +138,65 @@ class Graph:
edge_index = [edge_index[i] for i in range(edge_index.shape[0])] edge_index = [edge_index[i] for i in range(edge_index.shape[0])]
return x, pos, edge_index return x, pos, edge_index
@staticmethod
def _check_additional_params(additional_params, data_len):
if additional_params is not None:
if not isinstance(additional_params, dict):
raise TypeError("additional_params must be a dictionary.")
for param, val in additional_params.items():
# Check if the values are tensors or lists of tensors
if isinstance(val, torch.Tensor):
# If the tensor is 3D, we split it into a list of 2D tensors
# In this case there must be a additional parameter for each
# node
if val.ndim == 3:
additional_params[param] = [val[i] for i in
range(val.shape[0])]
# If the tensor is 2D, we replicate it for each node
elif val.ndim == 2:
additional_params[param] = [val] * data_len
# If the tensor is 1D, each graph has a scalar values as
# additional parameter
if val.ndim == 1:
if len(val) == data_len:
additional_params[param] = [val[i] for i in
range(len(val))]
else:
additional_params[param] = [val for _ in
range(data_len)]
elif not isinstance(val, list):
raise TypeError("additional_params values must be tensors "
"or lists of tensors.")
else:
additional_params = {}
return additional_params
def _check_and_build_edge_attr(self, edge_attr, build_edge_attr, data_len,
edge_index, pos, x):
# Check if edge_attr is consistent with x and pos
if edge_attr is not None:
if build_edge_attr is True:
warning("edge_attr is not None. build_edge_attr will not be "
"considered.")
if isinstance(edge_attr, list):
if len(edge_attr) != data_len:
raise ValueError("edge_attr must have the same length as x "
"and pos.")
return [edge_attr] * data_len
if build_edge_attr:
return [self._build_edge_attr(x,pos_, edge_index_) for
pos_, edge_index_ in zip(pos, edge_index)]
class RadiusGraph(Graph): class RadiusGraph(Graph):
def __init__(self, def __init__(
x, self,
pos, x,
r, pos,
build_edge_attr=False, r,
undirected=False, **kwargs
additional_params=None, ): ):
x, pos, edge_index = Graph._check_input_consistency(x, pos) x, pos, edge_index = Graph._check_input_consistency(x, pos)
if isinstance(pos, (torch.Tensor, LabelTensor)): if isinstance(pos, (torch.Tensor, LabelTensor)):
@@ -183,9 +205,7 @@ class RadiusGraph(Graph):
edge_index = [RadiusGraph._radius_graph(p, r) for p in pos] edge_index = [RadiusGraph._radius_graph(p, r) for p in pos]
super().__init__(x=x, pos=pos, edge_index=edge_index, super().__init__(x=x, pos=pos, edge_index=edge_index,
build_edge_attr=build_edge_attr, **kwargs)
undirected=undirected,
additional_params=additional_params)
@staticmethod @staticmethod
def _radius_graph(points, r): def _radius_graph(points, r):
@@ -204,23 +224,20 @@ class RadiusGraph(Graph):
class KNNGraph(Graph): class KNNGraph(Graph):
def __init__(self, def __init__(
x, self,
pos, x,
k, pos,
build_edge_attr=False, k,
undirected=False, **kwargs
additional_params=None, ):
):
x, pos, edge_index = Graph._check_input_consistency(x, pos) x, pos, edge_index = Graph._check_input_consistency(x, pos)
if isinstance(pos, (torch.Tensor, LabelTensor)): if isinstance(pos, (torch.Tensor, LabelTensor)):
edge_index = KNNGraph._knn_graph(pos, k) edge_index = KNNGraph._knn_graph(pos, k)
else: else:
edge_index = [KNNGraph._knn_graph(p, k) for p in pos] edge_index = [KNNGraph._knn_graph(p, k) for p in pos]
super().__init__(x=x, pos=pos, edge_index=edge_index, super().__init__(x=x, pos=pos, edge_index=edge_index,
build_edge_attr=build_edge_attr, **kwargs)
undirected=undirected,
additional_params=additional_params)
@staticmethod @staticmethod
def _knn_graph(points, k): def _knn_graph(points, k):
@@ -240,6 +257,7 @@ class KNNGraph(Graph):
edge_index = torch.stack([row, col], dim=0) edge_index = torch.stack([row, col], dim=0)
return edge_index return edge_index
class TemporalGraph(Graph): class TemporalGraph(Graph):
def __init__( def __init__(
self, self,
@@ -259,7 +277,8 @@ class TemporalGraph(Graph):
edge_index = [RadiusGraph._radius_graph(p, r) for p in pos] edge_index = [RadiusGraph._radius_graph(p, r) for p in pos]
additional_params = {'t': t} additional_params = {'t': t}
self._check_time_consistency(pos, t) self._check_time_consistency(pos, t)
super().__init__(x=x, pos=pos, edge_index=edge_index, edge_attr=edge_attr, super().__init__(x=x, pos=pos, edge_index=edge_index,
edge_attr=edge_attr,
build_edge_attr=build_edge_attr, build_edge_attr=build_edge_attr,
undirected=undirected, undirected=undirected,
additional_params=additional_params) additional_params=additional_params)

View File

@@ -7,10 +7,12 @@ from pina.graph import RadiusGraph, KNNGraph, TemporalGraph
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x, pos", "x, pos",
[ [
([torch.rand(10, 2) for _ in range(3)], [torch.rand(10, 3) for _ in range(3)]), ([torch.rand(10, 2) for _ in range(3)],
([torch.rand(10, 2) for _ in range(3)], [torch.rand(10, 3) for _ in range(3)]), [torch.rand(10, 3) for _ in range(3)]),
(torch.rand(3,10,2), torch.rand(3,10,3)), ([torch.rand(10, 2) for _ in range(3)],
(torch.rand(3,10,2), torch.rand(3,10,3)), [torch.rand(10, 3) for _ in range(3)]),
(torch.rand(3, 10, 2), torch.rand(3, 10, 3)),
(torch.rand(3, 10, 2), torch.rand(3, 10, 3)),
] ]
) )
def test_build_multiple_graph_multiple_val(x, pos): def test_build_multiple_graph_multiple_val(x, pos):
@@ -28,7 +30,7 @@ def test_build_multiple_graph_multiple_val(x, pos):
assert all(d.edge_attr is not None for d in data) assert all(d.edge_attr is not None for d in data)
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data) assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data)
graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k = 3) graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k=3)
data = graph.data data = graph.data
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos)) assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos))
@@ -112,36 +114,39 @@ def test_build_single_graph_single_val(pos):
assert all(d.edge_attr is not None for d in data) assert all(d.edge_attr is not None for d in data)
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data) assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data)
def test_additional_parameters_1(): def test_additional_parameters_1():
x = torch.rand(3, 10, 2) x = torch.rand(3, 10, 2)
pos = torch.rand(3, 10, 2) pos = torch.rand(3, 10, 2)
additional_parameters = {'y': torch.ones(3)} additional_parameters = {'y': torch.ones(3)}
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3, graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3,
additional_params=additional_parameters) additional_params=additional_parameters)
assert len(graph.data) == 3 assert len(graph.data) == 3
data = graph.data data = graph.data
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
assert all(hasattr(d, 'y') for d in data) assert all(hasattr(d, 'y') for d in data)
assert all(d_.y == 1 for d_ in data) assert all(d_.y == 1 for d_ in data)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"additional_parameters", "additional_parameters",
[ [
({'y': torch.rand(3,10,1)}), ({'y': torch.rand(3, 10, 1)}),
({'y': [torch.rand(10,1) for _ in range(3)]}), ({'y': [torch.rand(10, 1) for _ in range(3)]}),
] ]
) )
def test_additional_parameters_2(additional_parameters): def test_additional_parameters_2(additional_parameters):
x = torch.rand(3, 10, 2) x = torch.rand(3, 10, 2)
pos = torch.rand(3, 10, 2) pos = torch.rand(3, 10, 2)
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3, graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3,
additional_params=additional_parameters) additional_params=additional_parameters)
assert len(graph.data) == 3 assert len(graph.data) == 3
data = graph.data data = graph.data
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
assert all(hasattr(d, 'y') for d in data) assert all(hasattr(d, 'y') for d in data)
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
def test_temporal_graph(): def test_temporal_graph():
x = torch.rand(3, 10, 2) x = torch.rand(3, 10, 2)
pos = torch.rand(3, 10, 2) pos = torch.rand(3, 10, 2)
@@ -152,3 +157,21 @@ def test_temporal_graph():
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
assert all(hasattr(d, 't') for d in data) assert all(hasattr(d, 't') for d in data)
assert all(d_.t == t_ for (d_, t_) in zip(data, t)) assert all(d_.t == t_ for (d_, t_) in zip(data, t))
def test_custom_build_edge_attr_func():
x = torch.rand(3, 10, 2)
pos = torch.rand(3, 10, 2)
def build_edge_attr(x, pos, edge_index):
return torch.cat([pos[edge_index[0]], pos[edge_index[1]]], dim=-1)
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3,
custom_build_edge_attr=build_edge_attr)
assert len(graph.data) == 3
data = graph.data
assert all(hasattr(d, 'edge_attr') for d in data)
assert all(d.edge_attr.shape[1] == 4 for d in data)
assert all(torch.isclose(d.edge_attr,
build_edge_attr(d.x, d.pos, d.edge_index)).all()
for d in data)