From bbdd5d4bf1456043736a2f44595378f84ea0bc13 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Wed, 5 Feb 2025 10:59:48 +0100 Subject: [PATCH] Refactor Graph class to support custom edge attribute logic --- pina/graph.py | 173 ++++++++++++++++++++++++-------------------- tests/test_graph.py | 41 ++++++++--- 2 files changed, 128 insertions(+), 86 deletions(-) diff --git a/pina/graph.py b/pina/graph.py index 7365bf0..d856ad0 100644 --- a/pina/graph.py +++ b/pina/graph.py @@ -5,6 +5,7 @@ import torch from . import LabelTensor from torch_geometric.data import Data from torch_geometric.utils import to_undirected +import inspect class Graph: @@ -12,14 +13,17 @@ class Graph: Class for the graph construction. """ - def __init__(self, - x, - pos, - edge_index, - edge_attr=None, - build_edge_attr=False, - undirected=False, - additional_params=None): + def __init__( + self, + x, + pos, + edge_index, + edge_attr=None, + build_edge_attr=False, + undirected=False, + custom_build_edge_attr=None, + additional_params=None + ): """ Constructor for the Graph class. :param x: The node features. @@ -34,45 +38,23 @@ class Graph: :type build_edge_attr: bool :param undirected: Whether to build an undirected graph. :type undirected: bool + :param custom_build_edge_attr: Custom function to build the edge + attributes. + :type custom_build_edge_attr: function :param additional_params: Additional parameters. :type additional_params: dict """ self.data = [] - x, pos, edge_index = Graph._check_input_consistency(x, pos, edge_index) + x, pos, edge_index = self._check_input_consistency(x, pos, edge_index) # Check input dimension consistency and store the number of graphs data_len = self._check_len_consistency(x, pos) + if inspect.isfunction(custom_build_edge_attr): + self._build_edge_attr = custom_build_edge_attr - # Initialize additional_parameters (if present) - if additional_params is not None: - if not isinstance(additional_params, dict): - raise TypeError("additional_params must be a dictionary.") - for param, val in additional_params.items(): - # Check if the values are tensors or lists of tensors - if isinstance(val, torch.Tensor): - # If the tensor is 3D, we split it into a list of 2D tensors - # In this case there must be a additional parameter for each - # node - if val.ndim == 3: - additional_params[param] = [val[i] for i in - range(val.shape[0])] - # If the tensor is 2D, we replicate it for each node - elif val.ndim == 2: - additional_params[param] = [val] * data_len - # If the tensor is 1D, each graph has a scalar values as - # additional parameter - if val.ndim == 1: - if len(val) == data_len: - additional_params[param] = [val[i] for i in - range(len(val))] - else: - additional_params[param] = [val for _ in - range(data_len)] - elif not isinstance(val, list): - raise TypeError("additional_params values must be tensors " - "or lists of tensors.") - else: - additional_params = {} + # Check consistency and initialize additional_parameters (if present) + additional_params = self._check_additional_params(additional_params, + data_len) # Make the graphs undirected if undirected: @@ -81,27 +63,17 @@ class Graph: else: edge_index = to_undirected(edge_index) - if build_edge_attr: - if edge_attr is not None: - warning("Edge attributes are provided, build_edge_attr is set " - "to True. The provided edge attributes will be ignored.") - edge_attr = self._build_edge_attr(pos, edge_index) - # Prepare internal lists to create a graph list (same positions but # different node features) if isinstance(x, list) and isinstance(pos, (torch.Tensor, LabelTensor)): # Replicate the positions, edge_index and edge_attr pos, edge_index = [pos] * data_len, [edge_index] * data_len - if edge_attr is not None: - edge_attr = [edge_attr] * data_len # Prepare internal lists to create a list containing a single graph elif isinstance(x, (torch.Tensor, LabelTensor)) and isinstance(pos, ( torch.Tensor, LabelTensor)): # Encapsulate the input tensors into lists x, pos, edge_index = [x], [pos], [edge_index] - if isinstance(edge_attr, torch.Tensor): - edge_attr = [edge_attr] # Prepare internal lists to create a list of graphs (same node features # but different positions) elif (isinstance(x, (torch.Tensor, LabelTensor)) @@ -111,6 +83,10 @@ class Graph: elif not isinstance(x, list) and not isinstance(pos, list): raise TypeError("x and pos must be lists or tensors.") + # Build the edge attributes + edge_attr = self._check_and_build_edge_attr(edge_attr, build_edge_attr, + data_len, edge_index, pos, x) + # Perform the graph construction self._build_graph_list(x, pos, edge_index, edge_attr, additional_params) @@ -130,12 +106,8 @@ class Graph: **add_params_local)) @staticmethod - def _build_edge_attr(pos, edge_index): - if isinstance(pos, torch.Tensor): - pos = [pos] - edge_index = [edge_index] - distance = [pos_[edge_index_[0]] - pos_[edge_index_[1]] ** 2 for - pos_, edge_index_ in zip(pos, edge_index)] + def _build_edge_attr(x, pos, edge_index): + distance = torch.abs(pos[edge_index[0]] - pos[edge_index[1]]) return distance @staticmethod @@ -166,15 +138,65 @@ class Graph: edge_index = [edge_index[i] for i in range(edge_index.shape[0])] return x, pos, edge_index + @staticmethod + def _check_additional_params(additional_params, data_len): + if additional_params is not None: + if not isinstance(additional_params, dict): + raise TypeError("additional_params must be a dictionary.") + for param, val in additional_params.items(): + # Check if the values are tensors or lists of tensors + if isinstance(val, torch.Tensor): + # If the tensor is 3D, we split it into a list of 2D tensors + # In this case there must be a additional parameter for each + # node + if val.ndim == 3: + additional_params[param] = [val[i] for i in + range(val.shape[0])] + # If the tensor is 2D, we replicate it for each node + elif val.ndim == 2: + additional_params[param] = [val] * data_len + # If the tensor is 1D, each graph has a scalar values as + # additional parameter + if val.ndim == 1: + if len(val) == data_len: + additional_params[param] = [val[i] for i in + range(len(val))] + else: + additional_params[param] = [val for _ in + range(data_len)] + elif not isinstance(val, list): + raise TypeError("additional_params values must be tensors " + "or lists of tensors.") + else: + additional_params = {} + return additional_params + + def _check_and_build_edge_attr(self, edge_attr, build_edge_attr, data_len, + edge_index, pos, x): + # Check if edge_attr is consistent with x and pos + if edge_attr is not None: + if build_edge_attr is True: + warning("edge_attr is not None. build_edge_attr will not be " + "considered.") + if isinstance(edge_attr, list): + if len(edge_attr) != data_len: + raise ValueError("edge_attr must have the same length as x " + "and pos.") + return [edge_attr] * data_len + + if build_edge_attr: + return [self._build_edge_attr(x,pos_, edge_index_) for + pos_, edge_index_ in zip(pos, edge_index)] + class RadiusGraph(Graph): - def __init__(self, - x, - pos, - r, - build_edge_attr=False, - undirected=False, - additional_params=None, ): + def __init__( + self, + x, + pos, + r, + **kwargs + ): x, pos, edge_index = Graph._check_input_consistency(x, pos) if isinstance(pos, (torch.Tensor, LabelTensor)): @@ -183,9 +205,7 @@ class RadiusGraph(Graph): edge_index = [RadiusGraph._radius_graph(p, r) for p in pos] super().__init__(x=x, pos=pos, edge_index=edge_index, - build_edge_attr=build_edge_attr, - undirected=undirected, - additional_params=additional_params) + **kwargs) @staticmethod def _radius_graph(points, r): @@ -204,23 +224,20 @@ class RadiusGraph(Graph): class KNNGraph(Graph): - def __init__(self, - x, - pos, - k, - build_edge_attr=False, - undirected=False, - additional_params=None, - ): + def __init__( + self, + x, + pos, + k, + **kwargs + ): x, pos, edge_index = Graph._check_input_consistency(x, pos) if isinstance(pos, (torch.Tensor, LabelTensor)): edge_index = KNNGraph._knn_graph(pos, k) else: edge_index = [KNNGraph._knn_graph(p, k) for p in pos] super().__init__(x=x, pos=pos, edge_index=edge_index, - build_edge_attr=build_edge_attr, - undirected=undirected, - additional_params=additional_params) + **kwargs) @staticmethod def _knn_graph(points, k): @@ -240,6 +257,7 @@ class KNNGraph(Graph): edge_index = torch.stack([row, col], dim=0) return edge_index + class TemporalGraph(Graph): def __init__( self, @@ -259,7 +277,8 @@ class TemporalGraph(Graph): edge_index = [RadiusGraph._radius_graph(p, r) for p in pos] additional_params = {'t': t} self._check_time_consistency(pos, t) - super().__init__(x=x, pos=pos, edge_index=edge_index, edge_attr=edge_attr, + super().__init__(x=x, pos=pos, edge_index=edge_index, + edge_attr=edge_attr, build_edge_attr=build_edge_attr, undirected=undirected, additional_params=additional_params) diff --git a/tests/test_graph.py b/tests/test_graph.py index e6ce88c..5521be0 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -7,10 +7,12 @@ from pina.graph import RadiusGraph, KNNGraph, TemporalGraph @pytest.mark.parametrize( "x, pos", [ - ([torch.rand(10, 2) for _ in range(3)], [torch.rand(10, 3) for _ in range(3)]), - ([torch.rand(10, 2) for _ in range(3)], [torch.rand(10, 3) for _ in range(3)]), - (torch.rand(3,10,2), torch.rand(3,10,3)), - (torch.rand(3,10,2), torch.rand(3,10,3)), + ([torch.rand(10, 2) for _ in range(3)], + [torch.rand(10, 3) for _ in range(3)]), + ([torch.rand(10, 2) for _ in range(3)], + [torch.rand(10, 3) for _ in range(3)]), + (torch.rand(3, 10, 2), torch.rand(3, 10, 3)), + (torch.rand(3, 10, 2), torch.rand(3, 10, 3)), ] ) def test_build_multiple_graph_multiple_val(x, pos): @@ -28,7 +30,7 @@ def test_build_multiple_graph_multiple_val(x, pos): assert all(d.edge_attr is not None for d in data) assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data) - graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k = 3) + graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k=3) data = graph.data assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos)) @@ -112,36 +114,39 @@ def test_build_single_graph_single_val(pos): assert all(d.edge_attr is not None for d in data) assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data) + def test_additional_parameters_1(): x = torch.rand(3, 10, 2) pos = torch.rand(3, 10, 2) additional_parameters = {'y': torch.ones(3)} graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3, - additional_params=additional_parameters) + additional_params=additional_parameters) assert len(graph.data) == 3 data = graph.data assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) assert all(hasattr(d, 'y') for d in data) assert all(d_.y == 1 for d_ in data) + @pytest.mark.parametrize( "additional_parameters", [ - ({'y': torch.rand(3,10,1)}), - ({'y': [torch.rand(10,1) for _ in range(3)]}), + ({'y': torch.rand(3, 10, 1)}), + ({'y': [torch.rand(10, 1) for _ in range(3)]}), ] ) def test_additional_parameters_2(additional_parameters): x = torch.rand(3, 10, 2) pos = torch.rand(3, 10, 2) graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3, - additional_params=additional_parameters) + additional_params=additional_parameters) assert len(graph.data) == 3 data = graph.data assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) assert all(hasattr(d, 'y') for d in data) assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) + def test_temporal_graph(): x = torch.rand(3, 10, 2) pos = torch.rand(3, 10, 2) @@ -152,3 +157,21 @@ def test_temporal_graph(): assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) assert all(hasattr(d, 't') for d in data) assert all(d_.t == t_ for (d_, t_) in zip(data, t)) + + +def test_custom_build_edge_attr_func(): + x = torch.rand(3, 10, 2) + pos = torch.rand(3, 10, 2) + + def build_edge_attr(x, pos, edge_index): + return torch.cat([pos[edge_index[0]], pos[edge_index[1]]], dim=-1) + + graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3, + custom_build_edge_attr=build_edge_attr) + assert len(graph.data) == 3 + data = graph.data + assert all(hasattr(d, 'edge_attr') for d in data) + assert all(d.edge_attr.shape[1] == 4 for d in data) + assert all(torch.isclose(d.edge_attr, + build_edge_attr(d.x, d.pos, d.edge_index)).all() + for d in data)