Add TemporalGraph class

This commit is contained in:
FilippoOlivo
2025-02-04 21:37:49 +01:00
committed by Nicola Demo
parent 54a62dee26
commit 78b276d995
2 changed files with 42 additions and 1 deletions

View File

@@ -1,6 +1,7 @@
from logging import warning from logging import warning
import torch import torch
from . import LabelTensor from . import LabelTensor
from torch_geometric.data import Data from torch_geometric.data import Data
from torch_geometric.utils import to_undirected from torch_geometric.utils import to_undirected
@@ -238,3 +239,32 @@ class KNNGraph(Graph):
col = knn_indices.flatten() col = knn_indices.flatten()
edge_index = torch.stack([row, col], dim=0) edge_index = torch.stack([row, col], dim=0)
return edge_index return edge_index
class TemporalGraph(Graph):
def __init__(
self,
x,
pos,
t,
edge_index=None,
edge_attr=None,
build_edge_attr=False,
undirected=False,
r=None
):
x, pos, edge_index = self._check_input_consistency(x, pos, edge_index)
print(len(pos))
if edge_index is None:
edge_index = [RadiusGraph._radius_graph(p, r) for p in pos]
additional_params = {'t': t}
self._check_time_consistency(pos, t)
super().__init__(x=x, pos=pos, edge_index=edge_index, edge_attr=edge_attr,
build_edge_attr=build_edge_attr,
undirected=undirected,
additional_params=additional_params)
@staticmethod
def _check_time_consistency(pos, times):
if len(pos) != len(times):
raise ValueError("pos and times must have the same length.")

View File

@@ -1,7 +1,7 @@
import pytest import pytest
import torch import torch
from pina import Graph from pina import Graph
from pina.graph import RadiusGraph, KNNGraph from pina.graph import RadiusGraph, KNNGraph, TemporalGraph
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -141,3 +141,14 @@ def test_additional_parameters_2(additional_parameters):
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
assert all(hasattr(d, 'y') for d in data) assert all(hasattr(d, 'y') for d in data)
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
def test_temporal_graph():
x = torch.rand(3, 10, 2)
pos = torch.rand(3, 10, 2)
t = torch.rand(3)
graph = TemporalGraph(x=x, pos=pos, build_edge_attr=True, r=.3, t=t)
assert len(graph.data) == 3
data = graph.data
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
assert all(hasattr(d, 't') for d in data)
assert all(d_.t == t_ for (d_, t_) in zip(data, t))