From 7702427e8d4aaeefb6886e5312bc5a1bfe941587 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Sun, 9 Feb 2025 13:01:58 +0100 Subject: [PATCH] Bug fix in Graph class --- pina/__init__.py | 3 ++- pina/graph.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pina/__init__.py b/pina/__init__.py index e9ce706..e69db88 100644 --- a/pina/__init__.py +++ b/pina/__init__.py @@ -1,6 +1,7 @@ __all__ = [ "Trainer", "LabelTensor", "Plotter", "Condition", "PinaDataModule", 'TorchOptimizer', 'Graph', + "RadiusGraph", "KNNGraph" ] from .meta import * @@ -14,4 +15,4 @@ from .data import PinaDataModule from .optim import TorchOptimizer from .optim import TorchScheduler -from .graph import Graph +from .graph import Graph, RadiusGraph, KNNGraph diff --git a/pina/graph.py b/pina/graph.py index 8c16741..959bd9c 100644 --- a/pina/graph.py +++ b/pina/graph.py @@ -223,8 +223,8 @@ class Graph: return [edge_attr] * data_len if build_edge_attr: - return [self._build_edge_attr(x, pos_, edge_index_) for - pos_, edge_index_ in zip(pos, edge_index)] + return [self._build_edge_attr(x_, pos_, edge_index_) for + x_, pos_, edge_index_ in zip(x, pos, edge_index)] class RadiusGraph(Graph):