Fix bug in Collector with Graph data (#456)
* Fix bug in Collector with Graph data * Add comments in DataModule class and bug fix in collate
This commit is contained in:
committed by
Nicola Demo
parent
dfd6d7b467
commit
9c9d4fe7e4
@@ -108,16 +108,14 @@ class Graph:
|
||||
x)
|
||||
|
||||
# Perform the graph construction
|
||||
self._build_graph_list(x, pos, edge_index, edge_attr, additional_params)
|
||||
self._build_graph_list(
|
||||
x, pos, edge_index, edge_attr, additional_params)
|
||||
|
||||
def _build_graph_list(self, x, pos, edge_index, edge_attr,
|
||||
additional_params):
|
||||
for i, (x_, pos_, edge_index_) in enumerate(zip(x, pos, edge_index)):
|
||||
if isinstance(x_, LabelTensor):
|
||||
x_ = x_.tensor
|
||||
add_params_local = {k: v[i] for k, v in additional_params.items()}
|
||||
if edge_attr is not None:
|
||||
|
||||
self.data.append(Data(x=x_, pos=pos_, edge_index=edge_index_,
|
||||
edge_attr=edge_attr[i],
|
||||
**add_params_local))
|
||||
@@ -127,7 +125,8 @@ class Graph:
|
||||
|
||||
@staticmethod
|
||||
def _build_edge_attr(x, pos, edge_index):
|
||||
distance = torch.abs(pos[edge_index[0]] - pos[edge_index[1]])
|
||||
distance = torch.abs(pos[edge_index[0]] -
|
||||
pos[edge_index[1]]).as_subclass(torch.Tensor)
|
||||
return distance
|
||||
|
||||
@staticmethod
|
||||
@@ -165,7 +164,8 @@ class Graph:
|
||||
# If edge_index is a 3D tensor, we split it into a list of 2D tensors
|
||||
if edge_index is not None:
|
||||
if isinstance(edge_index, torch.Tensor) and edge_index.ndim == 3:
|
||||
edge_index = [edge_index[i] for i in range(edge_index.shape[0])]
|
||||
edge_index = [edge_index[i]
|
||||
for i in range(edge_index.shape[0])]
|
||||
elif not (isinstance(edge_index, list) and all(
|
||||
t.ndim == 2 for t in edge_index)) and not (
|
||||
isinstance(edge_index,
|
||||
@@ -219,7 +219,7 @@ class Graph:
|
||||
if isinstance(edge_attr, list):
|
||||
if len(edge_attr) != data_len:
|
||||
raise TypeError("edge_attr must have the same length as x "
|
||||
"and pos.")
|
||||
"and pos.")
|
||||
return [edge_attr] * data_len
|
||||
|
||||
if build_edge_attr:
|
||||
@@ -258,6 +258,8 @@ class RadiusGraph(Graph):
|
||||
"""
|
||||
dist = torch.cdist(points, points, p=2)
|
||||
edge_index = torch.nonzero(dist <= r, as_tuple=False).t()
|
||||
if isinstance(edge_index, LabelTensor):
|
||||
edge_index = edge_index.tensor
|
||||
return edge_index
|
||||
|
||||
|
||||
@@ -293,4 +295,6 @@ class KNNGraph(Graph):
|
||||
row = torch.arange(points.size(0)).repeat_interleave(k)
|
||||
col = knn_indices.flatten()
|
||||
edge_index = torch.stack([row, col], dim=0)
|
||||
if isinstance(edge_index, LabelTensor):
|
||||
edge_index = edge_index.tensor
|
||||
return edge_index
|
||||
|
||||
Reference in New Issue
Block a user