fix tensor getitem in graph_dataset (#633)
This commit is contained in:
@@ -276,20 +276,6 @@ class PinaGraphDataset(PinaDataset):
|
||||
batch = LabelBatch.from_data_list(data)
|
||||
return batch
|
||||
|
||||
def _create_tensor_batch(self, data):
|
||||
"""
|
||||
Reshape properly ``data`` tensor to be processed handle by the graph
|
||||
based models.
|
||||
|
||||
:param data: torch.Tensor object of shape ``(N, ...)`` where ``N`` is
|
||||
the number of data objects.
|
||||
:type data: torch.Tensor | LabelTensor
|
||||
:return: Reshaped tensor object.
|
||||
:rtype: torch.Tensor | LabelTensor
|
||||
"""
|
||||
out = data.reshape(-1, *data.shape[2:])
|
||||
return out
|
||||
|
||||
def create_batch(self, data):
|
||||
"""
|
||||
Create a Batch object from a list of :class:`~torch_geometric.data.Data`
|
||||
@@ -324,7 +310,7 @@ class PinaGraphDataset(PinaDataset):
|
||||
k: (
|
||||
self._create_graph_batch([v[i] for i in idx_list])
|
||||
if isinstance(v, list)
|
||||
else self._create_tensor_batch(v[idx_list])
|
||||
else v[idx_list]
|
||||
)
|
||||
for k, v in data.items()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user