Add docstring for repeat in DataModule

This commit is contained in:
FilippoOlivo
2025-03-17 13:47:24 +01:00
committed by Nicola Demo
parent 7f89c4f852
commit 1b2154d8fb
2 changed files with 51 additions and 21 deletions

View File

@@ -283,10 +283,20 @@ class PinaDataModule(LightningDataModule):
Default is ``None``.
:param bool shuffle: Whether to shuffle the dataset before splitting.
Default ``True``.
:param bool repeat: Whether to repeat the dataset indefinitely.
Default ``False``.
:param automatic_batching: Whether to enable automatic batching.
Default ``False``.
:param bool repeat: If ``True``, in case of batch size larger than the
number of elements in a specific condition, the elements are
repeated until the batch size is reached. If ``False``, the number
of elements in the batch is the minimum between the batch size and
the number of elements in the condition. Default is ``False``.
:param automatic_batching: If ``True``, automatic PyTorch batching
is performed, which consists of extracting one element at a time
from the dataset and collating them into a batch. This is useful
when the dataset is too large to fit into memory. On the other hand,
if ``False``, the items are retrieved from the dataset all at once
avoind the overhead of collating them into a batch and reducing the
__getitem__ calls to the dataset. This is useful when the dataset
fits into memory. Avoid using automatic batching when ``batch_size``
is large. Default is ``False``.
:param int num_workers: Number of worker threads for data loading.
Default ``0`` (serial loading).
:param bool pin_memory: Whether to use pinned memory for faster data

View File

@@ -26,6 +26,7 @@ class Trainer(lightning.pytorch.Trainer):
test_size=0.0,
val_size=0.0,
compile=None,
repeat=False,
automatic_batching=None,
num_workers=None,
pin_memory=None,
@@ -49,9 +50,13 @@ class Trainer(lightning.pytorch.Trainer):
validation dataset. Default is ``0.0``.
:param bool compile: If ``True``, the model is compiled before training.
Default is ``False``. For Windows users, it is always disabled.
:param bool repeat: Whether to repeat the dataset data in each
condition during training. For further details, see the
:class:`~pina.data.PinaDataModule` class. Default is ``False``.
:param bool automatic_batching: If ``True``, automatic PyTorch batching
is performed. Avoid using automatic batching when ``batch_size`` is
large. Default is ``False``.
is performed, otherwise the items are retrieved from the dataset
all at once. For further details, see the
:class:`~pina.data.PinaDataModule` class. Default is ``False``.
:param int num_workers: The number of worker threads for data loading.
Default is ``0`` (serial loading).
:param bool pin_memory: Whether to use pinned memory for faster data
@@ -65,12 +70,13 @@ class Trainer(lightning.pytorch.Trainer):
"""
# check consistency for init types
self._check_input_consistency(
solver,
train_size,
test_size,
val_size,
automatic_batching,
compile,
solver=solver,
train_size=train_size,
test_size=test_size,
val_size=val_size,
repeat=repeat,
automatic_batching=automatic_batching,
compile=compile,
)
pin_memory, num_workers, shuffle, batch_size = (
self._check_consistency_and_set_defaults(
@@ -110,14 +116,15 @@ class Trainer(lightning.pytorch.Trainer):
self._move_to_device()
self.data_module = None
self._create_datamodule(
train_size,
test_size,
val_size,
batch_size,
automatic_batching,
pin_memory,
num_workers,
shuffle,
train_size=train_size,
test_size=test_size,
val_size=val_size,
batch_size=batch_size,
repeat=repeat,
automatic_batching=automatic_batching,
pin_memory=pin_memory,
num_workers=num_workers,
shuffle=shuffle,
)
# logging
@@ -151,6 +158,7 @@ class Trainer(lightning.pytorch.Trainer):
test_size,
val_size,
batch_size,
repeat,
automatic_batching,
pin_memory,
num_workers,
@@ -169,6 +177,8 @@ class Trainer(lightning.pytorch.Trainer):
:param float val_size: The percentage of elements to include in the
validation dataset.
:param int batch_size: The number of samples per batch to load.
:param bool repeat: Whether to repeat the dataset data in each
condition during training.
:param bool automatic_batching: Whether to perform automatic batching
with PyTorch. If ``True``, automatic PyTorch batching
is performed, which consists of extracting one element at a time
@@ -206,6 +216,7 @@ class Trainer(lightning.pytorch.Trainer):
test_size=test_size,
val_size=val_size,
batch_size=batch_size,
repeat=repeat,
automatic_batching=automatic_batching,
num_workers=num_workers,
pin_memory=pin_memory,
@@ -253,7 +264,13 @@ class Trainer(lightning.pytorch.Trainer):
@staticmethod
def _check_input_consistency(
solver, train_size, test_size, val_size, automatic_batching, compile
solver,
train_size,
test_size,
val_size,
repeat,
automatic_batching,
compile,
):
"""
Verifies the consistency of the parameters for the solver configuration.
@@ -265,6 +282,8 @@ class Trainer(lightning.pytorch.Trainer):
test dataset.
:param float val_size: The percentage of elements to include in the
validation dataset.
:param bool repeat: Whether to repeat the dataset data in each
condition during training.
:param bool automatic_batching: Whether to perform automatic batching
with PyTorch.
:param bool compile: If ``True``, the model is compiled before training.
@@ -274,6 +293,7 @@ class Trainer(lightning.pytorch.Trainer):
check_consistency(train_size, float)
check_consistency(test_size, float)
check_consistency(val_size, float)
check_consistency(repeat, bool)
if automatic_batching is not None:
check_consistency(automatic_batching, bool)
if compile is not None: