Improve conditions and refactor dataset classes (#475)

* Reimplement conditions

* Refactor datasets and implement LabelBatch

---------

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
Filippo Olivo
2025-03-07 11:24:09 +01:00
committed by Nicola Demo
parent bdad144461
commit a0cbf1c44a
40 changed files with 943 additions and 550 deletions

View File

@@ -3,7 +3,7 @@ This module provides an interface to build torch_geometric.data.Data objects.
"""
import torch
from torch_geometric.data import Data
from torch_geometric.data import Data, Batch
from torch_geometric.utils import to_undirected
from . import LabelTensor
from .utils import check_consistency, is_function
@@ -162,6 +162,21 @@ class Graph(Data):
edge_index = to_undirected(edge_index)
return edge_index
def extract(self, labels, attr="x"):
"""
Perform extraction of labels on node features (x)
:param labels: Labels to extract
:type labels: list[str] | tuple[str] | str
:return: Batch object with extraction performed on x
:rtype: PinaBatch
"""
# Extract labels from LabelTensor object
tensor = getattr(self, attr).extract(labels)
# Set the extracted tensor as the new attribute
setattr(self, attr, tensor)
return self
class GraphBuilder:
"""
@@ -317,3 +332,31 @@ class KNNGraph(GraphBuilder):
row = torch.arange(points.size(0)).repeat_interleave(k)
col = knn_indices.flatten()
return torch.stack([row, col], dim=0).as_subclass(torch.Tensor)
class LabelBatch(Batch):
"""
Add extract function to torch_geometric Batch object
"""
@classmethod
def from_data_list(cls, data_list):
"""
Create a Batch object from a list of Data objects.
"""
# Store the labels of Data/Graph objects (all data have the same labels)
# If the data do not contain labels, labels is an empty dictionary,
# therefore the labels are not stored
labels = {
k: v.labels
for k, v in data_list[0].items()
if isinstance(v, LabelTensor)
}
# Create a Batch object from the list of Data objects
batch = super().from_data_list(data_list)
# Put the labels back in the Batch object
for k, v in labels.items():
batch[k].labels = v
return batch