Refactor GNO model and enhance Graph class documentation and error handling. Remove TemporalGraph class

This commit is contained in:
FilippoOlivo
2025-02-05 17:10:26 +01:00
committed by Nicola Demo
parent bbdd5d4bf1
commit bd24b0c1c2
2 changed files with 59 additions and 66 deletions

View File

@@ -25,25 +25,44 @@ class Graph:
additional_params=None additional_params=None
): ):
""" """
Constructor for the Graph class. Constructor for the Graph class. This object creates a list of PyTorch Geometric Data objects.
:param x: The node features. Based on the input of x and pos there could be the following cases:
1. 1 pos, 1 x: a single graph will be created
2. N pos, 1 x: N graphs will be created with the same node features
3. 1 pos, N x: N graphs will be created with the same nodes but different node features
4. N pos, N x: N graphs will be created
:param x: Node features. Can be a single 2D tensor of shape [num_nodes, num_node_features],
or a 3D tensor of shape [n_graphs, num_nodes, num_node_features]
or a list of such 2D tensors of shape [num_nodes, num_node_features].
:type x: torch.Tensor or list[torch.Tensor] :type x: torch.Tensor or list[torch.Tensor]
:param pos: The node positions. :param pos: Node coordinates. Can be a single 2D tensor of shape [num_nodes, num_coordinates],
or a 3D tensor of shape [n_graphs, num_nodes, num_coordinates]
or a list of such 2D tensors of shape [num_nodes, num_coordinates].
:type pos: torch.Tensor or list[torch.Tensor] :type pos: torch.Tensor or list[torch.Tensor]
:param edge_index: The edge index. :param edge_index: The edge index defining connections between nodes.
It should be a 2D tensor of shape [2, num_edges]
or a 3D tensor of shape [n_graphs, 2, num_edges]
or a list of such 2D tensors of shape [2, num_edges].
:type edge_index: torch.Tensor or list[torch.Tensor] :type edge_index: torch.Tensor or list[torch.Tensor]
:param edge_attr: The edge attributes. :param edge_attr: Edge features. If provided, should have the shape [num_edges, num_edge_features]
:type edge_attr: torch.Tensor or list[torch.Tensor] or be a list of such tensors for multiple graphs.
:param build_edge_attr: Whether to build the edge attributes. :type edge_attr: torch.Tensor or list[torch.Tensor], optional
:type build_edge_attr: bool :param build_edge_attr: Whether to compute edge attributes during initialization.
:param undirected: Whether to build an undirected graph. :type build_edge_attr: bool, default=False
:type undirected: bool :param undirected: If True, converts the graph(s) into an undirected graph by adding reciprocal edges.
:param custom_build_edge_attr: Custom function to build the edge :type undirected: bool, default=False
attributes. :param custom_build_edge_attr: A user-defined function to generate edge attributes dynamically.
:type custom_build_edge_attr: function The function should take (x, pos, edge_index) as input and return a tensor
:param additional_params: Additional parameters. of shape [num_edges, num_edge_features].
:type additional_params: dict :type custom_build_edge_attr: function or callable, optional
:param additional_params: Dictionary containing extra attributes to be added to each Data object.
Keys represent attribute names, and values should be tensors or lists of tensors.
:type additional_params: dict, optional
Note: if x, pos, and edge_index are both lists or 3D tensors, then len(x) == len(pos) == len(edge_index).
""" """
self.data = [] self.data = []
x, pos, edge_index = self._check_input_consistency(x, pos, edge_index) x, pos, edge_index = self._check_input_consistency(x, pos, edge_index)
@@ -85,7 +104,8 @@ class Graph:
# Build the edge attributes # Build the edge attributes
edge_attr = self._check_and_build_edge_attr(edge_attr, build_edge_attr, edge_attr = self._check_and_build_edge_attr(edge_attr, build_edge_attr,
data_len, edge_index, pos, x) data_len, edge_index, pos,
x)
# Perform the graph construction # Perform the graph construction
self._build_graph_list(x, pos, edge_index, edge_attr, additional_params) self._build_graph_list(x, pos, edge_index, edge_attr, additional_params)
@@ -128,14 +148,32 @@ class Graph:
# If x is a 3D tensor, we split it into a list of 2D tensors # If x is a 3D tensor, we split it into a list of 2D tensors
if isinstance(x, torch.Tensor) and x.ndim == 3: if isinstance(x, torch.Tensor) and x.ndim == 3:
x = [x[i] for i in range(x.shape[0])] x = [x[i] for i in range(x.shape[0])]
elif (not (isinstance(x, list) and all(t.ndim == 2 for t in x)) and
not (isinstance(x, torch.Tensor) and x.ndim == 2)):
raise TypeError("x must be either a list of 2D tensors or a 2D "
"tensor or a 3D tensor")
# If pos is a 3D tensor, we split it into a list of 2D tensors # If pos is a 3D tensor, we split it into a list of 2D tensors
if isinstance(pos, torch.Tensor) and pos.ndim == 3: if isinstance(pos, torch.Tensor) and pos.ndim == 3:
pos = [pos[i] for i in range(pos.shape[0])] pos = [pos[i] for i in range(pos.shape[0])]
elif not (isinstance(pos, list) and all(
t.ndim == 2 for t in pos)) and not (
isinstance(pos, torch.Tensor) and pos.ndim == 2):
raise TypeError("pos must be either a list of 2D tensors or a 2D "
"tensor or a 3D tensor")
# If edge_index is a 3D tensor, we split it into a list of 2D tensors # If edge_index is a 3D tensor, we split it into a list of 2D tensors
if isinstance(edge_index, torch.Tensor) and edge_index.ndim == 3: if edge_index is not None:
edge_index = [edge_index[i] for i in range(edge_index.shape[0])] if isinstance(edge_index, torch.Tensor) and edge_index.ndim == 3:
edge_index = [edge_index[i] for i in range(edge_index.shape[0])]
elif not (isinstance(edge_index, list) and all(
t.ndim == 2 for t in edge_index)) and not (
isinstance(edge_index,
torch.Tensor) and edge_index.ndim == 2):
raise TypeError(
"edge_index must be either a list of 2D tensors or a 2D "
"tensor or a 3D tensor")
return x, pos, edge_index return x, pos, edge_index
@staticmethod @staticmethod
@@ -180,12 +218,12 @@ class Graph:
"considered.") "considered.")
if isinstance(edge_attr, list): if isinstance(edge_attr, list):
if len(edge_attr) != data_len: if len(edge_attr) != data_len:
raise ValueError("edge_attr must have the same length as x " raise TypeError("edge_attr must have the same length as x "
"and pos.") "and pos.")
return [edge_attr] * data_len return [edge_attr] * data_len
if build_edge_attr: if build_edge_attr:
return [self._build_edge_attr(x,pos_, edge_index_) for return [self._build_edge_attr(x, pos_, edge_index_) for
pos_, edge_index_ in zip(pos, edge_index)] pos_, edge_index_ in zip(pos, edge_index)]
@@ -256,34 +294,3 @@ class KNNGraph(Graph):
col = knn_indices.flatten() col = knn_indices.flatten()
edge_index = torch.stack([row, col], dim=0) edge_index = torch.stack([row, col], dim=0)
return edge_index return edge_index
class TemporalGraph(Graph):
def __init__(
self,
x,
pos,
t,
edge_index=None,
edge_attr=None,
build_edge_attr=False,
undirected=False,
r=None
):
x, pos, edge_index = self._check_input_consistency(x, pos, edge_index)
print(len(pos))
if edge_index is None:
edge_index = [RadiusGraph._radius_graph(p, r) for p in pos]
additional_params = {'t': t}
self._check_time_consistency(pos, t)
super().__init__(x=x, pos=pos, edge_index=edge_index,
edge_attr=edge_attr,
build_edge_attr=build_edge_attr,
undirected=undirected,
additional_params=additional_params)
@staticmethod
def _check_time_consistency(pos, times):
if len(pos) != len(times):
raise ValueError("pos and times must have the same length.")

View File

@@ -1,7 +1,6 @@
import pytest import pytest
import torch import torch
from pina import Graph from pina.graph import RadiusGraph, KNNGraph
from pina.graph import RadiusGraph, KNNGraph, TemporalGraph
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -146,19 +145,6 @@ def test_additional_parameters_2(additional_parameters):
assert all(hasattr(d, 'y') for d in data) assert all(hasattr(d, 'y') for d in data)
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
def test_temporal_graph():
x = torch.rand(3, 10, 2)
pos = torch.rand(3, 10, 2)
t = torch.rand(3)
graph = TemporalGraph(x=x, pos=pos, build_edge_attr=True, r=.3, t=t)
assert len(graph.data) == 3
data = graph.data
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
assert all(hasattr(d, 't') for d in data)
assert all(d_.t == t_ for (d_, t_) in zip(data, t))
def test_custom_build_edge_attr_func(): def test_custom_build_edge_attr_func():
x = torch.rand(3, 10, 2) x = torch.rand(3, 10, 2)
pos = torch.rand(3, 10, 2) pos = torch.rand(3, 10, 2)