fix tensor getitem in graph_dataset (#633)

This commit is contained in:
Dario Coscia
2025-09-10 12:04:41 +02:00
committed by GitHub
parent 7469543499
commit 85b9edc74d
5 changed files with 39 additions and 51 deletions

View File

@@ -2,6 +2,7 @@ import torch
import pytest
from torch._dynamo.eval_frame import OptimizedModule
from torch_geometric.nn import GCNConv
from torch_geometric.utils import to_dense_batch
from pina import Condition, LabelTensor
from pina.condition import InputTargetCondition
from pina.problem import AbstractProblem
@@ -82,7 +83,7 @@ class Model(torch.nn.Module):
y = self.conv(y, edge_index)
y = self.activation(y)
y = self.output(y)
return y
return to_dense_batch(y, batch.batch)[0]
graph_model = Model()