Improve efficiency and refact LabelTensor, codacy correction and fix bug in PinaBatch
This commit is contained in:
committed by
Nicola Demo
parent
ccc5f5a322
commit
ea3d1924e7
@@ -2,21 +2,22 @@
|
||||
Module for PinaSubset class
|
||||
"""
|
||||
from pina import LabelTensor
|
||||
from torch import Tensor
|
||||
from torch import Tensor, float32
|
||||
|
||||
|
||||
class PinaSubset:
|
||||
"""
|
||||
TODO
|
||||
"""
|
||||
__slots__ = ['dataset', 'indices']
|
||||
__slots__ = ['dataset', 'indices', 'require_grad']
|
||||
|
||||
def __init__(self, dataset, indices):
|
||||
def __init__(self, dataset, indices, require_grad=True):
|
||||
"""
|
||||
TODO
|
||||
"""
|
||||
self.dataset = dataset
|
||||
self.indices = indices
|
||||
self.require_grad = require_grad
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
@@ -27,7 +28,9 @@ class PinaSubset:
|
||||
def __getattr__(self, name):
|
||||
tensor = self.dataset.__getattribute__(name)
|
||||
if isinstance(tensor, (LabelTensor, Tensor)):
|
||||
return tensor[self.indices]
|
||||
tensor = tensor[[self.indices]].to(self.dataset.device)
|
||||
return tensor.requires_grad_(
|
||||
self.require_grad) if tensor.dtype == float32 else tensor
|
||||
if isinstance(tensor, list):
|
||||
return [tensor[i] for i in self.indices]
|
||||
raise AttributeError("No attribute named {}".format(name))
|
||||
raise AttributeError(f"No attribute named {name}")
|
||||
|
||||
Reference in New Issue
Block a user