Add Graph support in Dataset and Dataloader

This commit is contained in:
FilippoOlivo
2024-10-23 15:04:28 +02:00
committed by Nicola Demo
parent eb146ea2ea
commit ccc5f5a322
11 changed files with 125 additions and 75 deletions

View File

@@ -413,7 +413,6 @@ class LabelTensor(torch.Tensor):
return selected_lt
def _getitem_permutation(self, index, selected_lt):
new_labels = deepcopy(self.full_labels)
new_labels.update(self._update_label_for_dim(self.full_labels, index,
0))
@@ -429,6 +428,8 @@ class LabelTensor(torch.Tensor):
:param dim:
:return:
"""
if isinstance(index, torch.Tensor):
index = index.nonzero()
if isinstance(index, list):
return {dim: {'dof': [old_labels[dim]['dof'][i] for i in index],
'name': old_labels[dim]['name']}}
@@ -436,7 +437,6 @@ class LabelTensor(torch.Tensor):
return {dim: {'dof': old_labels[dim]['dof'][index],
'name': old_labels[dim]['name']}}
def sort_labels(self, dim=None):
def argsort(lst):
return sorted(range(len(lst)), key=lambda x: lst[x])