Files
PINA/pina/data/pina_dataloader.py
2025-03-19 17:46:33 +01:00

69 lines
2.2 KiB
Python

"""
This module is used to create an iterable object used during training
"""
import math
from .pina_batch import Batch
class PinaDataLoader:
"""
This class is used to create a dataloader to use during the training.
:var condition_names: The names of the conditions. The order is consistent
with the condition indeces in the batches.
:vartype condition_names: list[str]
"""
def __init__(self, dataset_dict, batch_size, condition_names) -> None:
"""
Initialize local variables
:param dataset_dict: Dictionary of datasets
:type dataset_dict: dict
:param batch_size: Size of the batch
:type batch_size: int
:param condition_names: Names of the conditions
:type condition_names: list[str]
"""
self.condition_names = condition_names
self.dataset_dict = dataset_dict
self._init_batches(batch_size)
def _init_batches(self, batch_size=None):
"""
Create batches according to the batch_size provided in input.
"""
self.batches = []
n_elements = sum(len(v) for v in self.dataset_dict.values())
if batch_size is None:
batch_size = n_elements
indexes_dict = {}
n_batches = int(math.ceil(n_elements / batch_size))
for k, v in self.dataset_dict.items():
if n_batches != 1:
indexes_dict[k] = math.floor(len(v) / (n_batches - 1))
else:
indexes_dict[k] = len(v)
for i in range(n_batches):
temp_dict = {}
for k, v in indexes_dict.items():
if i != n_batches - 1:
temp_dict[k] = slice(i * v, (i + 1) * v)
else:
temp_dict[k] = slice(i * v, len(self.dataset_dict[k]))
self.batches.append(
Batch(idx_dict=temp_dict, dataset_dict=self.dataset_dict))
def __iter__(self):
"""
Makes dataloader object iterable
"""
yield from self.batches
def __len__(self):
"""
Return the number of batches.
:return: The number of batches.
:rtype: int
"""
return len(self.batches)