From 1b2154d8fb9c4db57b01afda3cb0fddfe8fcf5d4 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Mon, 17 Mar 2025 13:47:24 +0100 Subject: [PATCH] Add docstring for repeat in DataModule --- pina/data/data_module.py | 18 +++++++++++--- pina/trainer.py | 54 +++++++++++++++++++++++++++------------- 2 files changed, 51 insertions(+), 21 deletions(-) diff --git a/pina/data/data_module.py b/pina/data/data_module.py index 6f3c751..1401613 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -283,10 +283,20 @@ class PinaDataModule(LightningDataModule): Default is ``None``. :param bool shuffle: Whether to shuffle the dataset before splitting. Default ``True``. - :param bool repeat: Whether to repeat the dataset indefinitely. - Default ``False``. - :param automatic_batching: Whether to enable automatic batching. - Default ``False``. + :param bool repeat: If ``True``, in case of batch size larger than the + number of elements in a specific condition, the elements are + repeated until the batch size is reached. If ``False``, the number + of elements in the batch is the minimum between the batch size and + the number of elements in the condition. Default is ``False``. + :param automatic_batching: If ``True``, automatic PyTorch batching + is performed, which consists of extracting one element at a time + from the dataset and collating them into a batch. This is useful + when the dataset is too large to fit into memory. On the other hand, + if ``False``, the items are retrieved from the dataset all at once + avoind the overhead of collating them into a batch and reducing the + __getitem__ calls to the dataset. This is useful when the dataset + fits into memory. Avoid using automatic batching when ``batch_size`` + is large. Default is ``False``. :param int num_workers: Number of worker threads for data loading. Default ``0`` (serial loading). :param bool pin_memory: Whether to use pinned memory for faster data diff --git a/pina/trainer.py b/pina/trainer.py index a29152c..e5515d0 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -26,6 +26,7 @@ class Trainer(lightning.pytorch.Trainer): test_size=0.0, val_size=0.0, compile=None, + repeat=False, automatic_batching=None, num_workers=None, pin_memory=None, @@ -49,9 +50,13 @@ class Trainer(lightning.pytorch.Trainer): validation dataset. Default is ``0.0``. :param bool compile: If ``True``, the model is compiled before training. Default is ``False``. For Windows users, it is always disabled. + :param bool repeat: Whether to repeat the dataset data in each + condition during training. For further details, see the + :class:`~pina.data.PinaDataModule` class. Default is ``False``. :param bool automatic_batching: If ``True``, automatic PyTorch batching - is performed. Avoid using automatic batching when ``batch_size`` is - large. Default is ``False``. + is performed, otherwise the items are retrieved from the dataset + all at once. For further details, see the + :class:`~pina.data.PinaDataModule` class. Default is ``False``. :param int num_workers: The number of worker threads for data loading. Default is ``0`` (serial loading). :param bool pin_memory: Whether to use pinned memory for faster data @@ -65,12 +70,13 @@ class Trainer(lightning.pytorch.Trainer): """ # check consistency for init types self._check_input_consistency( - solver, - train_size, - test_size, - val_size, - automatic_batching, - compile, + solver=solver, + train_size=train_size, + test_size=test_size, + val_size=val_size, + repeat=repeat, + automatic_batching=automatic_batching, + compile=compile, ) pin_memory, num_workers, shuffle, batch_size = ( self._check_consistency_and_set_defaults( @@ -110,14 +116,15 @@ class Trainer(lightning.pytorch.Trainer): self._move_to_device() self.data_module = None self._create_datamodule( - train_size, - test_size, - val_size, - batch_size, - automatic_batching, - pin_memory, - num_workers, - shuffle, + train_size=train_size, + test_size=test_size, + val_size=val_size, + batch_size=batch_size, + repeat=repeat, + automatic_batching=automatic_batching, + pin_memory=pin_memory, + num_workers=num_workers, + shuffle=shuffle, ) # logging @@ -151,6 +158,7 @@ class Trainer(lightning.pytorch.Trainer): test_size, val_size, batch_size, + repeat, automatic_batching, pin_memory, num_workers, @@ -169,6 +177,8 @@ class Trainer(lightning.pytorch.Trainer): :param float val_size: The percentage of elements to include in the validation dataset. :param int batch_size: The number of samples per batch to load. + :param bool repeat: Whether to repeat the dataset data in each + condition during training. :param bool automatic_batching: Whether to perform automatic batching with PyTorch. If ``True``, automatic PyTorch batching is performed, which consists of extracting one element at a time @@ -206,6 +216,7 @@ class Trainer(lightning.pytorch.Trainer): test_size=test_size, val_size=val_size, batch_size=batch_size, + repeat=repeat, automatic_batching=automatic_batching, num_workers=num_workers, pin_memory=pin_memory, @@ -253,7 +264,13 @@ class Trainer(lightning.pytorch.Trainer): @staticmethod def _check_input_consistency( - solver, train_size, test_size, val_size, automatic_batching, compile + solver, + train_size, + test_size, + val_size, + repeat, + automatic_batching, + compile, ): """ Verifies the consistency of the parameters for the solver configuration. @@ -265,6 +282,8 @@ class Trainer(lightning.pytorch.Trainer): test dataset. :param float val_size: The percentage of elements to include in the validation dataset. + :param bool repeat: Whether to repeat the dataset data in each + condition during training. :param bool automatic_batching: Whether to perform automatic batching with PyTorch. :param bool compile: If ``True``, the model is compiled before training. @@ -274,6 +293,7 @@ class Trainer(lightning.pytorch.Trainer): check_consistency(train_size, float) check_consistency(test_size, float) check_consistency(val_size, float) + check_consistency(repeat, bool) if automatic_batching is not None: check_consistency(automatic_batching, bool) if compile is not None: