Add Graph support in Dataset and Dataloader

This commit is contained in:
FilippoOlivo
2024-10-23 15:04:28 +02:00
committed by Nicola Demo
parent eb146ea2ea
commit ccc5f5a322
11 changed files with 125 additions and 75 deletions

View File

@@ -1,6 +1,8 @@
"""
Module for PinaSubset class
"""
from pina import LabelTensor
from torch import Tensor
class PinaSubset:
@@ -23,4 +25,9 @@ class PinaSubset:
return len(self.indices)
def __getattr__(self, name):
return self.dataset.__getattribute__(name)
tensor = self.dataset.__getattribute__(name)
if isinstance(tensor, (LabelTensor, Tensor)):
return tensor[self.indices]
if isinstance(tensor, list):
return [tensor[i] for i in self.indices]
raise AttributeError("No attribute named {}".format(name))