diff --git a/pina/trainer.py b/pina/trainer.py index 8831d40..f0175e9 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -22,6 +22,7 @@ class Trainer(lightning.pytorch.Trainer): automatic_batching=None, num_workers=None, pin_memory=None, + shuffle=None, **kwargs, ): """ @@ -34,13 +35,13 @@ class Trainer(lightning.pytorch.Trainer): If ``batch_size=None`` all samples are loaded and data are not batched, defaults to None. :type batch_size: int | None - :param train_size: percentage of elements in the train dataset + :param train_size: Percentage of elements in the train dataset. :type train_size: float - :param test_size: percentage of elements in the test dataset + :param test_size: Percentage of elements in the test dataset. :type test_size: float - :param val_size: percentage of elements in the val dataset + :param val_size: Percentage of elements in the val dataset. :type val_size: float - :param predict_size: percentage of elements in the predict dataset + :param predict_size: Percentage of elements in the predict dataset. :type predict_size: float :param compile: if True model is compiled before training, default False. For Windows users compilation is always disabled. @@ -49,9 +50,13 @@ class Trainer(lightning.pytorch.Trainer): performed. Please avoid using automatic batching when batch_size is large, default False. :type automatic_batching: bool - :param num_workers: Number of worker threads for data loading. Default 0 (serial loading) + :param num_workers: Number of worker threads for data loading. + Default 0 (serial loading). :type num_workers: int - :param pin_memory: Whether to use pinned memory for faster data transfer to GPU. (Default False) + :param pin_memory: Whether to use pinned memory for faster data + transfer to GPU. Default False. + :type pin_memory: bool + :param shuffle: Whether to shuffle the data for training. Default False. :type pin_memory: bool :Keyword Arguments: @@ -77,6 +82,10 @@ class Trainer(lightning.pytorch.Trainer): check_consistency(pin_memory, int) else: num_workers = 0 + if shuffle is not None: + check_consistency(shuffle, bool) + else: + shuffle = False if train_size + test_size + val_size + predict_size > 1: raise ValueError( "train_size, test_size, val_size and predict_size " @@ -131,6 +140,7 @@ class Trainer(lightning.pytorch.Trainer): automatic_batching, pin_memory, num_workers, + shuffle, ) # logging @@ -166,6 +176,7 @@ class Trainer(lightning.pytorch.Trainer): automatic_batching, pin_memory, num_workers, + shuffle, ): """ This method is used here because is resampling is needed @@ -196,6 +207,7 @@ class Trainer(lightning.pytorch.Trainer): automatic_batching=automatic_batching, num_workers=num_workers, pin_memory=pin_memory, + shuffle=shuffle, ) def train(self, **kwargs):