Bug fix in Graph class
This commit is contained in:
committed by
Nicola Demo
parent
86405ef597
commit
7702427e8d
@@ -1,6 +1,7 @@
|
||||
__all__ = [
|
||||
"Trainer", "LabelTensor", "Plotter", "Condition",
|
||||
"PinaDataModule", 'TorchOptimizer', 'Graph',
|
||||
"RadiusGraph", "KNNGraph"
|
||||
]
|
||||
|
||||
from .meta import *
|
||||
@@ -14,4 +15,4 @@ from .data import PinaDataModule
|
||||
|
||||
from .optim import TorchOptimizer
|
||||
from .optim import TorchScheduler
|
||||
from .graph import Graph
|
||||
from .graph import Graph, RadiusGraph, KNNGraph
|
||||
|
||||
@@ -223,8 +223,8 @@ class Graph:
|
||||
return [edge_attr] * data_len
|
||||
|
||||
if build_edge_attr:
|
||||
return [self._build_edge_attr(x, pos_, edge_index_) for
|
||||
pos_, edge_index_ in zip(pos, edge_index)]
|
||||
return [self._build_edge_attr(x_, pos_, edge_index_) for
|
||||
x_, pos_, edge_index_ in zip(x, pos, edge_index)]
|
||||
|
||||
|
||||
class RadiusGraph(Graph):
|
||||
|
||||
Reference in New Issue
Block a user