fix tensor getitem in graph_dataset (#633)
This commit is contained in:
@@ -101,7 +101,7 @@ def test_getitem(conditions_dict, max_conditions_lengths):
|
||||
[d["input"].x.shape == torch.Size((400, 10)) for d in data.values()]
|
||||
)
|
||||
assert all(
|
||||
[d["target"].shape == torch.Size((400, 10)) for d in data.values()]
|
||||
[d["target"].shape == torch.Size((20, 20, 10)) for d in data.values()]
|
||||
)
|
||||
assert all(
|
||||
[
|
||||
|
||||
@@ -2,6 +2,7 @@ import torch
|
||||
import pytest
|
||||
from torch._dynamo.eval_frame import OptimizedModule
|
||||
from torch_geometric.nn import GCNConv
|
||||
from torch_geometric.utils import to_dense_batch
|
||||
from pina import Condition, LabelTensor
|
||||
from pina.condition import InputTargetCondition
|
||||
from pina.problem import AbstractProblem
|
||||
@@ -82,7 +83,7 @@ class Models(torch.nn.Module):
|
||||
y = self.conv(y, edge_index)
|
||||
y = self.activation(y)
|
||||
y = self.output(y)
|
||||
return y
|
||||
return to_dense_batch(y, batch.batch)[0]
|
||||
|
||||
|
||||
graph_models = [Models() for i in range(10)]
|
||||
|
||||
@@ -2,6 +2,7 @@ import torch
|
||||
import pytest
|
||||
from torch._dynamo.eval_frame import OptimizedModule
|
||||
from torch_geometric.nn import GCNConv
|
||||
from torch_geometric.utils import to_dense_batch
|
||||
from pina import Condition, LabelTensor
|
||||
from pina.condition import InputTargetCondition
|
||||
from pina.problem import AbstractProblem
|
||||
@@ -82,7 +83,7 @@ class Model(torch.nn.Module):
|
||||
y = self.conv(y, edge_index)
|
||||
y = self.activation(y)
|
||||
y = self.output(y)
|
||||
return y
|
||||
return to_dense_batch(y, batch.batch)[0]
|
||||
|
||||
|
||||
graph_model = Model()
|
||||
|
||||
Reference in New Issue
Block a user