Files
PINA/pina/data/pina_subset.py

37 lines
969 B
Python

"""
Module for PinaSubset class
"""
from pina import LabelTensor
from torch import Tensor, float32
class PinaSubset:
"""
TODO
"""
__slots__ = ['dataset', 'indices', 'require_grad']
def __init__(self, dataset, indices, require_grad=True):
"""
TODO
"""
self.dataset = dataset
self.indices = indices
self.require_grad = require_grad
def __len__(self):
"""
TODO
"""
return len(self.indices)
def __getattr__(self, name):
tensor = self.dataset.__getattribute__(name)
if isinstance(tensor, (LabelTensor, Tensor)):
tensor = tensor[[self.indices]].to(self.dataset.device)
return tensor.requires_grad_(
self.require_grad) if tensor.dtype == float32 else tensor
if isinstance(tensor, list):
return [tensor[i] for i in self.indices]
raise AttributeError(f"No attribute named {name}")