37 lines
969 B
Python
37 lines
969 B
Python
"""
|
|
Module for PinaSubset class
|
|
"""
|
|
from pina import LabelTensor
|
|
from torch import Tensor, float32
|
|
|
|
|
|
class PinaSubset:
|
|
"""
|
|
TODO
|
|
"""
|
|
__slots__ = ['dataset', 'indices', 'require_grad']
|
|
|
|
def __init__(self, dataset, indices, require_grad=True):
|
|
"""
|
|
TODO
|
|
"""
|
|
self.dataset = dataset
|
|
self.indices = indices
|
|
self.require_grad = require_grad
|
|
|
|
def __len__(self):
|
|
"""
|
|
TODO
|
|
"""
|
|
return len(self.indices)
|
|
|
|
def __getattr__(self, name):
|
|
tensor = self.dataset.__getattribute__(name)
|
|
if isinstance(tensor, (LabelTensor, Tensor)):
|
|
tensor = tensor[[self.indices]].to(self.dataset.device)
|
|
return tensor.requires_grad_(
|
|
self.require_grad) if tensor.dtype == float32 else tensor
|
|
if isinstance(tensor, list):
|
|
return [tensor[i] for i in self.indices]
|
|
raise AttributeError(f"No attribute named {name}")
|