34 lines
958 B
Python
34 lines
958 B
Python
"""
|
|
Batch management module
|
|
"""
|
|
from .pina_subset import PinaSubset
|
|
|
|
|
|
class Batch:
|
|
def __init__(self, dataset_dict, idx_dict):
|
|
|
|
for k, v in dataset_dict.items():
|
|
setattr(self, k, v)
|
|
|
|
for k, v in idx_dict.items():
|
|
setattr(self, k + '_idx', v)
|
|
|
|
def __len__(self):
|
|
"""
|
|
Returns the number of elements in the batch
|
|
:return: number of elements in the batch
|
|
:rtype: int
|
|
"""
|
|
length = 0
|
|
for dataset in dir(self):
|
|
attribute = getattr(self, dataset)
|
|
if isinstance(attribute, list):
|
|
length += len(getattr(self, dataset))
|
|
return length
|
|
|
|
def __getattr__(self, item):
|
|
if not item in dir(self):
|
|
raise AttributeError(f'Batch instance has no attribute {item}')
|
|
return PinaSubset(getattr(self, item).dataset,
|
|
getattr(self, item).indices[self.coordinates_dict[item]])
|