Fix bug in Collector with Graph data (#456)

* Fix bug in Collector with Graph data
* Add comments in DataModule class and bug fix in collate
This commit is contained in:
Filippo Olivo
2025-02-20 13:49:01 +01:00
committed by Nicola Demo
parent dfd6d7b467
commit 9c9d4fe7e4
6 changed files with 254 additions and 66 deletions

View File

@@ -108,16 +108,14 @@ class Graph:
x)
# Perform the graph construction
self._build_graph_list(x, pos, edge_index, edge_attr, additional_params)
self._build_graph_list(
x, pos, edge_index, edge_attr, additional_params)
def _build_graph_list(self, x, pos, edge_index, edge_attr,
additional_params):
for i, (x_, pos_, edge_index_) in enumerate(zip(x, pos, edge_index)):
if isinstance(x_, LabelTensor):
x_ = x_.tensor
add_params_local = {k: v[i] for k, v in additional_params.items()}
if edge_attr is not None:
self.data.append(Data(x=x_, pos=pos_, edge_index=edge_index_,
edge_attr=edge_attr[i],
**add_params_local))
@@ -127,7 +125,8 @@ class Graph:
@staticmethod
def _build_edge_attr(x, pos, edge_index):
distance = torch.abs(pos[edge_index[0]] - pos[edge_index[1]])
distance = torch.abs(pos[edge_index[0]] -
pos[edge_index[1]]).as_subclass(torch.Tensor)
return distance
@staticmethod
@@ -165,7 +164,8 @@ class Graph:
# If edge_index is a 3D tensor, we split it into a list of 2D tensors
if edge_index is not None:
if isinstance(edge_index, torch.Tensor) and edge_index.ndim == 3:
edge_index = [edge_index[i] for i in range(edge_index.shape[0])]
edge_index = [edge_index[i]
for i in range(edge_index.shape[0])]
elif not (isinstance(edge_index, list) and all(
t.ndim == 2 for t in edge_index)) and not (
isinstance(edge_index,
@@ -219,7 +219,7 @@ class Graph:
if isinstance(edge_attr, list):
if len(edge_attr) != data_len:
raise TypeError("edge_attr must have the same length as x "
"and pos.")
"and pos.")
return [edge_attr] * data_len
if build_edge_attr:
@@ -258,6 +258,8 @@ class RadiusGraph(Graph):
"""
dist = torch.cdist(points, points, p=2)
edge_index = torch.nonzero(dist <= r, as_tuple=False).t()
if isinstance(edge_index, LabelTensor):
edge_index = edge_index.tensor
return edge_index
@@ -293,4 +295,6 @@ class KNNGraph(Graph):
row = torch.arange(points.size(0)).repeat_interleave(k)
col = knn_indices.flatten()
edge_index = torch.stack([row, col], dim=0)
if isinstance(edge_index, LabelTensor):
edge_index = edge_index.tensor
return edge_index