Simplify Graph class (#459)

* Simplifying Graph class and adjust tests

---------

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
Filippo Olivo
2025-03-03 09:30:44 +01:00
committed by Nicola Demo
parent 4c3e305b09
commit ab6ca78d85
7 changed files with 909 additions and 719 deletions

View File

@@ -1,163 +1,346 @@
import pytest
import torch
from pina.graph import RadiusGraph, KNNGraph
from pina import LabelTensor
from pina.graph import RadiusGraph, KNNGraph, Graph
from torch_geometric.data import Data
def build_edge_attr(pos, edge_index):
return torch.cat([pos[edge_index[0]], pos[edge_index[1]]], dim=-1)
@pytest.mark.parametrize(
"x, pos",
[
([torch.rand(10, 2) for _ in range(3)],
[torch.rand(10, 3) for _ in range(3)]),
([torch.rand(10, 2) for _ in range(3)],
[torch.rand(10, 3) for _ in range(3)]),
(torch.rand(3, 10, 2), torch.rand(3, 10, 3)),
(torch.rand(3, 10, 2), torch.rand(3, 10, 3)),
]
(torch.rand(10, 2), torch.rand(10, 3)),
(
LabelTensor(torch.rand(10, 2), ["u", "v"]),
LabelTensor(torch.rand(10, 3), ["x", "y", "z"]),
),
],
)
def test_build_multiple_graph_multiple_val(x, pos):
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=False, r=.3)
assert len(graph.data) == 3
data = graph.data
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos))
assert all(len(d.edge_index) == 2 for d in data)
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3)
data = graph.data
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos))
assert all(len(d.edge_index) == 2 for d in data)
assert all(d.edge_attr is not None for d in data)
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data)
def test_build_graph(x, pos):
edge_index = torch.tensor(
[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]],
dtype=torch.int64,
)
graph = Graph(x=x, pos=pos, edge_index=edge_index)
assert hasattr(graph, "x")
assert hasattr(graph, "pos")
assert hasattr(graph, "edge_index")
assert torch.isclose(graph.x, x).all()
if isinstance(x, LabelTensor):
assert isinstance(graph.x, LabelTensor)
assert graph.x.labels == x.labels
else:
assert isinstance(graph.pos, torch.Tensor)
assert torch.isclose(graph.pos, pos).all()
if isinstance(pos, LabelTensor):
assert isinstance(graph.pos, LabelTensor)
assert graph.pos.labels == pos.labels
else:
assert isinstance(graph.pos, torch.Tensor)
graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k=3)
data = graph.data
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos))
assert all(len(d.edge_index) == 2 for d in data)
assert all(d.edge_attr is not None for d in data)
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data)
def test_build_single_graph_multiple_val():
x = torch.rand(10, 2)
pos = torch.rand(10, 3)
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=False, r=.3)
assert len(graph.data) == 1
data = graph.data
assert all(torch.isclose(d.x, x).all() for d in data)
assert all(torch.isclose(d_.pos, pos).all() for d_ in data)
assert all(len(d.edge_index) == 2 for d in data)
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3)
data = graph.data
assert len(graph.data) == 1
assert all(torch.isclose(d.x, x).all() for d in data)
assert all(torch.isclose(d_.pos, pos).all() for d_ in data)
assert all(len(d.edge_index) == 2 for d in data)
assert all(d.edge_attr is not None for d in data)
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data)
x = torch.rand(10, 2)
pos = torch.rand(10, 3)
graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k=3)
assert len(graph.data) == 1
data = graph.data
assert all(torch.isclose(d.x, x).all() for d in data)
assert all(torch.isclose(d_.pos, pos).all() for d_ in data)
assert all(len(d.edge_index) == 2 for d in data)
graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k=3)
data = graph.data
assert len(graph.data) == 1
assert all(torch.isclose(d.x, x).all() for d in data)
assert all(torch.isclose(d_.pos, pos).all() for d_ in data)
assert all(len(d.edge_index) == 2 for d in data)
assert all(d.edge_attr is not None for d in data)
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data)
edge_index = torch.tensor(
[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]],
dtype=torch.int64,
)
graph = Graph(x=x, edge_index=edge_index)
assert hasattr(graph, "x")
assert hasattr(graph, "pos")
assert hasattr(graph, "edge_index")
assert torch.isclose(graph.x, x).all()
if isinstance(x, LabelTensor):
assert isinstance(graph.x, LabelTensor)
assert graph.x.labels == x.labels
else:
assert isinstance(graph.x, torch.Tensor)
@pytest.mark.parametrize(
"pos",
"x, pos",
[
([torch.rand(10, 3) for _ in range(3)]),
([torch.rand(10, 3) for _ in range(3)]),
(torch.rand(3, 10, 3)),
(torch.rand(3, 10, 3))
]
(torch.rand(10, 2), torch.rand(10, 3)),
(
LabelTensor(torch.rand(10, 2), ["u", "v"]),
LabelTensor(torch.rand(10, 3), ["x", "y", "z"]),
),
],
)
def test_build_single_graph_single_val(pos):
x = torch.rand(10, 2)
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=False, r=.3)
assert len(graph.data) == 3
data = graph.data
assert all(torch.isclose(d.x, x).all() for d in data)
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos))
assert all(len(d.edge_index) == 2 for d in data)
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3)
data = graph.data
assert all(torch.isclose(d.x, x).all() for d in data)
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos))
assert all(len(d.edge_index) == 2 for d in data)
assert all(d.edge_attr is not None for d in data)
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data)
x = torch.rand(10, 2)
graph = KNNGraph(x=x, pos=pos, build_edge_attr=False, k=3)
assert len(graph.data) == 3
data = graph.data
assert all(torch.isclose(d.x, x).all() for d in data)
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos))
assert all(len(d.edge_index) == 2 for d in data)
graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k=3)
data = graph.data
assert all(torch.isclose(d.x, x).all() for d in data)
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos))
assert all(len(d.edge_index) == 2 for d in data)
assert all(d.edge_attr is not None for d in data)
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data)
def test_additional_parameters_1():
x = torch.rand(3, 10, 2)
pos = torch.rand(3, 10, 2)
additional_parameters = {'y': torch.ones(3)}
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3,
additional_params=additional_parameters)
assert len(graph.data) == 3
data = graph.data
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
assert all(hasattr(d, 'y') for d in data)
assert all(d_.y == 1 for d_ in data)
def test_build_radius_graph(x, pos):
graph = RadiusGraph(x=x, pos=pos, radius=0.5)
assert hasattr(graph, "x")
assert hasattr(graph, "pos")
assert hasattr(graph, "edge_index")
assert torch.isclose(graph.x, x).all()
if isinstance(x, LabelTensor):
assert isinstance(graph.x, LabelTensor)
assert graph.x.labels == x.labels
else:
assert isinstance(graph.pos, torch.Tensor)
assert torch.isclose(graph.pos, pos).all()
if isinstance(pos, LabelTensor):
assert isinstance(graph.pos, LabelTensor)
assert graph.pos.labels == pos.labels
else:
assert isinstance(graph.pos, torch.Tensor)
@pytest.mark.parametrize(
"additional_parameters",
"x, pos",
[
({'y': torch.rand(3, 10, 1)}),
({'y': [torch.rand(10, 1) for _ in range(3)]}),
]
(torch.rand(10, 2), torch.rand(10, 3)),
(
LabelTensor(torch.rand(10, 2), ["u", "v"]),
LabelTensor(torch.rand(10, 3), ["x", "y", "z"]),
),
],
)
def test_additional_parameters_2(additional_parameters):
x = torch.rand(3, 10, 2)
pos = torch.rand(3, 10, 2)
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3,
additional_params=additional_parameters)
assert len(graph.data) == 3
data = graph.data
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
assert all(hasattr(d, 'y') for d in data)
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
def test_build_radius_graph_edge_attr(x, pos):
graph = RadiusGraph(x=x, pos=pos, radius=0.5, edge_attr=True)
assert hasattr(graph, "x")
assert hasattr(graph, "pos")
assert hasattr(graph, "edge_index")
assert torch.isclose(graph.x, x).all()
if isinstance(x, LabelTensor):
assert isinstance(graph.x, LabelTensor)
assert graph.x.labels == x.labels
else:
assert isinstance(graph.pos, torch.Tensor)
assert torch.isclose(graph.pos, pos).all()
if isinstance(pos, LabelTensor):
assert isinstance(graph.pos, LabelTensor)
assert graph.pos.labels == pos.labels
else:
assert isinstance(graph.pos, torch.Tensor)
assert hasattr(graph, "edge_attr")
assert isinstance(graph.edge_attr, torch.Tensor)
assert graph.edge_attr.shape[-1] == 3
assert graph.edge_attr.shape[0] == graph.edge_index.shape[1]
def test_custom_build_edge_attr_func():
x = torch.rand(3, 10, 2)
pos = torch.rand(3, 10, 2)
def build_edge_attr(x, pos, edge_index):
return torch.cat([pos[edge_index[0]], pos[edge_index[1]]], dim=-1)
@pytest.mark.parametrize(
"x, pos",
[
(torch.rand(10, 2), torch.rand(10, 3)),
(
LabelTensor(torch.rand(10, 2), ["u", "v"]),
LabelTensor(torch.rand(10, 3), ["x", "y", "z"]),
),
],
)
def test_build_radius_graph_custom_edge_attr(x, pos):
graph = RadiusGraph(
x=x,
pos=pos,
radius=0.5,
edge_attr=True,
custom_edge_func=build_edge_attr,
)
assert hasattr(graph, "x")
assert hasattr(graph, "pos")
assert hasattr(graph, "edge_index")
assert torch.isclose(graph.x, x).all()
if isinstance(x, LabelTensor):
assert isinstance(graph.x, LabelTensor)
assert graph.x.labels == x.labels
else:
assert isinstance(graph.pos, torch.Tensor)
assert torch.isclose(graph.pos, pos).all()
if isinstance(pos, LabelTensor):
assert isinstance(graph.pos, LabelTensor)
assert graph.pos.labels == pos.labels
else:
assert isinstance(graph.pos, torch.Tensor)
assert hasattr(graph, "edge_attr")
assert isinstance(graph.edge_attr, torch.Tensor)
assert graph.edge_attr.shape[-1] == 6
assert graph.edge_attr.shape[0] == graph.edge_index.shape[1]
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3,
custom_build_edge_attr=build_edge_attr)
assert len(graph.data) == 3
data = graph.data
assert all(hasattr(d, 'edge_attr') for d in data)
assert all(d.edge_attr.shape[1] == 4 for d in data)
assert all(torch.isclose(d.edge_attr,
build_edge_attr(d.x, d.pos, d.edge_index)).all()
for d in data)
@pytest.mark.parametrize(
"x, pos",
[
(torch.rand(10, 2), torch.rand(10, 3)),
(
LabelTensor(torch.rand(10, 2), ["u", "v"]),
LabelTensor(torch.rand(10, 3), ["x", "y", "z"]),
),
],
)
def test_build_knn_graph(x, pos):
graph = KNNGraph(x=x, pos=pos, neighbours=2)
assert hasattr(graph, "x")
assert hasattr(graph, "pos")
assert hasattr(graph, "edge_index")
assert torch.isclose(graph.x, x).all()
if isinstance(x, LabelTensor):
assert isinstance(graph.x, LabelTensor)
assert graph.x.labels == x.labels
else:
assert isinstance(graph.pos, torch.Tensor)
assert torch.isclose(graph.pos, pos).all()
if isinstance(pos, LabelTensor):
assert isinstance(graph.pos, LabelTensor)
assert graph.pos.labels == pos.labels
else:
assert isinstance(graph.pos, torch.Tensor)
assert graph.edge_attr is None
@pytest.mark.parametrize(
"x, pos",
[
(torch.rand(10, 2), torch.rand(10, 3)),
(
LabelTensor(torch.rand(10, 2), ["u", "v"]),
LabelTensor(torch.rand(10, 3), ["x", "y", "z"]),
),
],
)
def test_build_knn_graph_edge_attr(x, pos):
graph = KNNGraph(x=x, pos=pos, neighbours=2, edge_attr=True)
assert hasattr(graph, "x")
assert hasattr(graph, "pos")
assert hasattr(graph, "edge_index")
assert torch.isclose(graph.x, x).all()
if isinstance(x, LabelTensor):
assert isinstance(graph.x, LabelTensor)
assert graph.x.labels == x.labels
else:
assert isinstance(graph.pos, torch.Tensor)
assert torch.isclose(graph.pos, pos).all()
if isinstance(pos, LabelTensor):
assert isinstance(graph.pos, LabelTensor)
assert graph.pos.labels == pos.labels
else:
assert isinstance(graph.pos, torch.Tensor)
assert isinstance(graph.edge_attr, torch.Tensor)
assert graph.edge_attr.shape[-1] == 3
assert graph.edge_attr.shape[0] == graph.edge_index.shape[1]
@pytest.mark.parametrize(
"x, pos",
[
(torch.rand(10, 2), torch.rand(10, 3)),
(
LabelTensor(torch.rand(10, 2), ["u", "v"]),
LabelTensor(torch.rand(10, 3), ["x", "y", "z"]),
),
],
)
def test_build_knn_graph_custom_edge_attr(x, pos):
graph = KNNGraph(
x=x,
pos=pos,
neighbours=2,
edge_attr=True,
custom_edge_func=build_edge_attr,
)
assert hasattr(graph, "x")
assert hasattr(graph, "pos")
assert hasattr(graph, "edge_index")
assert torch.isclose(graph.x, x).all()
if isinstance(x, LabelTensor):
assert isinstance(graph.x, LabelTensor)
assert graph.x.labels == x.labels
else:
assert isinstance(graph.pos, torch.Tensor)
assert torch.isclose(graph.pos, pos).all()
if isinstance(pos, LabelTensor):
assert isinstance(graph.pos, LabelTensor)
assert graph.pos.labels == pos.labels
else:
assert isinstance(graph.pos, torch.Tensor)
assert isinstance(graph.edge_attr, torch.Tensor)
assert graph.edge_attr.shape[-1] == 6
assert graph.edge_attr.shape[0] == graph.edge_index.shape[1]
@pytest.mark.parametrize(
"x, pos, y",
[
(torch.rand(10, 2), torch.rand(10, 3), torch.rand(10, 4)),
(
LabelTensor(torch.rand(10, 2), ["u", "v"]),
LabelTensor(torch.rand(10, 3), ["x", "y", "z"]),
LabelTensor(torch.rand(10, 4), ["a", "b", "c", "d"]),
),
],
)
def test_additional_params(x, pos, y):
edge_index = torch.tensor(
[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]],
dtype=torch.int64,
)
graph = Graph(x=x, pos=pos, edge_index=edge_index, y=y)
assert hasattr(graph, "y")
assert torch.isclose(graph.y, y).all()
if isinstance(y, LabelTensor):
assert isinstance(graph.y, LabelTensor)
assert graph.y.labels == y.labels
else:
assert isinstance(graph.y, torch.Tensor)
assert torch.isclose(graph.y, y).all()
if isinstance(y, LabelTensor):
assert isinstance(graph.y, LabelTensor)
assert graph.y.labels == y.labels
else:
assert isinstance(graph.y, torch.Tensor)
@pytest.mark.parametrize(
"x, pos, y",
[
(torch.rand(10, 2), torch.rand(10, 3), torch.rand(10, 4)),
(
LabelTensor(torch.rand(10, 2), ["u", "v"]),
LabelTensor(torch.rand(10, 3), ["x", "y", "z"]),
LabelTensor(torch.rand(10, 4), ["a", "b", "c", "d"]),
),
],
)
def test_additional_params_radius_graph(x, pos, y):
graph = RadiusGraph(x=x, pos=pos, radius=0.5, y=y)
assert hasattr(graph, "y")
assert torch.isclose(graph.y, y).all()
if isinstance(y, LabelTensor):
assert isinstance(graph.y, LabelTensor)
assert graph.y.labels == y.labels
else:
assert isinstance(graph.y, torch.Tensor)
assert torch.isclose(graph.y, y).all()
if isinstance(y, LabelTensor):
assert isinstance(graph.y, LabelTensor)
assert graph.y.labels == y.labels
else:
assert isinstance(graph.y, torch.Tensor)
@pytest.mark.parametrize(
"x, pos, y",
[
(torch.rand(10, 2), torch.rand(10, 3), torch.rand(10, 4)),
(
LabelTensor(torch.rand(10, 2), ["u", "v"]),
LabelTensor(torch.rand(10, 3), ["x", "y", "z"]),
LabelTensor(torch.rand(10, 4), ["a", "b", "c", "d"]),
),
],
)
def test_additional_params_knn_graph(x, pos, y):
graph = KNNGraph(x=x, pos=pos, neighbours=3, y=y)
assert hasattr(graph, "y")
assert torch.isclose(graph.y, y).all()
if isinstance(y, LabelTensor):
assert isinstance(graph.y, LabelTensor)
assert graph.y.labels == y.labels
else:
assert isinstance(graph.y, torch.Tensor)
assert torch.isclose(graph.y, y).all()
if isinstance(y, LabelTensor):
assert isinstance(graph.y, LabelTensor)
assert graph.y.labels == y.labels
else:
assert isinstance(graph.y, torch.Tensor)