Add TemporalGraph class
This commit is contained in:
committed by
Nicola Demo
parent
54a62dee26
commit
78b276d995
@@ -1,6 +1,7 @@
|
||||
from logging import warning
|
||||
|
||||
import torch
|
||||
|
||||
from . import LabelTensor
|
||||
from torch_geometric.data import Data
|
||||
from torch_geometric.utils import to_undirected
|
||||
@@ -238,3 +239,32 @@ class KNNGraph(Graph):
|
||||
col = knn_indices.flatten()
|
||||
edge_index = torch.stack([row, col], dim=0)
|
||||
return edge_index
|
||||
|
||||
class TemporalGraph(Graph):
|
||||
def __init__(
|
||||
self,
|
||||
x,
|
||||
pos,
|
||||
t,
|
||||
edge_index=None,
|
||||
edge_attr=None,
|
||||
build_edge_attr=False,
|
||||
undirected=False,
|
||||
r=None
|
||||
):
|
||||
|
||||
x, pos, edge_index = self._check_input_consistency(x, pos, edge_index)
|
||||
print(len(pos))
|
||||
if edge_index is None:
|
||||
edge_index = [RadiusGraph._radius_graph(p, r) for p in pos]
|
||||
additional_params = {'t': t}
|
||||
self._check_time_consistency(pos, t)
|
||||
super().__init__(x=x, pos=pos, edge_index=edge_index, edge_attr=edge_attr,
|
||||
build_edge_attr=build_edge_attr,
|
||||
undirected=undirected,
|
||||
additional_params=additional_params)
|
||||
|
||||
@staticmethod
|
||||
def _check_time_consistency(pos, times):
|
||||
if len(pos) != len(times):
|
||||
raise ValueError("pos and times must have the same length.")
|
||||
|
||||
Reference in New Issue
Block a user