Refactor GNO model and enhance Graph class documentation and error handling. Remove TemporalGraph class

This commit is contained in:
FilippoOlivo
2025-02-05 17:10:26 +01:00
committed by Nicola Demo
parent bbdd5d4bf1
commit bd24b0c1c2
2 changed files with 59 additions and 66 deletions

View File

@@ -25,25 +25,44 @@ class Graph:
additional_params=None
):
"""
Constructor for the Graph class.
:param x: The node features.
Constructor for the Graph class. This object creates a list of PyTorch Geometric Data objects.
Based on the input of x and pos there could be the following cases:
1. 1 pos, 1 x: a single graph will be created
2. N pos, 1 x: N graphs will be created with the same node features
3. 1 pos, N x: N graphs will be created with the same nodes but different node features
4. N pos, N x: N graphs will be created
:param x: Node features. Can be a single 2D tensor of shape [num_nodes, num_node_features],
or a 3D tensor of shape [n_graphs, num_nodes, num_node_features]
or a list of such 2D tensors of shape [num_nodes, num_node_features].
:type x: torch.Tensor or list[torch.Tensor]
:param pos: The node positions.
:param pos: Node coordinates. Can be a single 2D tensor of shape [num_nodes, num_coordinates],
or a 3D tensor of shape [n_graphs, num_nodes, num_coordinates]
or a list of such 2D tensors of shape [num_nodes, num_coordinates].
:type pos: torch.Tensor or list[torch.Tensor]
:param edge_index: The edge index.
:param edge_index: The edge index defining connections between nodes.
It should be a 2D tensor of shape [2, num_edges]
or a 3D tensor of shape [n_graphs, 2, num_edges]
or a list of such 2D tensors of shape [2, num_edges].
:type edge_index: torch.Tensor or list[torch.Tensor]
:param edge_attr: The edge attributes.
:type edge_attr: torch.Tensor or list[torch.Tensor]
:param build_edge_attr: Whether to build the edge attributes.
:type build_edge_attr: bool
:param undirected: Whether to build an undirected graph.
:type undirected: bool
:param custom_build_edge_attr: Custom function to build the edge
attributes.
:type custom_build_edge_attr: function
:param additional_params: Additional parameters.
:type additional_params: dict
:param edge_attr: Edge features. If provided, should have the shape [num_edges, num_edge_features]
or be a list of such tensors for multiple graphs.
:type edge_attr: torch.Tensor or list[torch.Tensor], optional
:param build_edge_attr: Whether to compute edge attributes during initialization.
:type build_edge_attr: bool, default=False
:param undirected: If True, converts the graph(s) into an undirected graph by adding reciprocal edges.
:type undirected: bool, default=False
:param custom_build_edge_attr: A user-defined function to generate edge attributes dynamically.
The function should take (x, pos, edge_index) as input and return a tensor
of shape [num_edges, num_edge_features].
:type custom_build_edge_attr: function or callable, optional
:param additional_params: Dictionary containing extra attributes to be added to each Data object.
Keys represent attribute names, and values should be tensors or lists of tensors.
:type additional_params: dict, optional
Note: if x, pos, and edge_index are both lists or 3D tensors, then len(x) == len(pos) == len(edge_index).
"""
self.data = []
x, pos, edge_index = self._check_input_consistency(x, pos, edge_index)
@@ -85,7 +104,8 @@ class Graph:
# Build the edge attributes
edge_attr = self._check_and_build_edge_attr(edge_attr, build_edge_attr,
data_len, edge_index, pos, x)
data_len, edge_index, pos,
x)
# Perform the graph construction
self._build_graph_list(x, pos, edge_index, edge_attr, additional_params)
@@ -128,14 +148,32 @@ class Graph:
# If x is a 3D tensor, we split it into a list of 2D tensors
if isinstance(x, torch.Tensor) and x.ndim == 3:
x = [x[i] for i in range(x.shape[0])]
elif (not (isinstance(x, list) and all(t.ndim == 2 for t in x)) and
not (isinstance(x, torch.Tensor) and x.ndim == 2)):
raise TypeError("x must be either a list of 2D tensors or a 2D "
"tensor or a 3D tensor")
# If pos is a 3D tensor, we split it into a list of 2D tensors
if isinstance(pos, torch.Tensor) and pos.ndim == 3:
pos = [pos[i] for i in range(pos.shape[0])]
elif not (isinstance(pos, list) and all(
t.ndim == 2 for t in pos)) and not (
isinstance(pos, torch.Tensor) and pos.ndim == 2):
raise TypeError("pos must be either a list of 2D tensors or a 2D "
"tensor or a 3D tensor")
# If edge_index is a 3D tensor, we split it into a list of 2D tensors
if isinstance(edge_index, torch.Tensor) and edge_index.ndim == 3:
edge_index = [edge_index[i] for i in range(edge_index.shape[0])]
if edge_index is not None:
if isinstance(edge_index, torch.Tensor) and edge_index.ndim == 3:
edge_index = [edge_index[i] for i in range(edge_index.shape[0])]
elif not (isinstance(edge_index, list) and all(
t.ndim == 2 for t in edge_index)) and not (
isinstance(edge_index,
torch.Tensor) and edge_index.ndim == 2):
raise TypeError(
"edge_index must be either a list of 2D tensors or a 2D "
"tensor or a 3D tensor")
return x, pos, edge_index
@staticmethod
@@ -180,12 +218,12 @@ class Graph:
"considered.")
if isinstance(edge_attr, list):
if len(edge_attr) != data_len:
raise ValueError("edge_attr must have the same length as x "
raise TypeError("edge_attr must have the same length as x "
"and pos.")
return [edge_attr] * data_len
if build_edge_attr:
return [self._build_edge_attr(x,pos_, edge_index_) for
return [self._build_edge_attr(x, pos_, edge_index_) for
pos_, edge_index_ in zip(pos, edge_index)]
@@ -256,34 +294,3 @@ class KNNGraph(Graph):
col = knn_indices.flatten()
edge_index = torch.stack([row, col], dim=0)
return edge_index
class TemporalGraph(Graph):
def __init__(
self,
x,
pos,
t,
edge_index=None,
edge_attr=None,
build_edge_attr=False,
undirected=False,
r=None
):
x, pos, edge_index = self._check_input_consistency(x, pos, edge_index)
print(len(pos))
if edge_index is None:
edge_index = [RadiusGraph._radius_graph(p, r) for p in pos]
additional_params = {'t': t}
self._check_time_consistency(pos, t)
super().__init__(x=x, pos=pos, edge_index=edge_index,
edge_attr=edge_attr,
build_edge_attr=build_edge_attr,
undirected=undirected,
additional_params=additional_params)
@staticmethod
def _check_time_consistency(pos, times):
if len(pos) != len(times):
raise ValueError("pos and times must have the same length.")