fix tensor getitem in graph_dataset (#633)

This commit is contained in:
Dario Coscia
2025-09-10 12:04:41 +02:00
committed by GitHub
parent 7469543499
commit 85b9edc74d
5 changed files with 39 additions and 51 deletions

View File

@@ -101,7 +101,7 @@ def test_getitem(conditions_dict, max_conditions_lengths):
[d["input"].x.shape == torch.Size((400, 10)) for d in data.values()]
)
assert all(
[d["target"].shape == torch.Size((400, 10)) for d in data.values()]
[d["target"].shape == torch.Size((20, 20, 10)) for d in data.values()]
)
assert all(
[