fix tensor getitem in graph_dataset (#633)
This commit is contained in:
@@ -276,20 +276,6 @@ class PinaGraphDataset(PinaDataset):
|
|||||||
batch = LabelBatch.from_data_list(data)
|
batch = LabelBatch.from_data_list(data)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def _create_tensor_batch(self, data):
|
|
||||||
"""
|
|
||||||
Reshape properly ``data`` tensor to be processed handle by the graph
|
|
||||||
based models.
|
|
||||||
|
|
||||||
:param data: torch.Tensor object of shape ``(N, ...)`` where ``N`` is
|
|
||||||
the number of data objects.
|
|
||||||
:type data: torch.Tensor | LabelTensor
|
|
||||||
:return: Reshaped tensor object.
|
|
||||||
:rtype: torch.Tensor | LabelTensor
|
|
||||||
"""
|
|
||||||
out = data.reshape(-1, *data.shape[2:])
|
|
||||||
return out
|
|
||||||
|
|
||||||
def create_batch(self, data):
|
def create_batch(self, data):
|
||||||
"""
|
"""
|
||||||
Create a Batch object from a list of :class:`~torch_geometric.data.Data`
|
Create a Batch object from a list of :class:`~torch_geometric.data.Data`
|
||||||
@@ -324,7 +310,7 @@ class PinaGraphDataset(PinaDataset):
|
|||||||
k: (
|
k: (
|
||||||
self._create_graph_batch([v[i] for i in idx_list])
|
self._create_graph_batch([v[i] for i in idx_list])
|
||||||
if isinstance(v, list)
|
if isinstance(v, list)
|
||||||
else self._create_tensor_batch(v[idx_list])
|
else v[idx_list]
|
||||||
)
|
)
|
||||||
for k, v in data.items()
|
for k, v in data.items()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -101,7 +101,7 @@ def test_getitem(conditions_dict, max_conditions_lengths):
|
|||||||
[d["input"].x.shape == torch.Size((400, 10)) for d in data.values()]
|
[d["input"].x.shape == torch.Size((400, 10)) for d in data.values()]
|
||||||
)
|
)
|
||||||
assert all(
|
assert all(
|
||||||
[d["target"].shape == torch.Size((400, 10)) for d in data.values()]
|
[d["target"].shape == torch.Size((20, 20, 10)) for d in data.values()]
|
||||||
)
|
)
|
||||||
assert all(
|
assert all(
|
||||||
[
|
[
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import torch
|
|||||||
import pytest
|
import pytest
|
||||||
from torch._dynamo.eval_frame import OptimizedModule
|
from torch._dynamo.eval_frame import OptimizedModule
|
||||||
from torch_geometric.nn import GCNConv
|
from torch_geometric.nn import GCNConv
|
||||||
|
from torch_geometric.utils import to_dense_batch
|
||||||
from pina import Condition, LabelTensor
|
from pina import Condition, LabelTensor
|
||||||
from pina.condition import InputTargetCondition
|
from pina.condition import InputTargetCondition
|
||||||
from pina.problem import AbstractProblem
|
from pina.problem import AbstractProblem
|
||||||
@@ -82,7 +83,7 @@ class Models(torch.nn.Module):
|
|||||||
y = self.conv(y, edge_index)
|
y = self.conv(y, edge_index)
|
||||||
y = self.activation(y)
|
y = self.activation(y)
|
||||||
y = self.output(y)
|
y = self.output(y)
|
||||||
return y
|
return to_dense_batch(y, batch.batch)[0]
|
||||||
|
|
||||||
|
|
||||||
graph_models = [Models() for i in range(10)]
|
graph_models = [Models() for i in range(10)]
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import torch
|
|||||||
import pytest
|
import pytest
|
||||||
from torch._dynamo.eval_frame import OptimizedModule
|
from torch._dynamo.eval_frame import OptimizedModule
|
||||||
from torch_geometric.nn import GCNConv
|
from torch_geometric.nn import GCNConv
|
||||||
|
from torch_geometric.utils import to_dense_batch
|
||||||
from pina import Condition, LabelTensor
|
from pina import Condition, LabelTensor
|
||||||
from pina.condition import InputTargetCondition
|
from pina.condition import InputTargetCondition
|
||||||
from pina.problem import AbstractProblem
|
from pina.problem import AbstractProblem
|
||||||
@@ -82,7 +83,7 @@ class Model(torch.nn.Module):
|
|||||||
y = self.conv(y, edge_index)
|
y = self.conv(y, edge_index)
|
||||||
y = self.activation(y)
|
y = self.activation(y)
|
||||||
y = self.output(y)
|
y = self.output(y)
|
||||||
return y
|
return to_dense_batch(y, batch.batch)[0]
|
||||||
|
|
||||||
|
|
||||||
graph_model = Model()
|
graph_model = Model()
|
||||||
|
|||||||
66
tutorials/tutorial15/tutorial.ipynb
vendored
66
tutorials/tutorial15/tutorial.ipynb
vendored
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user