Add Graph support in Dataset and Dataloader
This commit is contained in:
committed by
Nicola Demo
parent
eb146ea2ea
commit
ccc5f5a322
@@ -1,6 +1,8 @@
|
||||
"""
|
||||
Module for PinaSubset class
|
||||
"""
|
||||
from pina import LabelTensor
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class PinaSubset:
|
||||
@@ -23,4 +25,9 @@ class PinaSubset:
|
||||
return len(self.indices)
|
||||
|
||||
def __getattr__(self, name):
|
||||
return self.dataset.__getattribute__(name)
|
||||
tensor = self.dataset.__getattribute__(name)
|
||||
if isinstance(tensor, (LabelTensor, Tensor)):
|
||||
return tensor[self.indices]
|
||||
if isinstance(tensor, list):
|
||||
return [tensor[i] for i in self.indices]
|
||||
raise AttributeError("No attribute named {}".format(name))
|
||||
|
||||
Reference in New Issue
Block a user