add stack conditions option
This commit is contained in:
53
pina/data/stacked_dataloader.py
Normal file
53
pina/data/stacked_dataloader.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import torch
|
||||
from math import ceil
|
||||
|
||||
|
||||
class StackedDataLoader:
|
||||
def __init__(self, datasets, batch_size=32, shuffle=True):
|
||||
for d in datasets.values():
|
||||
if d.is_graph_dataset:
|
||||
raise ValueError("Each dataset must be a dictionary")
|
||||
self.chunks = {}
|
||||
self.total_length = 0
|
||||
self.indices = []
|
||||
|
||||
self._init_chunks(datasets)
|
||||
self.indices = list(range(self.total_length))
|
||||
self.batch_size = batch_size
|
||||
self.shuffle = shuffle
|
||||
if self.shuffle:
|
||||
torch.random.manual_seed(42)
|
||||
self.indices = torch.randperm(self.total_length).tolist()
|
||||
self.datasets = datasets
|
||||
|
||||
def _init_chunks(self, datasets):
|
||||
inc = 0
|
||||
total_length = 0
|
||||
for name, dataset in datasets.items():
|
||||
self.chunks[name] = {"start": inc, "end": inc + len(dataset)}
|
||||
inc += len(dataset)
|
||||
self.total_length = inc
|
||||
|
||||
def __len__(self):
|
||||
return ceil(self.total_length / self.batch_size)
|
||||
|
||||
def _build_batch_indices(self, batch_idx):
|
||||
start = batch_idx * self.batch_size
|
||||
end = min(start + self.batch_size, self.total_length)
|
||||
return self.indices[start:end]
|
||||
|
||||
def __iter__(self):
|
||||
for batch_idx in range(len(self)):
|
||||
batch_indices = self._build_batch_indices(batch_idx)
|
||||
batch_data = {}
|
||||
for name, chunk in self.chunks.items():
|
||||
local_indices = [
|
||||
idx - chunk["start"]
|
||||
for idx in batch_indices
|
||||
if chunk["start"] <= idx < chunk["end"]
|
||||
]
|
||||
if local_indices:
|
||||
batch_data[name] = self.datasets[name].getitem_from_list(
|
||||
local_indices
|
||||
)
|
||||
yield batch_data
|
||||
Reference in New Issue
Block a user