Improve efficiency and refact LabelTensor, codacy correction and fix bug in PinaBatch

This commit is contained in:
FilippoOlivo
2024-10-23 15:04:28 +02:00
committed by Nicola Demo
parent ccc5f5a322
commit ea3d1924e7
13 changed files with 496 additions and 395 deletions

View File

@@ -10,13 +10,15 @@ class Batch:
optimization.
"""
def __init__(self, dataset_dict, idx_dict):
def __init__(self, dataset_dict, idx_dict, require_grad=True):
self.attributes = []
for k, v in dataset_dict.items():
setattr(self, k, v)
self.attributes.append(k)
for k, v in idx_dict.items():
setattr(self, k + '_idx', v)
self.require_grad = require_grad
def __len__(self):
"""
@@ -31,9 +33,18 @@ class Batch:
length += len(getattr(self, dataset))
return length
def __getattribute__(self, item):
if item in super().__getattribute__('attributes'):
dataset = super().__getattribute__(item)
index = super().__getattribute__(item + '_idx')
return PinaSubset(
dataset.dataset,
dataset.indices[index])
else:
return super().__getattribute__(item)
def __getattr__(self, item):
if not item in dir(self):
raise AttributeError(f'Batch instance has no attribute {item}')
return PinaSubset(
getattr(self, item).dataset,
getattr(self, item).indices[self.coordinates_dict[item]])
if item == 'data' and len(self.attributes) == 1:
item = self.attributes[0]
return super().__getattribute__(item)
raise AttributeError(f"'Batch' object has no attribute '{item}'")