Improve conditions and refactor dataset classes (#475)
* Reimplement conditions * Refactor datasets and implement LabelBatch --------- Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
bdad144461
commit
a0cbf1c44a
@@ -3,7 +3,7 @@ This module provides an interface to build torch_geometric.data.Data objects.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch_geometric.data import Data
|
||||
from torch_geometric.data import Data, Batch
|
||||
from torch_geometric.utils import to_undirected
|
||||
from . import LabelTensor
|
||||
from .utils import check_consistency, is_function
|
||||
@@ -162,6 +162,21 @@ class Graph(Data):
|
||||
edge_index = to_undirected(edge_index)
|
||||
return edge_index
|
||||
|
||||
def extract(self, labels, attr="x"):
|
||||
"""
|
||||
Perform extraction of labels on node features (x)
|
||||
|
||||
:param labels: Labels to extract
|
||||
:type labels: list[str] | tuple[str] | str
|
||||
:return: Batch object with extraction performed on x
|
||||
:rtype: PinaBatch
|
||||
"""
|
||||
# Extract labels from LabelTensor object
|
||||
tensor = getattr(self, attr).extract(labels)
|
||||
# Set the extracted tensor as the new attribute
|
||||
setattr(self, attr, tensor)
|
||||
return self
|
||||
|
||||
|
||||
class GraphBuilder:
|
||||
"""
|
||||
@@ -317,3 +332,31 @@ class KNNGraph(GraphBuilder):
|
||||
row = torch.arange(points.size(0)).repeat_interleave(k)
|
||||
col = knn_indices.flatten()
|
||||
return torch.stack([row, col], dim=0).as_subclass(torch.Tensor)
|
||||
|
||||
|
||||
class LabelBatch(Batch):
|
||||
"""
|
||||
Add extract function to torch_geometric Batch object
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_data_list(cls, data_list):
|
||||
"""
|
||||
Create a Batch object from a list of Data objects.
|
||||
"""
|
||||
# Store the labels of Data/Graph objects (all data have the same labels)
|
||||
# If the data do not contain labels, labels is an empty dictionary,
|
||||
# therefore the labels are not stored
|
||||
labels = {
|
||||
k: v.labels
|
||||
for k, v in data_list[0].items()
|
||||
if isinstance(v, LabelTensor)
|
||||
}
|
||||
|
||||
# Create a Batch object from the list of Data objects
|
||||
batch = super().from_data_list(data_list)
|
||||
|
||||
# Put the labels back in the Batch object
|
||||
for k, v in labels.items():
|
||||
batch[k].labels = v
|
||||
return batch
|
||||
|
||||
Reference in New Issue
Block a user