add stack conditions option
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
"""DataLoader module for PinaDataset."""
|
||||
|
||||
import itertools
|
||||
import random
|
||||
from functools import partial
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data.sampler import SequentialSampler
|
||||
from .stacked_dataloader import StackedDataLoader
|
||||
|
||||
|
||||
class DummyDataloader:
|
||||
@@ -152,6 +154,22 @@ class PinaDataLoader:
|
||||
Custom DataLoader for PinaDataset.
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
batching_mode = kwargs.get("batching_mode", "common_batch_size").lower()
|
||||
batch_size = kwargs.get("batch_size")
|
||||
if batching_mode == "stacked" and batch_size is not None:
|
||||
return StackedDataLoader(
|
||||
args[0],
|
||||
batch_size=batch_size,
|
||||
shuffle=kwargs.get("shuffle", True),
|
||||
)
|
||||
elif batch_size is None:
|
||||
kwargs["batching_mode"] = "proportional"
|
||||
print(
|
||||
"Using PinaDataLoader with batching mode:", kwargs["batching_mode"]
|
||||
)
|
||||
return super(PinaDataLoader, cls).__new__(cls)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_dict,
|
||||
@@ -193,9 +211,10 @@ class PinaDataLoader:
|
||||
# (the sum of the batch sizes is equal to
|
||||
# n_conditions * batch_size)
|
||||
batch_size_per_dataset = {
|
||||
split: batch_size for split in dataset_dict.keys()
|
||||
split: min(batch_size, len(ds))
|
||||
for split, ds in dataset_dict.items()
|
||||
}
|
||||
elif batching_mode == "propotional":
|
||||
elif batching_mode == "proportional":
|
||||
# batch sizes is equal to the specified batch size)
|
||||
batch_size_per_dataset = self._compute_batch_size()
|
||||
|
||||
@@ -296,30 +315,33 @@ class PinaDataLoader:
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
Iterate over the dataloader.
|
||||
Iterate over the dataloader. Yields a dictionary mapping split name to batch.
|
||||
|
||||
:return: Yields batches from the dataloader.
|
||||
:rtype: dict
|
||||
The iteration logic for 'separate_conditions' is now iterative and memory-efficient.
|
||||
"""
|
||||
if self.batching_mode == "separate_conditions":
|
||||
tmp = []
|
||||
for split, dl in self.dataloaders.items():
|
||||
for batch in dl:
|
||||
yield {split: batch}
|
||||
len_split = len(dl)
|
||||
for i, batch in enumerate(dl):
|
||||
tmp.append({split: batch})
|
||||
if i + 1 >= len_split:
|
||||
break
|
||||
random.shuffle(tmp)
|
||||
for batch_dict in tmp:
|
||||
yield batch_dict
|
||||
return
|
||||
|
||||
# Common_batch_size or Proportional mode (round-robin sampling)
|
||||
iterators = {
|
||||
split: itertools.cycle(dl) for split, dl in self.dataloaders.items()
|
||||
}
|
||||
|
||||
# Iterate for the length of the longest dataloader
|
||||
for _ in range(len(self)):
|
||||
batch_dict = {}
|
||||
batch_dict: BatchDict = {}
|
||||
for split, it in iterators.items():
|
||||
|
||||
# Iterate through each dataloader and get the next batch
|
||||
batch = next(it, None)
|
||||
# Check if batch is None (in case of uneven lengths)
|
||||
if batch is None:
|
||||
return
|
||||
|
||||
batch_dict[split] = batch
|
||||
# Since we use itertools.cycle, next(it) will always yield a batch
|
||||
# by repeating the dataset, so no need for the 'if batch is None: return' check.
|
||||
batch_dict[split] = next(it)
|
||||
yield batch_dict
|
||||
|
||||
53
pina/data/stacked_dataloader.py
Normal file
53
pina/data/stacked_dataloader.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import torch
|
||||
from math import ceil
|
||||
|
||||
|
||||
class StackedDataLoader:
|
||||
def __init__(self, datasets, batch_size=32, shuffle=True):
|
||||
for d in datasets.values():
|
||||
if d.is_graph_dataset:
|
||||
raise ValueError("Each dataset must be a dictionary")
|
||||
self.chunks = {}
|
||||
self.total_length = 0
|
||||
self.indices = []
|
||||
|
||||
self._init_chunks(datasets)
|
||||
self.indices = list(range(self.total_length))
|
||||
self.batch_size = batch_size
|
||||
self.shuffle = shuffle
|
||||
if self.shuffle:
|
||||
torch.random.manual_seed(42)
|
||||
self.indices = torch.randperm(self.total_length).tolist()
|
||||
self.datasets = datasets
|
||||
|
||||
def _init_chunks(self, datasets):
|
||||
inc = 0
|
||||
total_length = 0
|
||||
for name, dataset in datasets.items():
|
||||
self.chunks[name] = {"start": inc, "end": inc + len(dataset)}
|
||||
inc += len(dataset)
|
||||
self.total_length = inc
|
||||
|
||||
def __len__(self):
|
||||
return ceil(self.total_length / self.batch_size)
|
||||
|
||||
def _build_batch_indices(self, batch_idx):
|
||||
start = batch_idx * self.batch_size
|
||||
end = min(start + self.batch_size, self.total_length)
|
||||
return self.indices[start:end]
|
||||
|
||||
def __iter__(self):
|
||||
for batch_idx in range(len(self)):
|
||||
batch_indices = self._build_batch_indices(batch_idx)
|
||||
batch_data = {}
|
||||
for name, chunk in self.chunks.items():
|
||||
local_indices = [
|
||||
idx - chunk["start"]
|
||||
for idx in batch_indices
|
||||
if chunk["start"] <= idx < chunk["end"]
|
||||
]
|
||||
if local_indices:
|
||||
batch_data[name] = self.datasets[name].getitem_from_list(
|
||||
local_indices
|
||||
)
|
||||
yield batch_data
|
||||
Reference in New Issue
Block a user