Refactor GNO model and enhance Graph class documentation and error handling. Remove TemporalGraph class
This commit is contained in:
committed by
Nicola Demo
parent
bbdd5d4bf1
commit
bd24b0c1c2
109
pina/graph.py
109
pina/graph.py
@@ -25,25 +25,44 @@ class Graph:
|
||||
additional_params=None
|
||||
):
|
||||
"""
|
||||
Constructor for the Graph class.
|
||||
:param x: The node features.
|
||||
Constructor for the Graph class. This object creates a list of PyTorch Geometric Data objects.
|
||||
Based on the input of x and pos there could be the following cases:
|
||||
1. 1 pos, 1 x: a single graph will be created
|
||||
2. N pos, 1 x: N graphs will be created with the same node features
|
||||
3. 1 pos, N x: N graphs will be created with the same nodes but different node features
|
||||
4. N pos, N x: N graphs will be created
|
||||
|
||||
:param x: Node features. Can be a single 2D tensor of shape [num_nodes, num_node_features],
|
||||
or a 3D tensor of shape [n_graphs, num_nodes, num_node_features]
|
||||
or a list of such 2D tensors of shape [num_nodes, num_node_features].
|
||||
:type x: torch.Tensor or list[torch.Tensor]
|
||||
:param pos: The node positions.
|
||||
:param pos: Node coordinates. Can be a single 2D tensor of shape [num_nodes, num_coordinates],
|
||||
or a 3D tensor of shape [n_graphs, num_nodes, num_coordinates]
|
||||
or a list of such 2D tensors of shape [num_nodes, num_coordinates].
|
||||
:type pos: torch.Tensor or list[torch.Tensor]
|
||||
:param edge_index: The edge index.
|
||||
:param edge_index: The edge index defining connections between nodes.
|
||||
It should be a 2D tensor of shape [2, num_edges]
|
||||
or a 3D tensor of shape [n_graphs, 2, num_edges]
|
||||
or a list of such 2D tensors of shape [2, num_edges].
|
||||
:type edge_index: torch.Tensor or list[torch.Tensor]
|
||||
:param edge_attr: The edge attributes.
|
||||
:type edge_attr: torch.Tensor or list[torch.Tensor]
|
||||
:param build_edge_attr: Whether to build the edge attributes.
|
||||
:type build_edge_attr: bool
|
||||
:param undirected: Whether to build an undirected graph.
|
||||
:type undirected: bool
|
||||
:param custom_build_edge_attr: Custom function to build the edge
|
||||
attributes.
|
||||
:type custom_build_edge_attr: function
|
||||
:param additional_params: Additional parameters.
|
||||
:type additional_params: dict
|
||||
:param edge_attr: Edge features. If provided, should have the shape [num_edges, num_edge_features]
|
||||
or be a list of such tensors for multiple graphs.
|
||||
:type edge_attr: torch.Tensor or list[torch.Tensor], optional
|
||||
:param build_edge_attr: Whether to compute edge attributes during initialization.
|
||||
:type build_edge_attr: bool, default=False
|
||||
:param undirected: If True, converts the graph(s) into an undirected graph by adding reciprocal edges.
|
||||
:type undirected: bool, default=False
|
||||
:param custom_build_edge_attr: A user-defined function to generate edge attributes dynamically.
|
||||
The function should take (x, pos, edge_index) as input and return a tensor
|
||||
of shape [num_edges, num_edge_features].
|
||||
:type custom_build_edge_attr: function or callable, optional
|
||||
:param additional_params: Dictionary containing extra attributes to be added to each Data object.
|
||||
Keys represent attribute names, and values should be tensors or lists of tensors.
|
||||
:type additional_params: dict, optional
|
||||
|
||||
Note: if x, pos, and edge_index are both lists or 3D tensors, then len(x) == len(pos) == len(edge_index).
|
||||
"""
|
||||
|
||||
self.data = []
|
||||
x, pos, edge_index = self._check_input_consistency(x, pos, edge_index)
|
||||
|
||||
@@ -85,7 +104,8 @@ class Graph:
|
||||
|
||||
# Build the edge attributes
|
||||
edge_attr = self._check_and_build_edge_attr(edge_attr, build_edge_attr,
|
||||
data_len, edge_index, pos, x)
|
||||
data_len, edge_index, pos,
|
||||
x)
|
||||
|
||||
# Perform the graph construction
|
||||
self._build_graph_list(x, pos, edge_index, edge_attr, additional_params)
|
||||
@@ -128,14 +148,32 @@ class Graph:
|
||||
# If x is a 3D tensor, we split it into a list of 2D tensors
|
||||
if isinstance(x, torch.Tensor) and x.ndim == 3:
|
||||
x = [x[i] for i in range(x.shape[0])]
|
||||
elif (not (isinstance(x, list) and all(t.ndim == 2 for t in x)) and
|
||||
not (isinstance(x, torch.Tensor) and x.ndim == 2)):
|
||||
raise TypeError("x must be either a list of 2D tensors or a 2D "
|
||||
"tensor or a 3D tensor")
|
||||
|
||||
# If pos is a 3D tensor, we split it into a list of 2D tensors
|
||||
if isinstance(pos, torch.Tensor) and pos.ndim == 3:
|
||||
pos = [pos[i] for i in range(pos.shape[0])]
|
||||
elif not (isinstance(pos, list) and all(
|
||||
t.ndim == 2 for t in pos)) and not (
|
||||
isinstance(pos, torch.Tensor) and pos.ndim == 2):
|
||||
raise TypeError("pos must be either a list of 2D tensors or a 2D "
|
||||
"tensor or a 3D tensor")
|
||||
|
||||
# If edge_index is a 3D tensor, we split it into a list of 2D tensors
|
||||
if isinstance(edge_index, torch.Tensor) and edge_index.ndim == 3:
|
||||
edge_index = [edge_index[i] for i in range(edge_index.shape[0])]
|
||||
if edge_index is not None:
|
||||
if isinstance(edge_index, torch.Tensor) and edge_index.ndim == 3:
|
||||
edge_index = [edge_index[i] for i in range(edge_index.shape[0])]
|
||||
elif not (isinstance(edge_index, list) and all(
|
||||
t.ndim == 2 for t in edge_index)) and not (
|
||||
isinstance(edge_index,
|
||||
torch.Tensor) and edge_index.ndim == 2):
|
||||
raise TypeError(
|
||||
"edge_index must be either a list of 2D tensors or a 2D "
|
||||
"tensor or a 3D tensor")
|
||||
|
||||
return x, pos, edge_index
|
||||
|
||||
@staticmethod
|
||||
@@ -180,12 +218,12 @@ class Graph:
|
||||
"considered.")
|
||||
if isinstance(edge_attr, list):
|
||||
if len(edge_attr) != data_len:
|
||||
raise ValueError("edge_attr must have the same length as x "
|
||||
raise TypeError("edge_attr must have the same length as x "
|
||||
"and pos.")
|
||||
return [edge_attr] * data_len
|
||||
|
||||
if build_edge_attr:
|
||||
return [self._build_edge_attr(x,pos_, edge_index_) for
|
||||
return [self._build_edge_attr(x, pos_, edge_index_) for
|
||||
pos_, edge_index_ in zip(pos, edge_index)]
|
||||
|
||||
|
||||
@@ -256,34 +294,3 @@ class KNNGraph(Graph):
|
||||
col = knn_indices.flatten()
|
||||
edge_index = torch.stack([row, col], dim=0)
|
||||
return edge_index
|
||||
|
||||
|
||||
class TemporalGraph(Graph):
|
||||
def __init__(
|
||||
self,
|
||||
x,
|
||||
pos,
|
||||
t,
|
||||
edge_index=None,
|
||||
edge_attr=None,
|
||||
build_edge_attr=False,
|
||||
undirected=False,
|
||||
r=None
|
||||
):
|
||||
|
||||
x, pos, edge_index = self._check_input_consistency(x, pos, edge_index)
|
||||
print(len(pos))
|
||||
if edge_index is None:
|
||||
edge_index = [RadiusGraph._radius_graph(p, r) for p in pos]
|
||||
additional_params = {'t': t}
|
||||
self._check_time_consistency(pos, t)
|
||||
super().__init__(x=x, pos=pos, edge_index=edge_index,
|
||||
edge_attr=edge_attr,
|
||||
build_edge_attr=build_edge_attr,
|
||||
undirected=undirected,
|
||||
additional_params=additional_params)
|
||||
|
||||
@staticmethod
|
||||
def _check_time_consistency(pos, times):
|
||||
if len(pos) != len(times):
|
||||
raise ValueError("pos and times must have the same length.")
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import pytest
|
||||
import torch
|
||||
from pina import Graph
|
||||
from pina.graph import RadiusGraph, KNNGraph, TemporalGraph
|
||||
from pina.graph import RadiusGraph, KNNGraph
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -146,19 +145,6 @@ def test_additional_parameters_2(additional_parameters):
|
||||
assert all(hasattr(d, 'y') for d in data)
|
||||
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
|
||||
|
||||
|
||||
def test_temporal_graph():
|
||||
x = torch.rand(3, 10, 2)
|
||||
pos = torch.rand(3, 10, 2)
|
||||
t = torch.rand(3)
|
||||
graph = TemporalGraph(x=x, pos=pos, build_edge_attr=True, r=.3, t=t)
|
||||
assert len(graph.data) == 3
|
||||
data = graph.data
|
||||
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
|
||||
assert all(hasattr(d, 't') for d in data)
|
||||
assert all(d_.t == t_ for (d_, t_) in zip(data, t))
|
||||
|
||||
|
||||
def test_custom_build_edge_attr_func():
|
||||
x = torch.rand(3, 10, 2)
|
||||
pos = torch.rand(3, 10, 2)
|
||||
|
||||
Reference in New Issue
Block a user