From 78b276d995126d191997bddef8678e6705bf03e1 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Tue, 4 Feb 2025 21:37:49 +0100 Subject: [PATCH] Add TemporalGraph class --- pina/graph.py | 30 ++++++++++++++++++++++++++++++ tests/test_graph.py | 13 ++++++++++++- 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/pina/graph.py b/pina/graph.py index 22d7082..7365bf0 100644 --- a/pina/graph.py +++ b/pina/graph.py @@ -1,6 +1,7 @@ from logging import warning import torch + from . import LabelTensor from torch_geometric.data import Data from torch_geometric.utils import to_undirected @@ -238,3 +239,32 @@ class KNNGraph(Graph): col = knn_indices.flatten() edge_index = torch.stack([row, col], dim=0) return edge_index + +class TemporalGraph(Graph): + def __init__( + self, + x, + pos, + t, + edge_index=None, + edge_attr=None, + build_edge_attr=False, + undirected=False, + r=None + ): + + x, pos, edge_index = self._check_input_consistency(x, pos, edge_index) + print(len(pos)) + if edge_index is None: + edge_index = [RadiusGraph._radius_graph(p, r) for p in pos] + additional_params = {'t': t} + self._check_time_consistency(pos, t) + super().__init__(x=x, pos=pos, edge_index=edge_index, edge_attr=edge_attr, + build_edge_attr=build_edge_attr, + undirected=undirected, + additional_params=additional_params) + + @staticmethod + def _check_time_consistency(pos, times): + if len(pos) != len(times): + raise ValueError("pos and times must have the same length.") diff --git a/tests/test_graph.py b/tests/test_graph.py index 6886dd9..e6ce88c 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,7 +1,7 @@ import pytest import torch from pina import Graph -from pina.graph import RadiusGraph, KNNGraph +from pina.graph import RadiusGraph, KNNGraph, TemporalGraph @pytest.mark.parametrize( @@ -141,3 +141,14 @@ def test_additional_parameters_2(additional_parameters): assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) assert all(hasattr(d, 'y') for d in data) assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) + +def test_temporal_graph(): + x = torch.rand(3, 10, 2) + pos = torch.rand(3, 10, 2) + t = torch.rand(3) + graph = TemporalGraph(x=x, pos=pos, build_edge_attr=True, r=.3, t=t) + assert len(graph.data) == 3 + data = graph.data + assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) + assert all(hasattr(d, 't') for d in data) + assert all(d_.t == t_ for (d_, t_) in zip(data, t))