Files
PINA/pina/data/pina_batch.py
2025-03-19 17:46:34 +01:00

48 lines
1.5 KiB
Python

"""
Batch management module
"""
from .pina_subset import PinaSubset
class Batch:
"""
Implementation of the Batch class used during training to perform SGD
optimization.
"""
def __init__(self, dataset_dict, idx_dict, require_grad=True):
self.attributes = []
for k, v in dataset_dict.items():
setattr(self, k, v)
self.attributes.append(k)
for k, v in idx_dict.items():
setattr(self, k + '_idx', v)
self.require_grad = require_grad
def __len__(self):
"""
Returns the number of elements in the batch
:return: number of elements in the batch
:rtype: int
"""
length = 0
for dataset in dir(self):
attribute = getattr(self, dataset)
if isinstance(attribute, list):
length += len(getattr(self, dataset))
return length
def __getattribute__(self, item):
if item in super().__getattribute__('attributes'):
dataset = super().__getattribute__(item)
index = super().__getattribute__(item + '_idx')
return PinaSubset(dataset.dataset, dataset.indices[index])
return super().__getattribute__(item)
def __getattr__(self, item):
if item == 'data' and len(self.attributes) == 1:
item = self.attributes[0]
return super().__getattribute__(item)
raise AttributeError(f"'Batch' object has no attribute '{item}'")