Refactor Graph class to support custom edge attribute logic

This commit is contained in:
FilippoOlivo
2025-02-05 10:59:48 +01:00
committed by Nicola Demo
parent 78b276d995
commit bbdd5d4bf1
2 changed files with 128 additions and 86 deletions

View File

@@ -5,6 +5,7 @@ import torch
from . import LabelTensor
from torch_geometric.data import Data
from torch_geometric.utils import to_undirected
import inspect
class Graph:
@@ -12,14 +13,17 @@ class Graph:
Class for the graph construction.
"""
def __init__(self,
x,
pos,
edge_index,
edge_attr=None,
build_edge_attr=False,
undirected=False,
additional_params=None):
def __init__(
self,
x,
pos,
edge_index,
edge_attr=None,
build_edge_attr=False,
undirected=False,
custom_build_edge_attr=None,
additional_params=None
):
"""
Constructor for the Graph class.
:param x: The node features.
@@ -34,45 +38,23 @@ class Graph:
:type build_edge_attr: bool
:param undirected: Whether to build an undirected graph.
:type undirected: bool
:param custom_build_edge_attr: Custom function to build the edge
attributes.
:type custom_build_edge_attr: function
:param additional_params: Additional parameters.
:type additional_params: dict
"""
self.data = []
x, pos, edge_index = Graph._check_input_consistency(x, pos, edge_index)
x, pos, edge_index = self._check_input_consistency(x, pos, edge_index)
# Check input dimension consistency and store the number of graphs
data_len = self._check_len_consistency(x, pos)
if inspect.isfunction(custom_build_edge_attr):
self._build_edge_attr = custom_build_edge_attr
# Initialize additional_parameters (if present)
if additional_params is not None:
if not isinstance(additional_params, dict):
raise TypeError("additional_params must be a dictionary.")
for param, val in additional_params.items():
# Check if the values are tensors or lists of tensors
if isinstance(val, torch.Tensor):
# If the tensor is 3D, we split it into a list of 2D tensors
# In this case there must be a additional parameter for each
# node
if val.ndim == 3:
additional_params[param] = [val[i] for i in
range(val.shape[0])]
# If the tensor is 2D, we replicate it for each node
elif val.ndim == 2:
additional_params[param] = [val] * data_len
# If the tensor is 1D, each graph has a scalar values as
# additional parameter
if val.ndim == 1:
if len(val) == data_len:
additional_params[param] = [val[i] for i in
range(len(val))]
else:
additional_params[param] = [val for _ in
range(data_len)]
elif not isinstance(val, list):
raise TypeError("additional_params values must be tensors "
"or lists of tensors.")
else:
additional_params = {}
# Check consistency and initialize additional_parameters (if present)
additional_params = self._check_additional_params(additional_params,
data_len)
# Make the graphs undirected
if undirected:
@@ -81,27 +63,17 @@ class Graph:
else:
edge_index = to_undirected(edge_index)
if build_edge_attr:
if edge_attr is not None:
warning("Edge attributes are provided, build_edge_attr is set "
"to True. The provided edge attributes will be ignored.")
edge_attr = self._build_edge_attr(pos, edge_index)
# Prepare internal lists to create a graph list (same positions but
# different node features)
if isinstance(x, list) and isinstance(pos,
(torch.Tensor, LabelTensor)):
# Replicate the positions, edge_index and edge_attr
pos, edge_index = [pos] * data_len, [edge_index] * data_len
if edge_attr is not None:
edge_attr = [edge_attr] * data_len
# Prepare internal lists to create a list containing a single graph
elif isinstance(x, (torch.Tensor, LabelTensor)) and isinstance(pos, (
torch.Tensor, LabelTensor)):
# Encapsulate the input tensors into lists
x, pos, edge_index = [x], [pos], [edge_index]
if isinstance(edge_attr, torch.Tensor):
edge_attr = [edge_attr]
# Prepare internal lists to create a list of graphs (same node features
# but different positions)
elif (isinstance(x, (torch.Tensor, LabelTensor))
@@ -111,6 +83,10 @@ class Graph:
elif not isinstance(x, list) and not isinstance(pos, list):
raise TypeError("x and pos must be lists or tensors.")
# Build the edge attributes
edge_attr = self._check_and_build_edge_attr(edge_attr, build_edge_attr,
data_len, edge_index, pos, x)
# Perform the graph construction
self._build_graph_list(x, pos, edge_index, edge_attr, additional_params)
@@ -130,12 +106,8 @@ class Graph:
**add_params_local))
@staticmethod
def _build_edge_attr(pos, edge_index):
if isinstance(pos, torch.Tensor):
pos = [pos]
edge_index = [edge_index]
distance = [pos_[edge_index_[0]] - pos_[edge_index_[1]] ** 2 for
pos_, edge_index_ in zip(pos, edge_index)]
def _build_edge_attr(x, pos, edge_index):
distance = torch.abs(pos[edge_index[0]] - pos[edge_index[1]])
return distance
@staticmethod
@@ -166,15 +138,65 @@ class Graph:
edge_index = [edge_index[i] for i in range(edge_index.shape[0])]
return x, pos, edge_index
@staticmethod
def _check_additional_params(additional_params, data_len):
if additional_params is not None:
if not isinstance(additional_params, dict):
raise TypeError("additional_params must be a dictionary.")
for param, val in additional_params.items():
# Check if the values are tensors or lists of tensors
if isinstance(val, torch.Tensor):
# If the tensor is 3D, we split it into a list of 2D tensors
# In this case there must be a additional parameter for each
# node
if val.ndim == 3:
additional_params[param] = [val[i] for i in
range(val.shape[0])]
# If the tensor is 2D, we replicate it for each node
elif val.ndim == 2:
additional_params[param] = [val] * data_len
# If the tensor is 1D, each graph has a scalar values as
# additional parameter
if val.ndim == 1:
if len(val) == data_len:
additional_params[param] = [val[i] for i in
range(len(val))]
else:
additional_params[param] = [val for _ in
range(data_len)]
elif not isinstance(val, list):
raise TypeError("additional_params values must be tensors "
"or lists of tensors.")
else:
additional_params = {}
return additional_params
def _check_and_build_edge_attr(self, edge_attr, build_edge_attr, data_len,
edge_index, pos, x):
# Check if edge_attr is consistent with x and pos
if edge_attr is not None:
if build_edge_attr is True:
warning("edge_attr is not None. build_edge_attr will not be "
"considered.")
if isinstance(edge_attr, list):
if len(edge_attr) != data_len:
raise ValueError("edge_attr must have the same length as x "
"and pos.")
return [edge_attr] * data_len
if build_edge_attr:
return [self._build_edge_attr(x,pos_, edge_index_) for
pos_, edge_index_ in zip(pos, edge_index)]
class RadiusGraph(Graph):
def __init__(self,
x,
pos,
r,
build_edge_attr=False,
undirected=False,
additional_params=None, ):
def __init__(
self,
x,
pos,
r,
**kwargs
):
x, pos, edge_index = Graph._check_input_consistency(x, pos)
if isinstance(pos, (torch.Tensor, LabelTensor)):
@@ -183,9 +205,7 @@ class RadiusGraph(Graph):
edge_index = [RadiusGraph._radius_graph(p, r) for p in pos]
super().__init__(x=x, pos=pos, edge_index=edge_index,
build_edge_attr=build_edge_attr,
undirected=undirected,
additional_params=additional_params)
**kwargs)
@staticmethod
def _radius_graph(points, r):
@@ -204,23 +224,20 @@ class RadiusGraph(Graph):
class KNNGraph(Graph):
def __init__(self,
x,
pos,
k,
build_edge_attr=False,
undirected=False,
additional_params=None,
):
def __init__(
self,
x,
pos,
k,
**kwargs
):
x, pos, edge_index = Graph._check_input_consistency(x, pos)
if isinstance(pos, (torch.Tensor, LabelTensor)):
edge_index = KNNGraph._knn_graph(pos, k)
else:
edge_index = [KNNGraph._knn_graph(p, k) for p in pos]
super().__init__(x=x, pos=pos, edge_index=edge_index,
build_edge_attr=build_edge_attr,
undirected=undirected,
additional_params=additional_params)
**kwargs)
@staticmethod
def _knn_graph(points, k):
@@ -240,6 +257,7 @@ class KNNGraph(Graph):
edge_index = torch.stack([row, col], dim=0)
return edge_index
class TemporalGraph(Graph):
def __init__(
self,
@@ -259,7 +277,8 @@ class TemporalGraph(Graph):
edge_index = [RadiusGraph._radius_graph(p, r) for p in pos]
additional_params = {'t': t}
self._check_time_consistency(pos, t)
super().__init__(x=x, pos=pos, edge_index=edge_index, edge_attr=edge_attr,
super().__init__(x=x, pos=pos, edge_index=edge_index,
edge_attr=edge_attr,
build_edge_attr=build_edge_attr,
undirected=undirected,
additional_params=additional_params)

View File

@@ -7,10 +7,12 @@ from pina.graph import RadiusGraph, KNNGraph, TemporalGraph
@pytest.mark.parametrize(
"x, pos",
[
([torch.rand(10, 2) for _ in range(3)], [torch.rand(10, 3) for _ in range(3)]),
([torch.rand(10, 2) for _ in range(3)], [torch.rand(10, 3) for _ in range(3)]),
(torch.rand(3,10,2), torch.rand(3,10,3)),
(torch.rand(3,10,2), torch.rand(3,10,3)),
([torch.rand(10, 2) for _ in range(3)],
[torch.rand(10, 3) for _ in range(3)]),
([torch.rand(10, 2) for _ in range(3)],
[torch.rand(10, 3) for _ in range(3)]),
(torch.rand(3, 10, 2), torch.rand(3, 10, 3)),
(torch.rand(3, 10, 2), torch.rand(3, 10, 3)),
]
)
def test_build_multiple_graph_multiple_val(x, pos):
@@ -28,7 +30,7 @@ def test_build_multiple_graph_multiple_val(x, pos):
assert all(d.edge_attr is not None for d in data)
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data)
graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k = 3)
graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k=3)
data = graph.data
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos))
@@ -112,36 +114,39 @@ def test_build_single_graph_single_val(pos):
assert all(d.edge_attr is not None for d in data)
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data)
def test_additional_parameters_1():
x = torch.rand(3, 10, 2)
pos = torch.rand(3, 10, 2)
additional_parameters = {'y': torch.ones(3)}
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3,
additional_params=additional_parameters)
additional_params=additional_parameters)
assert len(graph.data) == 3
data = graph.data
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
assert all(hasattr(d, 'y') for d in data)
assert all(d_.y == 1 for d_ in data)
@pytest.mark.parametrize(
"additional_parameters",
[
({'y': torch.rand(3,10,1)}),
({'y': [torch.rand(10,1) for _ in range(3)]}),
({'y': torch.rand(3, 10, 1)}),
({'y': [torch.rand(10, 1) for _ in range(3)]}),
]
)
def test_additional_parameters_2(additional_parameters):
x = torch.rand(3, 10, 2)
pos = torch.rand(3, 10, 2)
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3,
additional_params=additional_parameters)
additional_params=additional_parameters)
assert len(graph.data) == 3
data = graph.data
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
assert all(hasattr(d, 'y') for d in data)
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
def test_temporal_graph():
x = torch.rand(3, 10, 2)
pos = torch.rand(3, 10, 2)
@@ -152,3 +157,21 @@ def test_temporal_graph():
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
assert all(hasattr(d, 't') for d in data)
assert all(d_.t == t_ for (d_, t_) in zip(data, t))
def test_custom_build_edge_attr_func():
x = torch.rand(3, 10, 2)
pos = torch.rand(3, 10, 2)
def build_edge_attr(x, pos, edge_index):
return torch.cat([pos[edge_index[0]], pos[edge_index[1]]], dim=-1)
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3,
custom_build_edge_attr=build_edge_attr)
assert len(graph.data) == 3
data = graph.data
assert all(hasattr(d, 'edge_attr') for d in data)
assert all(d.edge_attr.shape[1] == 4 for d in data)
assert all(torch.isclose(d.edge_attr,
build_edge_attr(d.x, d.pos, d.edge_index)).all()
for d in data)