Bug fix in Graph class

This commit is contained in:
FilippoOlivo
2025-02-09 13:01:58 +01:00
committed by Nicola Demo
parent 86405ef597
commit 7702427e8d
2 changed files with 4 additions and 3 deletions

View File

@@ -1,6 +1,7 @@
__all__ = [ __all__ = [
"Trainer", "LabelTensor", "Plotter", "Condition", "Trainer", "LabelTensor", "Plotter", "Condition",
"PinaDataModule", 'TorchOptimizer', 'Graph', "PinaDataModule", 'TorchOptimizer', 'Graph',
"RadiusGraph", "KNNGraph"
] ]
from .meta import * from .meta import *
@@ -14,4 +15,4 @@ from .data import PinaDataModule
from .optim import TorchOptimizer from .optim import TorchOptimizer
from .optim import TorchScheduler from .optim import TorchScheduler
from .graph import Graph from .graph import Graph, RadiusGraph, KNNGraph

View File

@@ -223,8 +223,8 @@ class Graph:
return [edge_attr] * data_len return [edge_attr] * data_len
if build_edge_attr: if build_edge_attr:
return [self._build_edge_attr(x, pos_, edge_index_) for return [self._build_edge_attr(x_, pos_, edge_index_) for
pos_, edge_index_ in zip(pos, edge_index)] x_, pos_, edge_index_ in zip(x, pos, edge_index)]
class RadiusGraph(Graph): class RadiusGraph(Graph):