Add TemporalGraph class
This commit is contained in:
committed by
Nicola Demo
parent
54a62dee26
commit
78b276d995
@@ -1,6 +1,7 @@
|
|||||||
from logging import warning
|
from logging import warning
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from . import LabelTensor
|
from . import LabelTensor
|
||||||
from torch_geometric.data import Data
|
from torch_geometric.data import Data
|
||||||
from torch_geometric.utils import to_undirected
|
from torch_geometric.utils import to_undirected
|
||||||
@@ -238,3 +239,32 @@ class KNNGraph(Graph):
|
|||||||
col = knn_indices.flatten()
|
col = knn_indices.flatten()
|
||||||
edge_index = torch.stack([row, col], dim=0)
|
edge_index = torch.stack([row, col], dim=0)
|
||||||
return edge_index
|
return edge_index
|
||||||
|
|
||||||
|
class TemporalGraph(Graph):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
pos,
|
||||||
|
t,
|
||||||
|
edge_index=None,
|
||||||
|
edge_attr=None,
|
||||||
|
build_edge_attr=False,
|
||||||
|
undirected=False,
|
||||||
|
r=None
|
||||||
|
):
|
||||||
|
|
||||||
|
x, pos, edge_index = self._check_input_consistency(x, pos, edge_index)
|
||||||
|
print(len(pos))
|
||||||
|
if edge_index is None:
|
||||||
|
edge_index = [RadiusGraph._radius_graph(p, r) for p in pos]
|
||||||
|
additional_params = {'t': t}
|
||||||
|
self._check_time_consistency(pos, t)
|
||||||
|
super().__init__(x=x, pos=pos, edge_index=edge_index, edge_attr=edge_attr,
|
||||||
|
build_edge_attr=build_edge_attr,
|
||||||
|
undirected=undirected,
|
||||||
|
additional_params=additional_params)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _check_time_consistency(pos, times):
|
||||||
|
if len(pos) != len(times):
|
||||||
|
raise ValueError("pos and times must have the same length.")
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from pina import Graph
|
from pina import Graph
|
||||||
from pina.graph import RadiusGraph, KNNGraph
|
from pina.graph import RadiusGraph, KNNGraph, TemporalGraph
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -141,3 +141,14 @@ def test_additional_parameters_2(additional_parameters):
|
|||||||
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
|
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
|
||||||
assert all(hasattr(d, 'y') for d in data)
|
assert all(hasattr(d, 'y') for d in data)
|
||||||
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
|
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
|
||||||
|
|
||||||
|
def test_temporal_graph():
|
||||||
|
x = torch.rand(3, 10, 2)
|
||||||
|
pos = torch.rand(3, 10, 2)
|
||||||
|
t = torch.rand(3)
|
||||||
|
graph = TemporalGraph(x=x, pos=pos, build_edge_attr=True, r=.3, t=t)
|
||||||
|
assert len(graph.data) == 3
|
||||||
|
data = graph.data
|
||||||
|
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
|
||||||
|
assert all(hasattr(d, 't') for d in data)
|
||||||
|
assert all(d_.t == t_ for (d_, t_) in zip(data, t))
|
||||||
|
|||||||
Reference in New Issue
Block a user