Minor fix

This commit is contained in:
FilippoOlivo
2025-03-17 22:23:34 +01:00
parent b92f39aead
commit e90be726da
2 changed files with 17 additions and 27 deletions

View File

@@ -81,16 +81,9 @@ class Collator:
:param dict max_conditions_lengths: ``dict`` containing the maximum :param dict max_conditions_lengths: ``dict`` containing the maximum
number of data points to consider in a single batch for number of data points to consider in a single batch for
each condition. each condition.
:param bool automatic_batching: Whether to enable automatic batching. :param bool automatic_batching: Whether automatic PyTorch batching is
If ``True``, automatic PyTorch batching enabled or not. For more information, see the
is performed, which consists of extracting one element at a time :class:`~pina.data.data_module.PinaDataModule` class.
from the dataset and collating them into a batch. This is useful
when the dataset is too large to fit into memory. On the other hand,
if ``False``, the items are retrieved from the dataset all at once
avoind the overhead of collating them into a batch and reducing the
__getitem__ calls to the dataset. This is useful when the dataset
fits into memory. Avoid using automatic batching when ``batch_size``
is large. Default is ``False``.
:param PinaDataset dataset: The dataset where the data is stored. :param PinaDataset dataset: The dataset where the data is stored.
""" """
@@ -294,9 +287,9 @@ class PinaDataModule(LightningDataModule):
when the dataset is too large to fit into memory. On the other hand, when the dataset is too large to fit into memory. On the other hand,
if ``False``, the items are retrieved from the dataset all at once if ``False``, the items are retrieved from the dataset all at once
avoind the overhead of collating them into a batch and reducing the avoind the overhead of collating them into a batch and reducing the
__getitem__ calls to the dataset. This is useful when the dataset ``__getitem__`` calls to the dataset. This is useful when the
fits into memory. Avoid using automatic batching when ``batch_size`` dataset fits into memory. Avoid using automatic batching when
is large. Default is ``False``. ``batch_size`` is large. Default is ``False``.
:param int num_workers: Number of worker threads for data loading. :param int num_workers: Number of worker threads for data loading.
Default ``0`` (serial loading). Default ``0`` (serial loading).
:param bool pin_memory: Whether to use pinned memory for faster data :param bool pin_memory: Whether to use pinned memory for faster data

View File

@@ -26,7 +26,7 @@ class Trainer(lightning.pytorch.Trainer):
test_size=0.0, test_size=0.0,
val_size=0.0, val_size=0.0,
compile=None, compile=None,
repeat=False, repeat=None,
automatic_batching=None, automatic_batching=None,
num_workers=None, num_workers=None,
pin_memory=None, pin_memory=None,
@@ -52,11 +52,13 @@ class Trainer(lightning.pytorch.Trainer):
Default is ``False``. For Windows users, it is always disabled. Default is ``False``. For Windows users, it is always disabled.
:param bool repeat: Whether to repeat the dataset data in each :param bool repeat: Whether to repeat the dataset data in each
condition during training. For further details, see the condition during training. For further details, see the
:class:`~pina.data.PinaDataModule` class. Default is ``False``. :class:`~pina.data.data_module.PinaDataModule` class. Default is
``False``.
:param bool automatic_batching: If ``True``, automatic PyTorch batching :param bool automatic_batching: If ``True``, automatic PyTorch batching
is performed, otherwise the items are retrieved from the dataset is performed, otherwise the items are retrieved from the dataset
all at once. For further details, see the all at once. For further details, see the
:class:`~pina.data.PinaDataModule` class. Default is ``False``. :class:`~pina.data.data_module.PinaDataModule` class. Default is
``False``.
:param int num_workers: The number of worker threads for data loading. :param int num_workers: The number of worker threads for data loading.
Default is ``0`` (serial loading). Default is ``0`` (serial loading).
:param bool pin_memory: Whether to use pinned memory for faster data :param bool pin_memory: Whether to use pinned memory for faster data
@@ -105,7 +107,9 @@ class Trainer(lightning.pytorch.Trainer):
if compile is None or sys.platform == "win32": if compile is None or sys.platform == "win32":
compile = False compile = False
self.automatic_batching = ( repeat = repeat if repeat is not None else False
automatic_batching = (
automatic_batching if automatic_batching is not None else False automatic_batching if automatic_batching is not None else False
) )
@@ -180,15 +184,7 @@ class Trainer(lightning.pytorch.Trainer):
:param bool repeat: Whether to repeat the dataset data in each :param bool repeat: Whether to repeat the dataset data in each
condition during training. condition during training.
:param bool automatic_batching: Whether to perform automatic batching :param bool automatic_batching: Whether to perform automatic batching
with PyTorch. If ``True``, automatic PyTorch batching with PyTorch.
is performed, which consists of extracting one element at a time
from the dataset and collating them into a batch. This is useful
when the dataset is too large to fit into memory. On the other hand,
if ``False``, the items are retrieved from the dataset all at once
avoind the overhead of collating them into a batch and reducing the
__getitem__ calls to the dataset. This is useful when the dataset
fits into memory. Avoid using automatic batching when ``batch_size``
is large. Default is ``False``.
:param bool pin_memory: Whether to use pinned memory for faster data :param bool pin_memory: Whether to use pinned memory for faster data
transfer to GPU. transfer to GPU.
:param int num_workers: The number of worker threads for data loading. :param int num_workers: The number of worker threads for data loading.
@@ -293,7 +289,8 @@ class Trainer(lightning.pytorch.Trainer):
check_consistency(train_size, float) check_consistency(train_size, float)
check_consistency(test_size, float) check_consistency(test_size, float)
check_consistency(val_size, float) check_consistency(val_size, float)
check_consistency(repeat, bool) if repeat is not None:
check_consistency(repeat, bool)
if automatic_batching is not None: if automatic_batching is not None:
check_consistency(automatic_batching, bool) check_consistency(automatic_batching, bool)
if compile is not None: if compile is not None: