Add docstring for repeat in DataModule
This commit is contained in:
committed by
Nicola Demo
parent
7f89c4f852
commit
1b2154d8fb
@@ -283,10 +283,20 @@ class PinaDataModule(LightningDataModule):
|
|||||||
Default is ``None``.
|
Default is ``None``.
|
||||||
:param bool shuffle: Whether to shuffle the dataset before splitting.
|
:param bool shuffle: Whether to shuffle the dataset before splitting.
|
||||||
Default ``True``.
|
Default ``True``.
|
||||||
:param bool repeat: Whether to repeat the dataset indefinitely.
|
:param bool repeat: If ``True``, in case of batch size larger than the
|
||||||
Default ``False``.
|
number of elements in a specific condition, the elements are
|
||||||
:param automatic_batching: Whether to enable automatic batching.
|
repeated until the batch size is reached. If ``False``, the number
|
||||||
Default ``False``.
|
of elements in the batch is the minimum between the batch size and
|
||||||
|
the number of elements in the condition. Default is ``False``.
|
||||||
|
:param automatic_batching: If ``True``, automatic PyTorch batching
|
||||||
|
is performed, which consists of extracting one element at a time
|
||||||
|
from the dataset and collating them into a batch. This is useful
|
||||||
|
when the dataset is too large to fit into memory. On the other hand,
|
||||||
|
if ``False``, the items are retrieved from the dataset all at once
|
||||||
|
avoind the overhead of collating them into a batch and reducing the
|
||||||
|
__getitem__ calls to the dataset. This is useful when the dataset
|
||||||
|
fits into memory. Avoid using automatic batching when ``batch_size``
|
||||||
|
is large. Default is ``False``.
|
||||||
:param int num_workers: Number of worker threads for data loading.
|
:param int num_workers: Number of worker threads for data loading.
|
||||||
Default ``0`` (serial loading).
|
Default ``0`` (serial loading).
|
||||||
:param bool pin_memory: Whether to use pinned memory for faster data
|
:param bool pin_memory: Whether to use pinned memory for faster data
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
test_size=0.0,
|
test_size=0.0,
|
||||||
val_size=0.0,
|
val_size=0.0,
|
||||||
compile=None,
|
compile=None,
|
||||||
|
repeat=False,
|
||||||
automatic_batching=None,
|
automatic_batching=None,
|
||||||
num_workers=None,
|
num_workers=None,
|
||||||
pin_memory=None,
|
pin_memory=None,
|
||||||
@@ -49,9 +50,13 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
validation dataset. Default is ``0.0``.
|
validation dataset. Default is ``0.0``.
|
||||||
:param bool compile: If ``True``, the model is compiled before training.
|
:param bool compile: If ``True``, the model is compiled before training.
|
||||||
Default is ``False``. For Windows users, it is always disabled.
|
Default is ``False``. For Windows users, it is always disabled.
|
||||||
|
:param bool repeat: Whether to repeat the dataset data in each
|
||||||
|
condition during training. For further details, see the
|
||||||
|
:class:`~pina.data.PinaDataModule` class. Default is ``False``.
|
||||||
:param bool automatic_batching: If ``True``, automatic PyTorch batching
|
:param bool automatic_batching: If ``True``, automatic PyTorch batching
|
||||||
is performed. Avoid using automatic batching when ``batch_size`` is
|
is performed, otherwise the items are retrieved from the dataset
|
||||||
large. Default is ``False``.
|
all at once. For further details, see the
|
||||||
|
:class:`~pina.data.PinaDataModule` class. Default is ``False``.
|
||||||
:param int num_workers: The number of worker threads for data loading.
|
:param int num_workers: The number of worker threads for data loading.
|
||||||
Default is ``0`` (serial loading).
|
Default is ``0`` (serial loading).
|
||||||
:param bool pin_memory: Whether to use pinned memory for faster data
|
:param bool pin_memory: Whether to use pinned memory for faster data
|
||||||
@@ -65,12 +70,13 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
"""
|
"""
|
||||||
# check consistency for init types
|
# check consistency for init types
|
||||||
self._check_input_consistency(
|
self._check_input_consistency(
|
||||||
solver,
|
solver=solver,
|
||||||
train_size,
|
train_size=train_size,
|
||||||
test_size,
|
test_size=test_size,
|
||||||
val_size,
|
val_size=val_size,
|
||||||
automatic_batching,
|
repeat=repeat,
|
||||||
compile,
|
automatic_batching=automatic_batching,
|
||||||
|
compile=compile,
|
||||||
)
|
)
|
||||||
pin_memory, num_workers, shuffle, batch_size = (
|
pin_memory, num_workers, shuffle, batch_size = (
|
||||||
self._check_consistency_and_set_defaults(
|
self._check_consistency_and_set_defaults(
|
||||||
@@ -110,14 +116,15 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
self._move_to_device()
|
self._move_to_device()
|
||||||
self.data_module = None
|
self.data_module = None
|
||||||
self._create_datamodule(
|
self._create_datamodule(
|
||||||
train_size,
|
train_size=train_size,
|
||||||
test_size,
|
test_size=test_size,
|
||||||
val_size,
|
val_size=val_size,
|
||||||
batch_size,
|
batch_size=batch_size,
|
||||||
automatic_batching,
|
repeat=repeat,
|
||||||
pin_memory,
|
automatic_batching=automatic_batching,
|
||||||
num_workers,
|
pin_memory=pin_memory,
|
||||||
shuffle,
|
num_workers=num_workers,
|
||||||
|
shuffle=shuffle,
|
||||||
)
|
)
|
||||||
|
|
||||||
# logging
|
# logging
|
||||||
@@ -151,6 +158,7 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
test_size,
|
test_size,
|
||||||
val_size,
|
val_size,
|
||||||
batch_size,
|
batch_size,
|
||||||
|
repeat,
|
||||||
automatic_batching,
|
automatic_batching,
|
||||||
pin_memory,
|
pin_memory,
|
||||||
num_workers,
|
num_workers,
|
||||||
@@ -169,6 +177,8 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
:param float val_size: The percentage of elements to include in the
|
:param float val_size: The percentage of elements to include in the
|
||||||
validation dataset.
|
validation dataset.
|
||||||
:param int batch_size: The number of samples per batch to load.
|
:param int batch_size: The number of samples per batch to load.
|
||||||
|
:param bool repeat: Whether to repeat the dataset data in each
|
||||||
|
condition during training.
|
||||||
:param bool automatic_batching: Whether to perform automatic batching
|
:param bool automatic_batching: Whether to perform automatic batching
|
||||||
with PyTorch. If ``True``, automatic PyTorch batching
|
with PyTorch. If ``True``, automatic PyTorch batching
|
||||||
is performed, which consists of extracting one element at a time
|
is performed, which consists of extracting one element at a time
|
||||||
@@ -206,6 +216,7 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
test_size=test_size,
|
test_size=test_size,
|
||||||
val_size=val_size,
|
val_size=val_size,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
repeat=repeat,
|
||||||
automatic_batching=automatic_batching,
|
automatic_batching=automatic_batching,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
@@ -253,7 +264,13 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _check_input_consistency(
|
def _check_input_consistency(
|
||||||
solver, train_size, test_size, val_size, automatic_batching, compile
|
solver,
|
||||||
|
train_size,
|
||||||
|
test_size,
|
||||||
|
val_size,
|
||||||
|
repeat,
|
||||||
|
automatic_batching,
|
||||||
|
compile,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Verifies the consistency of the parameters for the solver configuration.
|
Verifies the consistency of the parameters for the solver configuration.
|
||||||
@@ -265,6 +282,8 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
test dataset.
|
test dataset.
|
||||||
:param float val_size: The percentage of elements to include in the
|
:param float val_size: The percentage of elements to include in the
|
||||||
validation dataset.
|
validation dataset.
|
||||||
|
:param bool repeat: Whether to repeat the dataset data in each
|
||||||
|
condition during training.
|
||||||
:param bool automatic_batching: Whether to perform automatic batching
|
:param bool automatic_batching: Whether to perform automatic batching
|
||||||
with PyTorch.
|
with PyTorch.
|
||||||
:param bool compile: If ``True``, the model is compiled before training.
|
:param bool compile: If ``True``, the model is compiled before training.
|
||||||
@@ -274,6 +293,7 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
check_consistency(train_size, float)
|
check_consistency(train_size, float)
|
||||||
check_consistency(test_size, float)
|
check_consistency(test_size, float)
|
||||||
check_consistency(val_size, float)
|
check_consistency(val_size, float)
|
||||||
|
check_consistency(repeat, bool)
|
||||||
if automatic_batching is not None:
|
if automatic_batching is not None:
|
||||||
check_consistency(automatic_batching, bool)
|
check_consistency(automatic_batching, bool)
|
||||||
if compile is not None:
|
if compile is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user