add stack conditions option
This commit is contained in:
@@ -1,11 +1,13 @@
|
|||||||
"""DataLoader module for PinaDataset."""
|
"""DataLoader module for PinaDataset."""
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
|
import random
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from torch.utils.data.sampler import SequentialSampler
|
from torch.utils.data.sampler import SequentialSampler
|
||||||
|
from .stacked_dataloader import StackedDataLoader
|
||||||
|
|
||||||
|
|
||||||
class DummyDataloader:
|
class DummyDataloader:
|
||||||
@@ -152,6 +154,22 @@ class PinaDataLoader:
|
|||||||
Custom DataLoader for PinaDataset.
|
Custom DataLoader for PinaDataset.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
batching_mode = kwargs.get("batching_mode", "common_batch_size").lower()
|
||||||
|
batch_size = kwargs.get("batch_size")
|
||||||
|
if batching_mode == "stacked" and batch_size is not None:
|
||||||
|
return StackedDataLoader(
|
||||||
|
args[0],
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=kwargs.get("shuffle", True),
|
||||||
|
)
|
||||||
|
elif batch_size is None:
|
||||||
|
kwargs["batching_mode"] = "proportional"
|
||||||
|
print(
|
||||||
|
"Using PinaDataLoader with batching mode:", kwargs["batching_mode"]
|
||||||
|
)
|
||||||
|
return super(PinaDataLoader, cls).__new__(cls)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dataset_dict,
|
dataset_dict,
|
||||||
@@ -193,9 +211,10 @@ class PinaDataLoader:
|
|||||||
# (the sum of the batch sizes is equal to
|
# (the sum of the batch sizes is equal to
|
||||||
# n_conditions * batch_size)
|
# n_conditions * batch_size)
|
||||||
batch_size_per_dataset = {
|
batch_size_per_dataset = {
|
||||||
split: batch_size for split in dataset_dict.keys()
|
split: min(batch_size, len(ds))
|
||||||
|
for split, ds in dataset_dict.items()
|
||||||
}
|
}
|
||||||
elif batching_mode == "propotional":
|
elif batching_mode == "proportional":
|
||||||
# batch sizes is equal to the specified batch size)
|
# batch sizes is equal to the specified batch size)
|
||||||
batch_size_per_dataset = self._compute_batch_size()
|
batch_size_per_dataset = self._compute_batch_size()
|
||||||
|
|
||||||
@@ -296,30 +315,33 @@ class PinaDataLoader:
|
|||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
"""
|
"""
|
||||||
Iterate over the dataloader.
|
Iterate over the dataloader. Yields a dictionary mapping split name to batch.
|
||||||
|
|
||||||
:return: Yields batches from the dataloader.
|
The iteration logic for 'separate_conditions' is now iterative and memory-efficient.
|
||||||
:rtype: dict
|
|
||||||
"""
|
"""
|
||||||
if self.batching_mode == "separate_conditions":
|
if self.batching_mode == "separate_conditions":
|
||||||
|
tmp = []
|
||||||
for split, dl in self.dataloaders.items():
|
for split, dl in self.dataloaders.items():
|
||||||
for batch in dl:
|
len_split = len(dl)
|
||||||
yield {split: batch}
|
for i, batch in enumerate(dl):
|
||||||
|
tmp.append({split: batch})
|
||||||
|
if i + 1 >= len_split:
|
||||||
|
break
|
||||||
|
random.shuffle(tmp)
|
||||||
|
for batch_dict in tmp:
|
||||||
|
yield batch_dict
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Common_batch_size or Proportional mode (round-robin sampling)
|
||||||
iterators = {
|
iterators = {
|
||||||
split: itertools.cycle(dl) for split, dl in self.dataloaders.items()
|
split: itertools.cycle(dl) for split, dl in self.dataloaders.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Iterate for the length of the longest dataloader
|
||||||
for _ in range(len(self)):
|
for _ in range(len(self)):
|
||||||
batch_dict = {}
|
batch_dict: BatchDict = {}
|
||||||
for split, it in iterators.items():
|
for split, it in iterators.items():
|
||||||
|
# Since we use itertools.cycle, next(it) will always yield a batch
|
||||||
# Iterate through each dataloader and get the next batch
|
# by repeating the dataset, so no need for the 'if batch is None: return' check.
|
||||||
batch = next(it, None)
|
batch_dict[split] = next(it)
|
||||||
# Check if batch is None (in case of uneven lengths)
|
|
||||||
if batch is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
batch_dict[split] = batch
|
|
||||||
yield batch_dict
|
yield batch_dict
|
||||||
|
|||||||
53
pina/data/stacked_dataloader.py
Normal file
53
pina/data/stacked_dataloader.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
import torch
|
||||||
|
from math import ceil
|
||||||
|
|
||||||
|
|
||||||
|
class StackedDataLoader:
|
||||||
|
def __init__(self, datasets, batch_size=32, shuffle=True):
|
||||||
|
for d in datasets.values():
|
||||||
|
if d.is_graph_dataset:
|
||||||
|
raise ValueError("Each dataset must be a dictionary")
|
||||||
|
self.chunks = {}
|
||||||
|
self.total_length = 0
|
||||||
|
self.indices = []
|
||||||
|
|
||||||
|
self._init_chunks(datasets)
|
||||||
|
self.indices = list(range(self.total_length))
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.shuffle = shuffle
|
||||||
|
if self.shuffle:
|
||||||
|
torch.random.manual_seed(42)
|
||||||
|
self.indices = torch.randperm(self.total_length).tolist()
|
||||||
|
self.datasets = datasets
|
||||||
|
|
||||||
|
def _init_chunks(self, datasets):
|
||||||
|
inc = 0
|
||||||
|
total_length = 0
|
||||||
|
for name, dataset in datasets.items():
|
||||||
|
self.chunks[name] = {"start": inc, "end": inc + len(dataset)}
|
||||||
|
inc += len(dataset)
|
||||||
|
self.total_length = inc
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return ceil(self.total_length / self.batch_size)
|
||||||
|
|
||||||
|
def _build_batch_indices(self, batch_idx):
|
||||||
|
start = batch_idx * self.batch_size
|
||||||
|
end = min(start + self.batch_size, self.total_length)
|
||||||
|
return self.indices[start:end]
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for batch_idx in range(len(self)):
|
||||||
|
batch_indices = self._build_batch_indices(batch_idx)
|
||||||
|
batch_data = {}
|
||||||
|
for name, chunk in self.chunks.items():
|
||||||
|
local_indices = [
|
||||||
|
idx - chunk["start"]
|
||||||
|
for idx in batch_indices
|
||||||
|
if chunk["start"] <= idx < chunk["end"]
|
||||||
|
]
|
||||||
|
if local_indices:
|
||||||
|
batch_data[name] = self.datasets[name].getitem_from_list(
|
||||||
|
local_indices
|
||||||
|
)
|
||||||
|
yield batch_data
|
||||||
Reference in New Issue
Block a user