Minor fix

This commit is contained in:
FilippoOlivo
2025-03-17 22:23:34 +01:00
parent b92f39aead
commit e90be726da
2 changed files with 17 additions and 27 deletions

View File

@@ -26,7 +26,7 @@ class Trainer(lightning.pytorch.Trainer):
test_size=0.0,
val_size=0.0,
compile=None,
repeat=False,
repeat=None,
automatic_batching=None,
num_workers=None,
pin_memory=None,
@@ -52,11 +52,13 @@ class Trainer(lightning.pytorch.Trainer):
Default is ``False``. For Windows users, it is always disabled.
:param bool repeat: Whether to repeat the dataset data in each
condition during training. For further details, see the
:class:`~pina.data.PinaDataModule` class. Default is ``False``.
:class:`~pina.data.data_module.PinaDataModule` class. Default is
``False``.
:param bool automatic_batching: If ``True``, automatic PyTorch batching
is performed, otherwise the items are retrieved from the dataset
all at once. For further details, see the
:class:`~pina.data.PinaDataModule` class. Default is ``False``.
:class:`~pina.data.data_module.PinaDataModule` class. Default is
``False``.
:param int num_workers: The number of worker threads for data loading.
Default is ``0`` (serial loading).
:param bool pin_memory: Whether to use pinned memory for faster data
@@ -105,7 +107,9 @@ class Trainer(lightning.pytorch.Trainer):
if compile is None or sys.platform == "win32":
compile = False
self.automatic_batching = (
repeat = repeat if repeat is not None else False
automatic_batching = (
automatic_batching if automatic_batching is not None else False
)
@@ -180,15 +184,7 @@ class Trainer(lightning.pytorch.Trainer):
:param bool repeat: Whether to repeat the dataset data in each
condition during training.
:param bool automatic_batching: Whether to perform automatic batching
with PyTorch. If ``True``, automatic PyTorch batching
is performed, which consists of extracting one element at a time
from the dataset and collating them into a batch. This is useful
when the dataset is too large to fit into memory. On the other hand,
if ``False``, the items are retrieved from the dataset all at once
avoind the overhead of collating them into a batch and reducing the
__getitem__ calls to the dataset. This is useful when the dataset
fits into memory. Avoid using automatic batching when ``batch_size``
is large. Default is ``False``.
with PyTorch.
:param bool pin_memory: Whether to use pinned memory for faster data
transfer to GPU.
:param int num_workers: The number of worker threads for data loading.
@@ -293,7 +289,8 @@ class Trainer(lightning.pytorch.Trainer):
check_consistency(train_size, float)
check_consistency(test_size, float)
check_consistency(val_size, float)
check_consistency(repeat, bool)
if repeat is not None:
check_consistency(repeat, bool)
if automatic_batching is not None:
check_consistency(automatic_batching, bool)
if compile is not None: