From 070b5137ba9dc495df4a02373e70ffa9a641f6b0 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Mon, 17 Mar 2025 22:23:34 +0100 Subject: [PATCH] Minor fix --- pina/data/data_module.py | 19 ++++++------------- pina/trainer.py | 25 +++++++++++-------------- 2 files changed, 17 insertions(+), 27 deletions(-) diff --git a/pina/data/data_module.py b/pina/data/data_module.py index 1401613..349d74d 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -81,16 +81,9 @@ class Collator: :param dict max_conditions_lengths: ``dict`` containing the maximum number of data points to consider in a single batch for each condition. - :param bool automatic_batching: Whether to enable automatic batching. - If ``True``, automatic PyTorch batching - is performed, which consists of extracting one element at a time - from the dataset and collating them into a batch. This is useful - when the dataset is too large to fit into memory. On the other hand, - if ``False``, the items are retrieved from the dataset all at once - avoind the overhead of collating them into a batch and reducing the - __getitem__ calls to the dataset. This is useful when the dataset - fits into memory. Avoid using automatic batching when ``batch_size`` - is large. Default is ``False``. + :param bool automatic_batching: Whether automatic PyTorch batching is + enabled or not. For more information, see the + :class:`~pina.data.data_module.PinaDataModule` class. :param PinaDataset dataset: The dataset where the data is stored. """ @@ -294,9 +287,9 @@ class PinaDataModule(LightningDataModule): when the dataset is too large to fit into memory. On the other hand, if ``False``, the items are retrieved from the dataset all at once avoind the overhead of collating them into a batch and reducing the - __getitem__ calls to the dataset. This is useful when the dataset - fits into memory. Avoid using automatic batching when ``batch_size`` - is large. Default is ``False``. + ``__getitem__`` calls to the dataset. This is useful when the + dataset fits into memory. Avoid using automatic batching when + ``batch_size`` is large. Default is ``False``. :param int num_workers: Number of worker threads for data loading. Default ``0`` (serial loading). :param bool pin_memory: Whether to use pinned memory for faster data diff --git a/pina/trainer.py b/pina/trainer.py index e5515d0..78dd77a 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -26,7 +26,7 @@ class Trainer(lightning.pytorch.Trainer): test_size=0.0, val_size=0.0, compile=None, - repeat=False, + repeat=None, automatic_batching=None, num_workers=None, pin_memory=None, @@ -52,11 +52,13 @@ class Trainer(lightning.pytorch.Trainer): Default is ``False``. For Windows users, it is always disabled. :param bool repeat: Whether to repeat the dataset data in each condition during training. For further details, see the - :class:`~pina.data.PinaDataModule` class. Default is ``False``. + :class:`~pina.data.data_module.PinaDataModule` class. Default is + ``False``. :param bool automatic_batching: If ``True``, automatic PyTorch batching is performed, otherwise the items are retrieved from the dataset all at once. For further details, see the - :class:`~pina.data.PinaDataModule` class. Default is ``False``. + :class:`~pina.data.data_module.PinaDataModule` class. Default is + ``False``. :param int num_workers: The number of worker threads for data loading. Default is ``0`` (serial loading). :param bool pin_memory: Whether to use pinned memory for faster data @@ -105,7 +107,9 @@ class Trainer(lightning.pytorch.Trainer): if compile is None or sys.platform == "win32": compile = False - self.automatic_batching = ( + repeat = repeat if repeat is not None else False + + automatic_batching = ( automatic_batching if automatic_batching is not None else False ) @@ -180,15 +184,7 @@ class Trainer(lightning.pytorch.Trainer): :param bool repeat: Whether to repeat the dataset data in each condition during training. :param bool automatic_batching: Whether to perform automatic batching - with PyTorch. If ``True``, automatic PyTorch batching - is performed, which consists of extracting one element at a time - from the dataset and collating them into a batch. This is useful - when the dataset is too large to fit into memory. On the other hand, - if ``False``, the items are retrieved from the dataset all at once - avoind the overhead of collating them into a batch and reducing the - __getitem__ calls to the dataset. This is useful when the dataset - fits into memory. Avoid using automatic batching when ``batch_size`` - is large. Default is ``False``. + with PyTorch. :param bool pin_memory: Whether to use pinned memory for faster data transfer to GPU. :param int num_workers: The number of worker threads for data loading. @@ -293,7 +289,8 @@ class Trainer(lightning.pytorch.Trainer): check_consistency(train_size, float) check_consistency(test_size, float) check_consistency(val_size, float) - check_consistency(repeat, bool) + if repeat is not None: + check_consistency(repeat, bool) if automatic_batching is not None: check_consistency(automatic_batching, bool) if compile is not None: