Improve efficiency and refact LabelTensor, codacy correction and fix bug in PinaBatch
This commit is contained in:
committed by
Nicola Demo
parent
ccc5f5a322
commit
ea3d1924e7
@@ -10,13 +10,15 @@ class Batch:
|
||||
optimization.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset_dict, idx_dict):
|
||||
|
||||
def __init__(self, dataset_dict, idx_dict, require_grad=True):
|
||||
self.attributes = []
|
||||
for k, v in dataset_dict.items():
|
||||
setattr(self, k, v)
|
||||
self.attributes.append(k)
|
||||
|
||||
for k, v in idx_dict.items():
|
||||
setattr(self, k + '_idx', v)
|
||||
self.require_grad = require_grad
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
@@ -31,9 +33,18 @@ class Batch:
|
||||
length += len(getattr(self, dataset))
|
||||
return length
|
||||
|
||||
def __getattribute__(self, item):
|
||||
if item in super().__getattribute__('attributes'):
|
||||
dataset = super().__getattribute__(item)
|
||||
index = super().__getattribute__(item + '_idx')
|
||||
return PinaSubset(
|
||||
dataset.dataset,
|
||||
dataset.indices[index])
|
||||
else:
|
||||
return super().__getattribute__(item)
|
||||
|
||||
def __getattr__(self, item):
|
||||
if not item in dir(self):
|
||||
raise AttributeError(f'Batch instance has no attribute {item}')
|
||||
return PinaSubset(
|
||||
getattr(self, item).dataset,
|
||||
getattr(self, item).indices[self.coordinates_dict[item]])
|
||||
if item == 'data' and len(self.attributes) == 1:
|
||||
item = self.attributes[0]
|
||||
return super().__getattribute__(item)
|
||||
raise AttributeError(f"'Batch' object has no attribute '{item}'")
|
||||
Reference in New Issue
Block a user