Improve efficiency and refact LabelTensor, codacy correction and fix bug in PinaBatch

This commit is contained in:
FilippoOlivo
2024-10-23 15:04:28 +02:00
committed by Nicola Demo
parent ccc5f5a322
commit ea3d1924e7
13 changed files with 496 additions and 395 deletions

View File

@@ -2,21 +2,22 @@
Module for PinaSubset class
"""
from pina import LabelTensor
from torch import Tensor
from torch import Tensor, float32
class PinaSubset:
"""
TODO
"""
__slots__ = ['dataset', 'indices']
__slots__ = ['dataset', 'indices', 'require_grad']
def __init__(self, dataset, indices):
def __init__(self, dataset, indices, require_grad=True):
"""
TODO
"""
self.dataset = dataset
self.indices = indices
self.require_grad = require_grad
def __len__(self):
"""
@@ -27,7 +28,9 @@ class PinaSubset:
def __getattr__(self, name):
tensor = self.dataset.__getattribute__(name)
if isinstance(tensor, (LabelTensor, Tensor)):
return tensor[self.indices]
tensor = tensor[[self.indices]].to(self.dataset.device)
return tensor.requires_grad_(
self.require_grad) if tensor.dtype == float32 else tensor
if isinstance(tensor, list):
return [tensor[i] for i in self.indices]
raise AttributeError("No attribute named {}".format(name))
raise AttributeError(f"No attribute named {name}")