Add Graph support in Dataset and Dataloader
This commit is contained in:
committed by
Nicola Demo
parent
eb146ea2ea
commit
ccc5f5a322
@@ -413,7 +413,6 @@ class LabelTensor(torch.Tensor):
|
||||
return selected_lt
|
||||
|
||||
def _getitem_permutation(self, index, selected_lt):
|
||||
|
||||
new_labels = deepcopy(self.full_labels)
|
||||
new_labels.update(self._update_label_for_dim(self.full_labels, index,
|
||||
0))
|
||||
@@ -429,6 +428,8 @@ class LabelTensor(torch.Tensor):
|
||||
:param dim:
|
||||
:return:
|
||||
"""
|
||||
if isinstance(index, torch.Tensor):
|
||||
index = index.nonzero()
|
||||
if isinstance(index, list):
|
||||
return {dim: {'dof': [old_labels[dim]['dof'][i] for i in index],
|
||||
'name': old_labels[dim]['name']}}
|
||||
@@ -436,7 +437,6 @@ class LabelTensor(torch.Tensor):
|
||||
return {dim: {'dof': old_labels[dim]['dof'][index],
|
||||
'name': old_labels[dim]['name']}}
|
||||
|
||||
|
||||
def sort_labels(self, dim=None):
|
||||
def argsort(lst):
|
||||
return sorted(range(len(lst)), key=lambda x: lst[x])
|
||||
|
||||
Reference in New Issue
Block a user