Refactor Graph class to support custom edge attribute logic
This commit is contained in:
committed by
Nicola Demo
parent
78b276d995
commit
bbdd5d4bf1
173
pina/graph.py
173
pina/graph.py
@@ -5,6 +5,7 @@ import torch
|
|||||||
from . import LabelTensor
|
from . import LabelTensor
|
||||||
from torch_geometric.data import Data
|
from torch_geometric.data import Data
|
||||||
from torch_geometric.utils import to_undirected
|
from torch_geometric.utils import to_undirected
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
|
||||||
class Graph:
|
class Graph:
|
||||||
@@ -12,14 +13,17 @@ class Graph:
|
|||||||
Class for the graph construction.
|
Class for the graph construction.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
x,
|
self,
|
||||||
pos,
|
x,
|
||||||
edge_index,
|
pos,
|
||||||
edge_attr=None,
|
edge_index,
|
||||||
build_edge_attr=False,
|
edge_attr=None,
|
||||||
undirected=False,
|
build_edge_attr=False,
|
||||||
additional_params=None):
|
undirected=False,
|
||||||
|
custom_build_edge_attr=None,
|
||||||
|
additional_params=None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Constructor for the Graph class.
|
Constructor for the Graph class.
|
||||||
:param x: The node features.
|
:param x: The node features.
|
||||||
@@ -34,45 +38,23 @@ class Graph:
|
|||||||
:type build_edge_attr: bool
|
:type build_edge_attr: bool
|
||||||
:param undirected: Whether to build an undirected graph.
|
:param undirected: Whether to build an undirected graph.
|
||||||
:type undirected: bool
|
:type undirected: bool
|
||||||
|
:param custom_build_edge_attr: Custom function to build the edge
|
||||||
|
attributes.
|
||||||
|
:type custom_build_edge_attr: function
|
||||||
:param additional_params: Additional parameters.
|
:param additional_params: Additional parameters.
|
||||||
:type additional_params: dict
|
:type additional_params: dict
|
||||||
"""
|
"""
|
||||||
self.data = []
|
self.data = []
|
||||||
x, pos, edge_index = Graph._check_input_consistency(x, pos, edge_index)
|
x, pos, edge_index = self._check_input_consistency(x, pos, edge_index)
|
||||||
|
|
||||||
# Check input dimension consistency and store the number of graphs
|
# Check input dimension consistency and store the number of graphs
|
||||||
data_len = self._check_len_consistency(x, pos)
|
data_len = self._check_len_consistency(x, pos)
|
||||||
|
if inspect.isfunction(custom_build_edge_attr):
|
||||||
|
self._build_edge_attr = custom_build_edge_attr
|
||||||
|
|
||||||
# Initialize additional_parameters (if present)
|
# Check consistency and initialize additional_parameters (if present)
|
||||||
if additional_params is not None:
|
additional_params = self._check_additional_params(additional_params,
|
||||||
if not isinstance(additional_params, dict):
|
data_len)
|
||||||
raise TypeError("additional_params must be a dictionary.")
|
|
||||||
for param, val in additional_params.items():
|
|
||||||
# Check if the values are tensors or lists of tensors
|
|
||||||
if isinstance(val, torch.Tensor):
|
|
||||||
# If the tensor is 3D, we split it into a list of 2D tensors
|
|
||||||
# In this case there must be a additional parameter for each
|
|
||||||
# node
|
|
||||||
if val.ndim == 3:
|
|
||||||
additional_params[param] = [val[i] for i in
|
|
||||||
range(val.shape[0])]
|
|
||||||
# If the tensor is 2D, we replicate it for each node
|
|
||||||
elif val.ndim == 2:
|
|
||||||
additional_params[param] = [val] * data_len
|
|
||||||
# If the tensor is 1D, each graph has a scalar values as
|
|
||||||
# additional parameter
|
|
||||||
if val.ndim == 1:
|
|
||||||
if len(val) == data_len:
|
|
||||||
additional_params[param] = [val[i] for i in
|
|
||||||
range(len(val))]
|
|
||||||
else:
|
|
||||||
additional_params[param] = [val for _ in
|
|
||||||
range(data_len)]
|
|
||||||
elif not isinstance(val, list):
|
|
||||||
raise TypeError("additional_params values must be tensors "
|
|
||||||
"or lists of tensors.")
|
|
||||||
else:
|
|
||||||
additional_params = {}
|
|
||||||
|
|
||||||
# Make the graphs undirected
|
# Make the graphs undirected
|
||||||
if undirected:
|
if undirected:
|
||||||
@@ -81,27 +63,17 @@ class Graph:
|
|||||||
else:
|
else:
|
||||||
edge_index = to_undirected(edge_index)
|
edge_index = to_undirected(edge_index)
|
||||||
|
|
||||||
if build_edge_attr:
|
|
||||||
if edge_attr is not None:
|
|
||||||
warning("Edge attributes are provided, build_edge_attr is set "
|
|
||||||
"to True. The provided edge attributes will be ignored.")
|
|
||||||
edge_attr = self._build_edge_attr(pos, edge_index)
|
|
||||||
|
|
||||||
# Prepare internal lists to create a graph list (same positions but
|
# Prepare internal lists to create a graph list (same positions but
|
||||||
# different node features)
|
# different node features)
|
||||||
if isinstance(x, list) and isinstance(pos,
|
if isinstance(x, list) and isinstance(pos,
|
||||||
(torch.Tensor, LabelTensor)):
|
(torch.Tensor, LabelTensor)):
|
||||||
# Replicate the positions, edge_index and edge_attr
|
# Replicate the positions, edge_index and edge_attr
|
||||||
pos, edge_index = [pos] * data_len, [edge_index] * data_len
|
pos, edge_index = [pos] * data_len, [edge_index] * data_len
|
||||||
if edge_attr is not None:
|
|
||||||
edge_attr = [edge_attr] * data_len
|
|
||||||
# Prepare internal lists to create a list containing a single graph
|
# Prepare internal lists to create a list containing a single graph
|
||||||
elif isinstance(x, (torch.Tensor, LabelTensor)) and isinstance(pos, (
|
elif isinstance(x, (torch.Tensor, LabelTensor)) and isinstance(pos, (
|
||||||
torch.Tensor, LabelTensor)):
|
torch.Tensor, LabelTensor)):
|
||||||
# Encapsulate the input tensors into lists
|
# Encapsulate the input tensors into lists
|
||||||
x, pos, edge_index = [x], [pos], [edge_index]
|
x, pos, edge_index = [x], [pos], [edge_index]
|
||||||
if isinstance(edge_attr, torch.Tensor):
|
|
||||||
edge_attr = [edge_attr]
|
|
||||||
# Prepare internal lists to create a list of graphs (same node features
|
# Prepare internal lists to create a list of graphs (same node features
|
||||||
# but different positions)
|
# but different positions)
|
||||||
elif (isinstance(x, (torch.Tensor, LabelTensor))
|
elif (isinstance(x, (torch.Tensor, LabelTensor))
|
||||||
@@ -111,6 +83,10 @@ class Graph:
|
|||||||
elif not isinstance(x, list) and not isinstance(pos, list):
|
elif not isinstance(x, list) and not isinstance(pos, list):
|
||||||
raise TypeError("x and pos must be lists or tensors.")
|
raise TypeError("x and pos must be lists or tensors.")
|
||||||
|
|
||||||
|
# Build the edge attributes
|
||||||
|
edge_attr = self._check_and_build_edge_attr(edge_attr, build_edge_attr,
|
||||||
|
data_len, edge_index, pos, x)
|
||||||
|
|
||||||
# Perform the graph construction
|
# Perform the graph construction
|
||||||
self._build_graph_list(x, pos, edge_index, edge_attr, additional_params)
|
self._build_graph_list(x, pos, edge_index, edge_attr, additional_params)
|
||||||
|
|
||||||
@@ -130,12 +106,8 @@ class Graph:
|
|||||||
**add_params_local))
|
**add_params_local))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_edge_attr(pos, edge_index):
|
def _build_edge_attr(x, pos, edge_index):
|
||||||
if isinstance(pos, torch.Tensor):
|
distance = torch.abs(pos[edge_index[0]] - pos[edge_index[1]])
|
||||||
pos = [pos]
|
|
||||||
edge_index = [edge_index]
|
|
||||||
distance = [pos_[edge_index_[0]] - pos_[edge_index_[1]] ** 2 for
|
|
||||||
pos_, edge_index_ in zip(pos, edge_index)]
|
|
||||||
return distance
|
return distance
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -166,15 +138,65 @@ class Graph:
|
|||||||
edge_index = [edge_index[i] for i in range(edge_index.shape[0])]
|
edge_index = [edge_index[i] for i in range(edge_index.shape[0])]
|
||||||
return x, pos, edge_index
|
return x, pos, edge_index
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _check_additional_params(additional_params, data_len):
|
||||||
|
if additional_params is not None:
|
||||||
|
if not isinstance(additional_params, dict):
|
||||||
|
raise TypeError("additional_params must be a dictionary.")
|
||||||
|
for param, val in additional_params.items():
|
||||||
|
# Check if the values are tensors or lists of tensors
|
||||||
|
if isinstance(val, torch.Tensor):
|
||||||
|
# If the tensor is 3D, we split it into a list of 2D tensors
|
||||||
|
# In this case there must be a additional parameter for each
|
||||||
|
# node
|
||||||
|
if val.ndim == 3:
|
||||||
|
additional_params[param] = [val[i] for i in
|
||||||
|
range(val.shape[0])]
|
||||||
|
# If the tensor is 2D, we replicate it for each node
|
||||||
|
elif val.ndim == 2:
|
||||||
|
additional_params[param] = [val] * data_len
|
||||||
|
# If the tensor is 1D, each graph has a scalar values as
|
||||||
|
# additional parameter
|
||||||
|
if val.ndim == 1:
|
||||||
|
if len(val) == data_len:
|
||||||
|
additional_params[param] = [val[i] for i in
|
||||||
|
range(len(val))]
|
||||||
|
else:
|
||||||
|
additional_params[param] = [val for _ in
|
||||||
|
range(data_len)]
|
||||||
|
elif not isinstance(val, list):
|
||||||
|
raise TypeError("additional_params values must be tensors "
|
||||||
|
"or lists of tensors.")
|
||||||
|
else:
|
||||||
|
additional_params = {}
|
||||||
|
return additional_params
|
||||||
|
|
||||||
|
def _check_and_build_edge_attr(self, edge_attr, build_edge_attr, data_len,
|
||||||
|
edge_index, pos, x):
|
||||||
|
# Check if edge_attr is consistent with x and pos
|
||||||
|
if edge_attr is not None:
|
||||||
|
if build_edge_attr is True:
|
||||||
|
warning("edge_attr is not None. build_edge_attr will not be "
|
||||||
|
"considered.")
|
||||||
|
if isinstance(edge_attr, list):
|
||||||
|
if len(edge_attr) != data_len:
|
||||||
|
raise ValueError("edge_attr must have the same length as x "
|
||||||
|
"and pos.")
|
||||||
|
return [edge_attr] * data_len
|
||||||
|
|
||||||
|
if build_edge_attr:
|
||||||
|
return [self._build_edge_attr(x,pos_, edge_index_) for
|
||||||
|
pos_, edge_index_ in zip(pos, edge_index)]
|
||||||
|
|
||||||
|
|
||||||
class RadiusGraph(Graph):
|
class RadiusGraph(Graph):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
x,
|
self,
|
||||||
pos,
|
x,
|
||||||
r,
|
pos,
|
||||||
build_edge_attr=False,
|
r,
|
||||||
undirected=False,
|
**kwargs
|
||||||
additional_params=None, ):
|
):
|
||||||
x, pos, edge_index = Graph._check_input_consistency(x, pos)
|
x, pos, edge_index = Graph._check_input_consistency(x, pos)
|
||||||
|
|
||||||
if isinstance(pos, (torch.Tensor, LabelTensor)):
|
if isinstance(pos, (torch.Tensor, LabelTensor)):
|
||||||
@@ -183,9 +205,7 @@ class RadiusGraph(Graph):
|
|||||||
edge_index = [RadiusGraph._radius_graph(p, r) for p in pos]
|
edge_index = [RadiusGraph._radius_graph(p, r) for p in pos]
|
||||||
|
|
||||||
super().__init__(x=x, pos=pos, edge_index=edge_index,
|
super().__init__(x=x, pos=pos, edge_index=edge_index,
|
||||||
build_edge_attr=build_edge_attr,
|
**kwargs)
|
||||||
undirected=undirected,
|
|
||||||
additional_params=additional_params)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _radius_graph(points, r):
|
def _radius_graph(points, r):
|
||||||
@@ -204,23 +224,20 @@ class RadiusGraph(Graph):
|
|||||||
|
|
||||||
|
|
||||||
class KNNGraph(Graph):
|
class KNNGraph(Graph):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
x,
|
self,
|
||||||
pos,
|
x,
|
||||||
k,
|
pos,
|
||||||
build_edge_attr=False,
|
k,
|
||||||
undirected=False,
|
**kwargs
|
||||||
additional_params=None,
|
):
|
||||||
):
|
|
||||||
x, pos, edge_index = Graph._check_input_consistency(x, pos)
|
x, pos, edge_index = Graph._check_input_consistency(x, pos)
|
||||||
if isinstance(pos, (torch.Tensor, LabelTensor)):
|
if isinstance(pos, (torch.Tensor, LabelTensor)):
|
||||||
edge_index = KNNGraph._knn_graph(pos, k)
|
edge_index = KNNGraph._knn_graph(pos, k)
|
||||||
else:
|
else:
|
||||||
edge_index = [KNNGraph._knn_graph(p, k) for p in pos]
|
edge_index = [KNNGraph._knn_graph(p, k) for p in pos]
|
||||||
super().__init__(x=x, pos=pos, edge_index=edge_index,
|
super().__init__(x=x, pos=pos, edge_index=edge_index,
|
||||||
build_edge_attr=build_edge_attr,
|
**kwargs)
|
||||||
undirected=undirected,
|
|
||||||
additional_params=additional_params)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _knn_graph(points, k):
|
def _knn_graph(points, k):
|
||||||
@@ -240,6 +257,7 @@ class KNNGraph(Graph):
|
|||||||
edge_index = torch.stack([row, col], dim=0)
|
edge_index = torch.stack([row, col], dim=0)
|
||||||
return edge_index
|
return edge_index
|
||||||
|
|
||||||
|
|
||||||
class TemporalGraph(Graph):
|
class TemporalGraph(Graph):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -259,7 +277,8 @@ class TemporalGraph(Graph):
|
|||||||
edge_index = [RadiusGraph._radius_graph(p, r) for p in pos]
|
edge_index = [RadiusGraph._radius_graph(p, r) for p in pos]
|
||||||
additional_params = {'t': t}
|
additional_params = {'t': t}
|
||||||
self._check_time_consistency(pos, t)
|
self._check_time_consistency(pos, t)
|
||||||
super().__init__(x=x, pos=pos, edge_index=edge_index, edge_attr=edge_attr,
|
super().__init__(x=x, pos=pos, edge_index=edge_index,
|
||||||
|
edge_attr=edge_attr,
|
||||||
build_edge_attr=build_edge_attr,
|
build_edge_attr=build_edge_attr,
|
||||||
undirected=undirected,
|
undirected=undirected,
|
||||||
additional_params=additional_params)
|
additional_params=additional_params)
|
||||||
|
|||||||
@@ -7,10 +7,12 @@ from pina.graph import RadiusGraph, KNNGraph, TemporalGraph
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"x, pos",
|
"x, pos",
|
||||||
[
|
[
|
||||||
([torch.rand(10, 2) for _ in range(3)], [torch.rand(10, 3) for _ in range(3)]),
|
([torch.rand(10, 2) for _ in range(3)],
|
||||||
([torch.rand(10, 2) for _ in range(3)], [torch.rand(10, 3) for _ in range(3)]),
|
[torch.rand(10, 3) for _ in range(3)]),
|
||||||
(torch.rand(3,10,2), torch.rand(3,10,3)),
|
([torch.rand(10, 2) for _ in range(3)],
|
||||||
(torch.rand(3,10,2), torch.rand(3,10,3)),
|
[torch.rand(10, 3) for _ in range(3)]),
|
||||||
|
(torch.rand(3, 10, 2), torch.rand(3, 10, 3)),
|
||||||
|
(torch.rand(3, 10, 2), torch.rand(3, 10, 3)),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def test_build_multiple_graph_multiple_val(x, pos):
|
def test_build_multiple_graph_multiple_val(x, pos):
|
||||||
@@ -28,7 +30,7 @@ def test_build_multiple_graph_multiple_val(x, pos):
|
|||||||
assert all(d.edge_attr is not None for d in data)
|
assert all(d.edge_attr is not None for d in data)
|
||||||
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data)
|
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data)
|
||||||
|
|
||||||
graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k = 3)
|
graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k=3)
|
||||||
data = graph.data
|
data = graph.data
|
||||||
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
|
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
|
||||||
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos))
|
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos))
|
||||||
@@ -112,36 +114,39 @@ def test_build_single_graph_single_val(pos):
|
|||||||
assert all(d.edge_attr is not None for d in data)
|
assert all(d.edge_attr is not None for d in data)
|
||||||
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data)
|
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data)
|
||||||
|
|
||||||
|
|
||||||
def test_additional_parameters_1():
|
def test_additional_parameters_1():
|
||||||
x = torch.rand(3, 10, 2)
|
x = torch.rand(3, 10, 2)
|
||||||
pos = torch.rand(3, 10, 2)
|
pos = torch.rand(3, 10, 2)
|
||||||
additional_parameters = {'y': torch.ones(3)}
|
additional_parameters = {'y': torch.ones(3)}
|
||||||
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3,
|
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3,
|
||||||
additional_params=additional_parameters)
|
additional_params=additional_parameters)
|
||||||
assert len(graph.data) == 3
|
assert len(graph.data) == 3
|
||||||
data = graph.data
|
data = graph.data
|
||||||
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
|
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
|
||||||
assert all(hasattr(d, 'y') for d in data)
|
assert all(hasattr(d, 'y') for d in data)
|
||||||
assert all(d_.y == 1 for d_ in data)
|
assert all(d_.y == 1 for d_ in data)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"additional_parameters",
|
"additional_parameters",
|
||||||
[
|
[
|
||||||
({'y': torch.rand(3,10,1)}),
|
({'y': torch.rand(3, 10, 1)}),
|
||||||
({'y': [torch.rand(10,1) for _ in range(3)]}),
|
({'y': [torch.rand(10, 1) for _ in range(3)]}),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def test_additional_parameters_2(additional_parameters):
|
def test_additional_parameters_2(additional_parameters):
|
||||||
x = torch.rand(3, 10, 2)
|
x = torch.rand(3, 10, 2)
|
||||||
pos = torch.rand(3, 10, 2)
|
pos = torch.rand(3, 10, 2)
|
||||||
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3,
|
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3,
|
||||||
additional_params=additional_parameters)
|
additional_params=additional_parameters)
|
||||||
assert len(graph.data) == 3
|
assert len(graph.data) == 3
|
||||||
data = graph.data
|
data = graph.data
|
||||||
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
|
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
|
||||||
assert all(hasattr(d, 'y') for d in data)
|
assert all(hasattr(d, 'y') for d in data)
|
||||||
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
|
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
|
||||||
|
|
||||||
|
|
||||||
def test_temporal_graph():
|
def test_temporal_graph():
|
||||||
x = torch.rand(3, 10, 2)
|
x = torch.rand(3, 10, 2)
|
||||||
pos = torch.rand(3, 10, 2)
|
pos = torch.rand(3, 10, 2)
|
||||||
@@ -152,3 +157,21 @@ def test_temporal_graph():
|
|||||||
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
|
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
|
||||||
assert all(hasattr(d, 't') for d in data)
|
assert all(hasattr(d, 't') for d in data)
|
||||||
assert all(d_.t == t_ for (d_, t_) in zip(data, t))
|
assert all(d_.t == t_ for (d_, t_) in zip(data, t))
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_build_edge_attr_func():
|
||||||
|
x = torch.rand(3, 10, 2)
|
||||||
|
pos = torch.rand(3, 10, 2)
|
||||||
|
|
||||||
|
def build_edge_attr(x, pos, edge_index):
|
||||||
|
return torch.cat([pos[edge_index[0]], pos[edge_index[1]]], dim=-1)
|
||||||
|
|
||||||
|
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3,
|
||||||
|
custom_build_edge_attr=build_edge_attr)
|
||||||
|
assert len(graph.data) == 3
|
||||||
|
data = graph.data
|
||||||
|
assert all(hasattr(d, 'edge_attr') for d in data)
|
||||||
|
assert all(d.edge_attr.shape[1] == 4 for d in data)
|
||||||
|
assert all(torch.isclose(d.edge_attr,
|
||||||
|
build_edge_attr(d.x, d.pos, d.edge_index)).all()
|
||||||
|
for d in data)
|
||||||
|
|||||||
Reference in New Issue
Block a user