Refactor GNO model and enhance Graph class documentation and error handling. Remove TemporalGraph class
This commit is contained in:
committed by
Nicola Demo
parent
bbdd5d4bf1
commit
bd24b0c1c2
109
pina/graph.py
109
pina/graph.py
@@ -25,25 +25,44 @@ class Graph:
|
|||||||
additional_params=None
|
additional_params=None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Constructor for the Graph class.
|
Constructor for the Graph class. This object creates a list of PyTorch Geometric Data objects.
|
||||||
:param x: The node features.
|
Based on the input of x and pos there could be the following cases:
|
||||||
|
1. 1 pos, 1 x: a single graph will be created
|
||||||
|
2. N pos, 1 x: N graphs will be created with the same node features
|
||||||
|
3. 1 pos, N x: N graphs will be created with the same nodes but different node features
|
||||||
|
4. N pos, N x: N graphs will be created
|
||||||
|
|
||||||
|
:param x: Node features. Can be a single 2D tensor of shape [num_nodes, num_node_features],
|
||||||
|
or a 3D tensor of shape [n_graphs, num_nodes, num_node_features]
|
||||||
|
or a list of such 2D tensors of shape [num_nodes, num_node_features].
|
||||||
:type x: torch.Tensor or list[torch.Tensor]
|
:type x: torch.Tensor or list[torch.Tensor]
|
||||||
:param pos: The node positions.
|
:param pos: Node coordinates. Can be a single 2D tensor of shape [num_nodes, num_coordinates],
|
||||||
|
or a 3D tensor of shape [n_graphs, num_nodes, num_coordinates]
|
||||||
|
or a list of such 2D tensors of shape [num_nodes, num_coordinates].
|
||||||
:type pos: torch.Tensor or list[torch.Tensor]
|
:type pos: torch.Tensor or list[torch.Tensor]
|
||||||
:param edge_index: The edge index.
|
:param edge_index: The edge index defining connections between nodes.
|
||||||
|
It should be a 2D tensor of shape [2, num_edges]
|
||||||
|
or a 3D tensor of shape [n_graphs, 2, num_edges]
|
||||||
|
or a list of such 2D tensors of shape [2, num_edges].
|
||||||
:type edge_index: torch.Tensor or list[torch.Tensor]
|
:type edge_index: torch.Tensor or list[torch.Tensor]
|
||||||
:param edge_attr: The edge attributes.
|
:param edge_attr: Edge features. If provided, should have the shape [num_edges, num_edge_features]
|
||||||
:type edge_attr: torch.Tensor or list[torch.Tensor]
|
or be a list of such tensors for multiple graphs.
|
||||||
:param build_edge_attr: Whether to build the edge attributes.
|
:type edge_attr: torch.Tensor or list[torch.Tensor], optional
|
||||||
:type build_edge_attr: bool
|
:param build_edge_attr: Whether to compute edge attributes during initialization.
|
||||||
:param undirected: Whether to build an undirected graph.
|
:type build_edge_attr: bool, default=False
|
||||||
:type undirected: bool
|
:param undirected: If True, converts the graph(s) into an undirected graph by adding reciprocal edges.
|
||||||
:param custom_build_edge_attr: Custom function to build the edge
|
:type undirected: bool, default=False
|
||||||
attributes.
|
:param custom_build_edge_attr: A user-defined function to generate edge attributes dynamically.
|
||||||
:type custom_build_edge_attr: function
|
The function should take (x, pos, edge_index) as input and return a tensor
|
||||||
:param additional_params: Additional parameters.
|
of shape [num_edges, num_edge_features].
|
||||||
:type additional_params: dict
|
:type custom_build_edge_attr: function or callable, optional
|
||||||
|
:param additional_params: Dictionary containing extra attributes to be added to each Data object.
|
||||||
|
Keys represent attribute names, and values should be tensors or lists of tensors.
|
||||||
|
:type additional_params: dict, optional
|
||||||
|
|
||||||
|
Note: if x, pos, and edge_index are both lists or 3D tensors, then len(x) == len(pos) == len(edge_index).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.data = []
|
self.data = []
|
||||||
x, pos, edge_index = self._check_input_consistency(x, pos, edge_index)
|
x, pos, edge_index = self._check_input_consistency(x, pos, edge_index)
|
||||||
|
|
||||||
@@ -85,7 +104,8 @@ class Graph:
|
|||||||
|
|
||||||
# Build the edge attributes
|
# Build the edge attributes
|
||||||
edge_attr = self._check_and_build_edge_attr(edge_attr, build_edge_attr,
|
edge_attr = self._check_and_build_edge_attr(edge_attr, build_edge_attr,
|
||||||
data_len, edge_index, pos, x)
|
data_len, edge_index, pos,
|
||||||
|
x)
|
||||||
|
|
||||||
# Perform the graph construction
|
# Perform the graph construction
|
||||||
self._build_graph_list(x, pos, edge_index, edge_attr, additional_params)
|
self._build_graph_list(x, pos, edge_index, edge_attr, additional_params)
|
||||||
@@ -128,14 +148,32 @@ class Graph:
|
|||||||
# If x is a 3D tensor, we split it into a list of 2D tensors
|
# If x is a 3D tensor, we split it into a list of 2D tensors
|
||||||
if isinstance(x, torch.Tensor) and x.ndim == 3:
|
if isinstance(x, torch.Tensor) and x.ndim == 3:
|
||||||
x = [x[i] for i in range(x.shape[0])]
|
x = [x[i] for i in range(x.shape[0])]
|
||||||
|
elif (not (isinstance(x, list) and all(t.ndim == 2 for t in x)) and
|
||||||
|
not (isinstance(x, torch.Tensor) and x.ndim == 2)):
|
||||||
|
raise TypeError("x must be either a list of 2D tensors or a 2D "
|
||||||
|
"tensor or a 3D tensor")
|
||||||
|
|
||||||
# If pos is a 3D tensor, we split it into a list of 2D tensors
|
# If pos is a 3D tensor, we split it into a list of 2D tensors
|
||||||
if isinstance(pos, torch.Tensor) and pos.ndim == 3:
|
if isinstance(pos, torch.Tensor) and pos.ndim == 3:
|
||||||
pos = [pos[i] for i in range(pos.shape[0])]
|
pos = [pos[i] for i in range(pos.shape[0])]
|
||||||
|
elif not (isinstance(pos, list) and all(
|
||||||
|
t.ndim == 2 for t in pos)) and not (
|
||||||
|
isinstance(pos, torch.Tensor) and pos.ndim == 2):
|
||||||
|
raise TypeError("pos must be either a list of 2D tensors or a 2D "
|
||||||
|
"tensor or a 3D tensor")
|
||||||
|
|
||||||
# If edge_index is a 3D tensor, we split it into a list of 2D tensors
|
# If edge_index is a 3D tensor, we split it into a list of 2D tensors
|
||||||
if isinstance(edge_index, torch.Tensor) and edge_index.ndim == 3:
|
if edge_index is not None:
|
||||||
edge_index = [edge_index[i] for i in range(edge_index.shape[0])]
|
if isinstance(edge_index, torch.Tensor) and edge_index.ndim == 3:
|
||||||
|
edge_index = [edge_index[i] for i in range(edge_index.shape[0])]
|
||||||
|
elif not (isinstance(edge_index, list) and all(
|
||||||
|
t.ndim == 2 for t in edge_index)) and not (
|
||||||
|
isinstance(edge_index,
|
||||||
|
torch.Tensor) and edge_index.ndim == 2):
|
||||||
|
raise TypeError(
|
||||||
|
"edge_index must be either a list of 2D tensors or a 2D "
|
||||||
|
"tensor or a 3D tensor")
|
||||||
|
|
||||||
return x, pos, edge_index
|
return x, pos, edge_index
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -180,12 +218,12 @@ class Graph:
|
|||||||
"considered.")
|
"considered.")
|
||||||
if isinstance(edge_attr, list):
|
if isinstance(edge_attr, list):
|
||||||
if len(edge_attr) != data_len:
|
if len(edge_attr) != data_len:
|
||||||
raise ValueError("edge_attr must have the same length as x "
|
raise TypeError("edge_attr must have the same length as x "
|
||||||
"and pos.")
|
"and pos.")
|
||||||
return [edge_attr] * data_len
|
return [edge_attr] * data_len
|
||||||
|
|
||||||
if build_edge_attr:
|
if build_edge_attr:
|
||||||
return [self._build_edge_attr(x,pos_, edge_index_) for
|
return [self._build_edge_attr(x, pos_, edge_index_) for
|
||||||
pos_, edge_index_ in zip(pos, edge_index)]
|
pos_, edge_index_ in zip(pos, edge_index)]
|
||||||
|
|
||||||
|
|
||||||
@@ -256,34 +294,3 @@ class KNNGraph(Graph):
|
|||||||
col = knn_indices.flatten()
|
col = knn_indices.flatten()
|
||||||
edge_index = torch.stack([row, col], dim=0)
|
edge_index = torch.stack([row, col], dim=0)
|
||||||
return edge_index
|
return edge_index
|
||||||
|
|
||||||
|
|
||||||
class TemporalGraph(Graph):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
pos,
|
|
||||||
t,
|
|
||||||
edge_index=None,
|
|
||||||
edge_attr=None,
|
|
||||||
build_edge_attr=False,
|
|
||||||
undirected=False,
|
|
||||||
r=None
|
|
||||||
):
|
|
||||||
|
|
||||||
x, pos, edge_index = self._check_input_consistency(x, pos, edge_index)
|
|
||||||
print(len(pos))
|
|
||||||
if edge_index is None:
|
|
||||||
edge_index = [RadiusGraph._radius_graph(p, r) for p in pos]
|
|
||||||
additional_params = {'t': t}
|
|
||||||
self._check_time_consistency(pos, t)
|
|
||||||
super().__init__(x=x, pos=pos, edge_index=edge_index,
|
|
||||||
edge_attr=edge_attr,
|
|
||||||
build_edge_attr=build_edge_attr,
|
|
||||||
undirected=undirected,
|
|
||||||
additional_params=additional_params)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _check_time_consistency(pos, times):
|
|
||||||
if len(pos) != len(times):
|
|
||||||
raise ValueError("pos and times must have the same length.")
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from pina import Graph
|
from pina.graph import RadiusGraph, KNNGraph
|
||||||
from pina.graph import RadiusGraph, KNNGraph, TemporalGraph
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -146,19 +145,6 @@ def test_additional_parameters_2(additional_parameters):
|
|||||||
assert all(hasattr(d, 'y') for d in data)
|
assert all(hasattr(d, 'y') for d in data)
|
||||||
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
|
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
|
||||||
|
|
||||||
|
|
||||||
def test_temporal_graph():
|
|
||||||
x = torch.rand(3, 10, 2)
|
|
||||||
pos = torch.rand(3, 10, 2)
|
|
||||||
t = torch.rand(3)
|
|
||||||
graph = TemporalGraph(x=x, pos=pos, build_edge_attr=True, r=.3, t=t)
|
|
||||||
assert len(graph.data) == 3
|
|
||||||
data = graph.data
|
|
||||||
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
|
|
||||||
assert all(hasattr(d, 't') for d in data)
|
|
||||||
assert all(d_.t == t_ for (d_, t_) in zip(data, t))
|
|
||||||
|
|
||||||
|
|
||||||
def test_custom_build_edge_attr_func():
|
def test_custom_build_edge_attr_func():
|
||||||
x = torch.rand(3, 10, 2)
|
x = torch.rand(3, 10, 2)
|
||||||
pos = torch.rand(3, 10, 2)
|
pos = torch.rand(3, 10, 2)
|
||||||
|
|||||||
Reference in New Issue
Block a user