Bug fix in Graph class

This commit is contained in:
FilippoOlivo
2025-02-09 13:01:58 +01:00
committed by Nicola Demo
parent 86405ef597
commit 7702427e8d
2 changed files with 4 additions and 3 deletions

View File

@@ -1,6 +1,7 @@
__all__ = [
"Trainer", "LabelTensor", "Plotter", "Condition",
"PinaDataModule", 'TorchOptimizer', 'Graph',
"RadiusGraph", "KNNGraph"
]
from .meta import *
@@ -14,4 +15,4 @@ from .data import PinaDataModule
from .optim import TorchOptimizer
from .optim import TorchScheduler
from .graph import Graph
from .graph import Graph, RadiusGraph, KNNGraph

View File

@@ -223,8 +223,8 @@ class Graph:
return [edge_attr] * data_len
if build_edge_attr:
return [self._build_edge_attr(x, pos_, edge_index_) for
pos_, edge_index_ in zip(pos, edge_index)]
return [self._build_edge_attr(x_, pos_, edge_index_) for
x_, pos_, edge_index_ in zip(x, pos, edge_index)]
class RadiusGraph(Graph):