Add docstring for repeat in DataModule
This commit is contained in:
@@ -283,10 +283,20 @@ class PinaDataModule(LightningDataModule):
|
||||
Default is ``None``.
|
||||
:param bool shuffle: Whether to shuffle the dataset before splitting.
|
||||
Default ``True``.
|
||||
:param bool repeat: Whether to repeat the dataset indefinitely.
|
||||
Default ``False``.
|
||||
:param automatic_batching: Whether to enable automatic batching.
|
||||
Default ``False``.
|
||||
:param bool repeat: If ``True``, in case of batch size larger than the
|
||||
number of elements in a specific condition, the elements are
|
||||
repeated until the batch size is reached. If ``False``, the number
|
||||
of elements in the batch is the minimum between the batch size and
|
||||
the number of elements in the condition. Default is ``False``.
|
||||
:param automatic_batching: If ``True``, automatic PyTorch batching
|
||||
is performed, which consists of extracting one element at a time
|
||||
from the dataset and collating them into a batch. This is useful
|
||||
when the dataset is too large to fit into memory. On the other hand,
|
||||
if ``False``, the items are retrieved from the dataset all at once
|
||||
avoind the overhead of collating them into a batch and reducing the
|
||||
__getitem__ calls to the dataset. This is useful when the dataset
|
||||
fits into memory. Avoid using automatic batching when ``batch_size``
|
||||
is large. Default is ``False``.
|
||||
:param int num_workers: Number of worker threads for data loading.
|
||||
Default ``0`` (serial loading).
|
||||
:param bool pin_memory: Whether to use pinned memory for faster data
|
||||
|
||||
@@ -26,6 +26,7 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
test_size=0.0,
|
||||
val_size=0.0,
|
||||
compile=None,
|
||||
repeat=False,
|
||||
automatic_batching=None,
|
||||
num_workers=None,
|
||||
pin_memory=None,
|
||||
@@ -49,9 +50,13 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
validation dataset. Default is ``0.0``.
|
||||
:param bool compile: If ``True``, the model is compiled before training.
|
||||
Default is ``False``. For Windows users, it is always disabled.
|
||||
:param bool repeat: Whether to repeat the dataset data in each
|
||||
condition during training. For further details, see the
|
||||
:class:`~pina.data.PinaDataModule` class. Default is ``False``.
|
||||
:param bool automatic_batching: If ``True``, automatic PyTorch batching
|
||||
is performed. Avoid using automatic batching when ``batch_size`` is
|
||||
large. Default is ``False``.
|
||||
is performed, otherwise the items are retrieved from the dataset
|
||||
all at once. For further details, see the
|
||||
:class:`~pina.data.PinaDataModule` class. Default is ``False``.
|
||||
:param int num_workers: The number of worker threads for data loading.
|
||||
Default is ``0`` (serial loading).
|
||||
:param bool pin_memory: Whether to use pinned memory for faster data
|
||||
@@ -65,12 +70,13 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
"""
|
||||
# check consistency for init types
|
||||
self._check_input_consistency(
|
||||
solver,
|
||||
train_size,
|
||||
test_size,
|
||||
val_size,
|
||||
automatic_batching,
|
||||
compile,
|
||||
solver=solver,
|
||||
train_size=train_size,
|
||||
test_size=test_size,
|
||||
val_size=val_size,
|
||||
repeat=repeat,
|
||||
automatic_batching=automatic_batching,
|
||||
compile=compile,
|
||||
)
|
||||
pin_memory, num_workers, shuffle, batch_size = (
|
||||
self._check_consistency_and_set_defaults(
|
||||
@@ -110,14 +116,15 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
self._move_to_device()
|
||||
self.data_module = None
|
||||
self._create_datamodule(
|
||||
train_size,
|
||||
test_size,
|
||||
val_size,
|
||||
batch_size,
|
||||
automatic_batching,
|
||||
pin_memory,
|
||||
num_workers,
|
||||
shuffle,
|
||||
train_size=train_size,
|
||||
test_size=test_size,
|
||||
val_size=val_size,
|
||||
batch_size=batch_size,
|
||||
repeat=repeat,
|
||||
automatic_batching=automatic_batching,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
shuffle=shuffle,
|
||||
)
|
||||
|
||||
# logging
|
||||
@@ -151,6 +158,7 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
test_size,
|
||||
val_size,
|
||||
batch_size,
|
||||
repeat,
|
||||
automatic_batching,
|
||||
pin_memory,
|
||||
num_workers,
|
||||
@@ -169,6 +177,8 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
:param float val_size: The percentage of elements to include in the
|
||||
validation dataset.
|
||||
:param int batch_size: The number of samples per batch to load.
|
||||
:param bool repeat: Whether to repeat the dataset data in each
|
||||
condition during training.
|
||||
:param bool automatic_batching: Whether to perform automatic batching
|
||||
with PyTorch. If ``True``, automatic PyTorch batching
|
||||
is performed, which consists of extracting one element at a time
|
||||
@@ -206,6 +216,7 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
test_size=test_size,
|
||||
val_size=val_size,
|
||||
batch_size=batch_size,
|
||||
repeat=repeat,
|
||||
automatic_batching=automatic_batching,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory,
|
||||
@@ -253,7 +264,13 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
|
||||
@staticmethod
|
||||
def _check_input_consistency(
|
||||
solver, train_size, test_size, val_size, automatic_batching, compile
|
||||
solver,
|
||||
train_size,
|
||||
test_size,
|
||||
val_size,
|
||||
repeat,
|
||||
automatic_batching,
|
||||
compile,
|
||||
):
|
||||
"""
|
||||
Verifies the consistency of the parameters for the solver configuration.
|
||||
@@ -265,6 +282,8 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
test dataset.
|
||||
:param float val_size: The percentage of elements to include in the
|
||||
validation dataset.
|
||||
:param bool repeat: Whether to repeat the dataset data in each
|
||||
condition during training.
|
||||
:param bool automatic_batching: Whether to perform automatic batching
|
||||
with PyTorch.
|
||||
:param bool compile: If ``True``, the model is compiled before training.
|
||||
@@ -274,6 +293,7 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
check_consistency(train_size, float)
|
||||
check_consistency(test_size, float)
|
||||
check_consistency(val_size, float)
|
||||
check_consistency(repeat, bool)
|
||||
if automatic_batching is not None:
|
||||
check_consistency(automatic_batching, bool)
|
||||
if compile is not None:
|
||||
|
||||
Reference in New Issue
Block a user