diff --git a/pina/__init__.py b/pina/__init__.py index e9ce706..e69db88 100644 --- a/pina/__init__.py +++ b/pina/__init__.py @@ -1,6 +1,7 @@ __all__ = [ "Trainer", "LabelTensor", "Plotter", "Condition", "PinaDataModule", 'TorchOptimizer', 'Graph', + "RadiusGraph", "KNNGraph" ] from .meta import * @@ -14,4 +15,4 @@ from .data import PinaDataModule from .optim import TorchOptimizer from .optim import TorchScheduler -from .graph import Graph +from .graph import Graph, RadiusGraph, KNNGraph diff --git a/pina/graph.py b/pina/graph.py index 8c16741..959bd9c 100644 --- a/pina/graph.py +++ b/pina/graph.py @@ -223,8 +223,8 @@ class Graph: return [edge_attr] * data_len if build_edge_attr: - return [self._build_edge_attr(x, pos_, edge_index_) for - pos_, edge_index_ in zip(pos, edge_index)] + return [self._build_edge_attr(x_, pos_, edge_index_) for + x_, pos_, edge_index_ in zip(x, pos, edge_index)] class RadiusGraph(Graph):