fix tests and modules
This commit is contained in:
@@ -27,8 +27,7 @@ class PinaDataModule(LightningDataModule):
|
|||||||
val_size=0.1,
|
val_size=0.1,
|
||||||
batch_size=None,
|
batch_size=None,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
common_batch_size=True,
|
batching_mode="common_batch_size",
|
||||||
separate_conditions=False,
|
|
||||||
automatic_batching=None,
|
automatic_batching=None,
|
||||||
num_workers=0,
|
num_workers=0,
|
||||||
pin_memory=False,
|
pin_memory=False,
|
||||||
@@ -84,8 +83,7 @@ class PinaDataModule(LightningDataModule):
|
|||||||
# Store fixed attributes
|
# Store fixed attributes
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
self.common_batch_size = common_batch_size
|
self.batching_mode = batching_mode
|
||||||
self.separate_conditions = separate_conditions
|
|
||||||
self.automatic_batching = automatic_batching
|
self.automatic_batching = automatic_batching
|
||||||
|
|
||||||
# If batch size is None, num_workers has no effect
|
# If batch size is None, num_workers has no effect
|
||||||
@@ -280,8 +278,7 @@ class PinaDataModule(LightningDataModule):
|
|||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
shuffle=self.shuffle,
|
shuffle=self.shuffle,
|
||||||
num_workers=self.num_workers,
|
num_workers=self.num_workers,
|
||||||
common_batch_size=self.common_batch_size,
|
batching_mode=self.batching_mode,
|
||||||
separate_conditions=self.separate_conditions,
|
|
||||||
device=self.trainer.strategy.root_device,
|
device=self.trainer.strategy.root_device,
|
||||||
)
|
)
|
||||||
if self.batch_size is None:
|
if self.batch_size is None:
|
||||||
@@ -330,7 +327,7 @@ class PinaDataModule(LightningDataModule):
|
|||||||
:rtype: list[tuple]
|
:rtype: list[tuple]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return [(k, v) for k, v in batch.items()]
|
return list(batch.items())
|
||||||
|
|
||||||
def _transfer_batch_to_device(self, batch, device, dataloader_idx):
|
def _transfer_batch_to_device(self, batch, device, dataloader_idx):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ class DummyDataloader:
|
|||||||
idx.append(i)
|
idx.append(i)
|
||||||
i += world_size
|
i += world_size
|
||||||
else:
|
else:
|
||||||
idx = [i for i in range(len(dataset))]
|
idx = list(range(len(dataset)))
|
||||||
self.dataset = dataset.getitem_from_list(idx)
|
self.dataset = dataset.getitem_from_list(idx)
|
||||||
self.device = device
|
self.device = device
|
||||||
self.dataset = (
|
self.dataset = (
|
||||||
@@ -158,15 +158,25 @@ class PinaDataLoader:
|
|||||||
batch_size,
|
batch_size,
|
||||||
num_workers=0,
|
num_workers=0,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
common_batch_size=True,
|
batching_mode="common_batch_size",
|
||||||
separate_conditions=False,
|
|
||||||
device=None,
|
device=None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Initialize the PinaDataLoader.
|
||||||
|
:param dict dataset_dict: A dictionary mapping dataset names to their
|
||||||
|
respective PinaDataset instances.
|
||||||
|
:param int batch_size: The batch size for the dataloader.
|
||||||
|
:param int num_workers: Number of worker processes for data loading.
|
||||||
|
:param bool shuffle: Whether to shuffle the data at every epoch.
|
||||||
|
:param str batching_mode: The batching mode to use. Options are
|
||||||
|
"common_batch_size", "separate_conditions", and "proportional".
|
||||||
|
:param device: The device to which the data should be moved.
|
||||||
|
"""
|
||||||
self.dataset_dict = dataset_dict
|
self.dataset_dict = dataset_dict
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_workers = num_workers
|
self.num_workers = num_workers
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
self.separate_conditions = separate_conditions
|
self.batching_mode = batching_mode.lower()
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
# Batch size None means we want to load the entire dataset in a single
|
# Batch size None means we want to load the entire dataset in a single
|
||||||
@@ -177,13 +187,13 @@ class PinaDataLoader:
|
|||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
# Compute batch size per dataset
|
# Compute batch size per dataset
|
||||||
if common_batch_size: # all datasets have the same batch size
|
if batching_mode in ["common_batch_size", "separate_conditions"]:
|
||||||
# (the sum of the batch sizes is equal to
|
# (the sum of the batch sizes is equal to
|
||||||
# n_conditions * batch_size)
|
# n_conditions * batch_size)
|
||||||
batch_size_per_dataset = {
|
batch_size_per_dataset = {
|
||||||
split: batch_size for split in dataset_dict.keys()
|
split: batch_size for split in dataset_dict.keys()
|
||||||
}
|
}
|
||||||
else: # batch size proportional to dataset size (the sum of the
|
elif batching_mode == "propotional":
|
||||||
# batch sizes is equal to the specified batch size)
|
# batch sizes is equal to the specified batch size)
|
||||||
batch_size_per_dataset = self._compute_batch_size()
|
batch_size_per_dataset = self._compute_batch_size()
|
||||||
|
|
||||||
@@ -242,6 +252,12 @@ class PinaDataLoader:
|
|||||||
def _create_dataloader(self, dataset, batch_size):
|
def _create_dataloader(self, dataset, batch_size):
|
||||||
"""
|
"""
|
||||||
Create the dataloader for the given dataset.
|
Create the dataloader for the given dataset.
|
||||||
|
|
||||||
|
:param PinaDataset dataset: The dataset for which to create the
|
||||||
|
dataloader.
|
||||||
|
:param int batch_size: The batch size for the dataloader.
|
||||||
|
:return: The created dataloader.
|
||||||
|
:rtype: :class:`torch.utils.data.DataLoader`
|
||||||
"""
|
"""
|
||||||
# If batch size is None, use DummyDataloader
|
# If batch size is None, use DummyDataloader
|
||||||
if batch_size is None or batch_size >= len(dataset):
|
if batch_size is None or batch_size >= len(dataset):
|
||||||
@@ -270,7 +286,7 @@ class PinaDataLoader:
|
|||||||
"""
|
"""
|
||||||
# If separate conditions, return sum of lengths of all dataloaders
|
# If separate conditions, return sum of lengths of all dataloaders
|
||||||
# else, return max length among dataloaders
|
# else, return max length among dataloaders
|
||||||
if self.separate_conditions:
|
if self.batching_mode == "separate_conditions":
|
||||||
return sum(len(dl) for dl in self.dataloaders.values())
|
return sum(len(dl) for dl in self.dataloaders.values())
|
||||||
return max(len(dl) for dl in self.dataloaders.values())
|
return max(len(dl) for dl in self.dataloaders.values())
|
||||||
|
|
||||||
@@ -280,7 +296,7 @@ class PinaDataLoader:
|
|||||||
:return: Yields batches from the dataloader.
|
:return: Yields batches from the dataloader.
|
||||||
:rtype: dict
|
:rtype: dict
|
||||||
"""
|
"""
|
||||||
if self.separate_conditions:
|
if self.batching_mode == "separate_conditions":
|
||||||
for split, dl in self.dataloaders.items():
|
for split, dl in self.dataloaders.items():
|
||||||
for batch in dl:
|
for batch in dl:
|
||||||
yield {split: batch}
|
yield {split: batch}
|
||||||
|
|||||||
@@ -31,8 +31,7 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
test_size=0.0,
|
test_size=0.0,
|
||||||
val_size=0.0,
|
val_size=0.0,
|
||||||
compile=None,
|
compile=None,
|
||||||
common_batch_size=True,
|
batching_mode="common_batch_size",
|
||||||
separate_conditions=False,
|
|
||||||
automatic_batching=None,
|
automatic_batching=None,
|
||||||
num_workers=None,
|
num_workers=None,
|
||||||
pin_memory=None,
|
pin_memory=None,
|
||||||
@@ -85,8 +84,7 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
train_size=train_size,
|
train_size=train_size,
|
||||||
test_size=test_size,
|
test_size=test_size,
|
||||||
val_size=val_size,
|
val_size=val_size,
|
||||||
common_batch_size=common_batch_size,
|
batching_mode=batching_mode,
|
||||||
seperate_conditions=separate_conditions,
|
|
||||||
automatic_batching=automatic_batching,
|
automatic_batching=automatic_batching,
|
||||||
compile=compile,
|
compile=compile,
|
||||||
)
|
)
|
||||||
@@ -141,8 +139,7 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
test_size=test_size,
|
test_size=test_size,
|
||||||
val_size=val_size,
|
val_size=val_size,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
common_batch_size=common_batch_size,
|
batching_mode=batching_mode,
|
||||||
seperate_conditions=separate_conditions,
|
|
||||||
automatic_batching=automatic_batching,
|
automatic_batching=automatic_batching,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
@@ -180,8 +177,7 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
test_size,
|
test_size,
|
||||||
val_size,
|
val_size,
|
||||||
batch_size,
|
batch_size,
|
||||||
common_batch_size,
|
batching_mode,
|
||||||
seperate_conditions,
|
|
||||||
automatic_batching,
|
automatic_batching,
|
||||||
pin_memory,
|
pin_memory,
|
||||||
num_workers,
|
num_workers,
|
||||||
@@ -233,8 +229,7 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
test_size=test_size,
|
test_size=test_size,
|
||||||
val_size=val_size,
|
val_size=val_size,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
common_batch_size=common_batch_size,
|
batching_mode=batching_mode,
|
||||||
separate_conditions=seperate_conditions,
|
|
||||||
automatic_batching=automatic_batching,
|
automatic_batching=automatic_batching,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
@@ -286,8 +281,7 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
train_size,
|
train_size,
|
||||||
test_size,
|
test_size,
|
||||||
val_size,
|
val_size,
|
||||||
common_batch_size,
|
batching_mode,
|
||||||
seperate_conditions,
|
|
||||||
automatic_batching,
|
automatic_batching,
|
||||||
compile,
|
compile,
|
||||||
):
|
):
|
||||||
@@ -314,8 +308,7 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
check_consistency(train_size, float)
|
check_consistency(train_size, float)
|
||||||
check_consistency(test_size, float)
|
check_consistency(test_size, float)
|
||||||
check_consistency(val_size, float)
|
check_consistency(val_size, float)
|
||||||
check_consistency(common_batch_size, bool)
|
check_consistency(batching_mode, str)
|
||||||
check_consistency(seperate_conditions, bool)
|
|
||||||
if automatic_batching is not None:
|
if automatic_batching is not None:
|
||||||
check_consistency(automatic_batching, bool)
|
check_consistency(automatic_batching, bool)
|
||||||
if compile is not None:
|
if compile is not None:
|
||||||
|
|||||||
@@ -159,7 +159,11 @@ def test_setup_test(input_, output_, train_size, val_size, test_size):
|
|||||||
[(input_tensor, output_tensor), (input_graph, output_graph)],
|
[(input_tensor, output_tensor), (input_graph, output_graph)],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("automatic_batching", [True, False])
|
@pytest.mark.parametrize("automatic_batching", [True, False])
|
||||||
def test_dataloader(input_, output_, automatic_batching):
|
@pytest.mark.parametrize("batch_size", [None, 10])
|
||||||
|
@pytest.mark.parametrize("batching_mode", ["common_batch_size", "propotional"])
|
||||||
|
def test_dataloader(
|
||||||
|
input_, output_, automatic_batching, batch_size, batching_mode
|
||||||
|
):
|
||||||
problem = SupervisedProblem(input_=input_, output_=output_)
|
problem = SupervisedProblem(input_=input_, output_=output_)
|
||||||
solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10))
|
solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10))
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
@@ -169,7 +173,7 @@ def test_dataloader(input_, output_, automatic_batching):
|
|||||||
val_size=0.3,
|
val_size=0.3,
|
||||||
test_size=0.0,
|
test_size=0.0,
|
||||||
automatic_batching=automatic_batching,
|
automatic_batching=automatic_batching,
|
||||||
common_batch_size=True,
|
batching_mode=batching_mode,
|
||||||
)
|
)
|
||||||
dm = trainer.data_module
|
dm = trainer.data_module
|
||||||
dm.setup()
|
dm.setup()
|
||||||
@@ -187,7 +191,7 @@ def test_dataloader(input_, output_, automatic_batching):
|
|||||||
|
|
||||||
dataloader = dm.val_dataloader()
|
dataloader = dm.val_dataloader()
|
||||||
assert isinstance(dataloader, PinaDataLoader)
|
assert isinstance(dataloader, PinaDataLoader)
|
||||||
assert len(dataloader) == 3
|
assert len(dataloader) == 3 if batch_size is not None else 1
|
||||||
data = next(iter(dataloader))
|
data = next(iter(dataloader))
|
||||||
assert isinstance(data, dict)
|
assert isinstance(data, dict)
|
||||||
if isinstance(input_, list):
|
if isinstance(input_, list):
|
||||||
@@ -225,7 +229,7 @@ def test_dataloader_labels(input_, output_, automatic_batching):
|
|||||||
val_size=0.3,
|
val_size=0.3,
|
||||||
test_size=0.0,
|
test_size=0.0,
|
||||||
automatic_batching=automatic_batching,
|
automatic_batching=automatic_batching,
|
||||||
common_batch_size=True,
|
# common_batch_size=True,
|
||||||
)
|
)
|
||||||
dm = trainer.data_module
|
dm = trainer.data_module
|
||||||
dm.setup()
|
dm.setup()
|
||||||
|
|||||||
@@ -117,6 +117,10 @@ def test_solver_train(use_lt, batch_size, compile):
|
|||||||
assert isinstance(solver.model, OptimizedModule)
|
assert isinstance(solver.model, OptimizedModule)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_solver_train(use_lt=True, batch_size=20, compile=True)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
|
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
|
||||||
@pytest.mark.parametrize("use_lt", [True, False])
|
@pytest.mark.parametrize("use_lt", [True, False])
|
||||||
def test_solver_train_graph(batch_size, use_lt):
|
def test_solver_train_graph(batch_size, use_lt):
|
||||||
|
|||||||
Reference in New Issue
Block a user