fix tensor getitem in graph_dataset (#633)

This commit is contained in:
Dario Coscia
2025-09-10 12:04:41 +02:00
committed by GitHub
parent 7469543499
commit 85b9edc74d
5 changed files with 39 additions and 51 deletions

View File

@@ -276,20 +276,6 @@ class PinaGraphDataset(PinaDataset):
batch = LabelBatch.from_data_list(data) batch = LabelBatch.from_data_list(data)
return batch return batch
def _create_tensor_batch(self, data):
"""
Reshape properly ``data`` tensor to be processed handle by the graph
based models.
:param data: torch.Tensor object of shape ``(N, ...)`` where ``N`` is
the number of data objects.
:type data: torch.Tensor | LabelTensor
:return: Reshaped tensor object.
:rtype: torch.Tensor | LabelTensor
"""
out = data.reshape(-1, *data.shape[2:])
return out
def create_batch(self, data): def create_batch(self, data):
""" """
Create a Batch object from a list of :class:`~torch_geometric.data.Data` Create a Batch object from a list of :class:`~torch_geometric.data.Data`
@@ -324,7 +310,7 @@ class PinaGraphDataset(PinaDataset):
k: ( k: (
self._create_graph_batch([v[i] for i in idx_list]) self._create_graph_batch([v[i] for i in idx_list])
if isinstance(v, list) if isinstance(v, list)
else self._create_tensor_batch(v[idx_list]) else v[idx_list]
) )
for k, v in data.items() for k, v in data.items()
} }

View File

@@ -101,7 +101,7 @@ def test_getitem(conditions_dict, max_conditions_lengths):
[d["input"].x.shape == torch.Size((400, 10)) for d in data.values()] [d["input"].x.shape == torch.Size((400, 10)) for d in data.values()]
) )
assert all( assert all(
[d["target"].shape == torch.Size((400, 10)) for d in data.values()] [d["target"].shape == torch.Size((20, 20, 10)) for d in data.values()]
) )
assert all( assert all(
[ [

View File

@@ -2,6 +2,7 @@ import torch
import pytest import pytest
from torch._dynamo.eval_frame import OptimizedModule from torch._dynamo.eval_frame import OptimizedModule
from torch_geometric.nn import GCNConv from torch_geometric.nn import GCNConv
from torch_geometric.utils import to_dense_batch
from pina import Condition, LabelTensor from pina import Condition, LabelTensor
from pina.condition import InputTargetCondition from pina.condition import InputTargetCondition
from pina.problem import AbstractProblem from pina.problem import AbstractProblem
@@ -82,7 +83,7 @@ class Models(torch.nn.Module):
y = self.conv(y, edge_index) y = self.conv(y, edge_index)
y = self.activation(y) y = self.activation(y)
y = self.output(y) y = self.output(y)
return y return to_dense_batch(y, batch.batch)[0]
graph_models = [Models() for i in range(10)] graph_models = [Models() for i in range(10)]

View File

@@ -2,6 +2,7 @@ import torch
import pytest import pytest
from torch._dynamo.eval_frame import OptimizedModule from torch._dynamo.eval_frame import OptimizedModule
from torch_geometric.nn import GCNConv from torch_geometric.nn import GCNConv
from torch_geometric.utils import to_dense_batch
from pina import Condition, LabelTensor from pina import Condition, LabelTensor
from pina.condition import InputTargetCondition from pina.condition import InputTargetCondition
from pina.problem import AbstractProblem from pina.problem import AbstractProblem
@@ -82,7 +83,7 @@ class Model(torch.nn.Module):
y = self.conv(y, edge_index) y = self.conv(y, edge_index)
y = self.activation(y) y = self.activation(y)
y = self.output(y) y = self.output(y)
return y return to_dense_batch(y, batch.batch)[0]
graph_model = Model() graph_model = Model()

File diff suppressed because one or more lines are too long