Refactor GNO model and enhance Graph class documentation and error handling. Remove TemporalGraph class
This commit is contained in:
committed by
Nicola Demo
parent
bbdd5d4bf1
commit
bd24b0c1c2
@@ -1,7 +1,6 @@
|
||||
import pytest
|
||||
import torch
|
||||
from pina import Graph
|
||||
from pina.graph import RadiusGraph, KNNGraph, TemporalGraph
|
||||
from pina.graph import RadiusGraph, KNNGraph
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -146,19 +145,6 @@ def test_additional_parameters_2(additional_parameters):
|
||||
assert all(hasattr(d, 'y') for d in data)
|
||||
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
|
||||
|
||||
|
||||
def test_temporal_graph():
|
||||
x = torch.rand(3, 10, 2)
|
||||
pos = torch.rand(3, 10, 2)
|
||||
t = torch.rand(3)
|
||||
graph = TemporalGraph(x=x, pos=pos, build_edge_attr=True, r=.3, t=t)
|
||||
assert len(graph.data) == 3
|
||||
data = graph.data
|
||||
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
|
||||
assert all(hasattr(d, 't') for d in data)
|
||||
assert all(d_.t == t_ for (d_, t_) in zip(data, t))
|
||||
|
||||
|
||||
def test_custom_build_edge_attr_func():
|
||||
x = torch.rand(3, 10, 2)
|
||||
pos = torch.rand(3, 10, 2)
|
||||
|
||||
Reference in New Issue
Block a user