diff --git a/pina/trainer.py b/pina/trainer.py index 8e1d951..840f2a6 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -31,7 +31,8 @@ class Trainer(lightning.pytorch.Trainer): test_size=0.0, val_size=0.0, compile=None, - repeat=None, + common_batch_size=True, + separate_conditions=False, automatic_batching=None, num_workers=None, pin_memory=None, @@ -56,10 +57,12 @@ class Trainer(lightning.pytorch.Trainer): :param bool compile: If ``True``, the model is compiled before training. Default is ``False``. For Windows users, it is always disabled. Not supported for python version greater or equal than 3.14. - :param bool repeat: Whether to repeat the dataset data in each - condition during training. For further details, see the - :class:`~pina.data.data_module.PinaDataModule` class. Default is - ``False``. + :param bool common_batch_size: If ``True``, the same batch size is used + for all conditions. If ``False``, each condition can have its own + batch size, proportional to the size of the dataset in that + condition. Default is ``True``. + :param bool separate_conditions: If ``True``, dataloaders for each + condition are iterated separately. Default is ``False``. :param bool automatic_batching: If ``True``, automatic PyTorch batching is performed, otherwise the items are retrieved from the dataset all at once. For further details, see the @@ -82,7 +85,8 @@ class Trainer(lightning.pytorch.Trainer): train_size=train_size, test_size=test_size, val_size=val_size, - repeat=repeat, + common_batch_size=common_batch_size, + seperate_conditions=separate_conditions, automatic_batching=automatic_batching, compile=compile, ) @@ -122,8 +126,6 @@ class Trainer(lightning.pytorch.Trainer): UserWarning, ) - repeat = repeat if repeat is not None else False - automatic_batching = ( automatic_batching if automatic_batching is not None else False ) @@ -139,7 +141,8 @@ class Trainer(lightning.pytorch.Trainer): test_size=test_size, val_size=val_size, batch_size=batch_size, - repeat=repeat, + common_batch_size=common_batch_size, + seperate_conditions=separate_conditions, automatic_batching=automatic_batching, pin_memory=pin_memory, num_workers=num_workers, @@ -177,7 +180,8 @@ class Trainer(lightning.pytorch.Trainer): test_size, val_size, batch_size, - repeat, + common_batch_size, + seperate_conditions, automatic_batching, pin_memory, num_workers, @@ -196,8 +200,10 @@ class Trainer(lightning.pytorch.Trainer): :param float val_size: The percentage of elements to include in the validation dataset. :param int batch_size: The number of samples per batch to load. - :param bool repeat: Whether to repeat the dataset data in each - condition during training. + :param bool common_batch_size: Whether to use the same batch size for + all conditions. + :param bool seperate_conditions: Whether to iterate dataloaders for + each condition separately. :param bool automatic_batching: Whether to perform automatic batching with PyTorch. :param bool pin_memory: Whether to use pinned memory for faster data @@ -227,7 +233,8 @@ class Trainer(lightning.pytorch.Trainer): test_size=test_size, val_size=val_size, batch_size=batch_size, - repeat=repeat, + common_batch_size=common_batch_size, + separate_conditions=seperate_conditions, automatic_batching=automatic_batching, num_workers=num_workers, pin_memory=pin_memory, @@ -279,7 +286,8 @@ class Trainer(lightning.pytorch.Trainer): train_size, test_size, val_size, - repeat, + common_batch_size, + seperate_conditions, automatic_batching, compile, ): @@ -293,8 +301,10 @@ class Trainer(lightning.pytorch.Trainer): test dataset. :param float val_size: The percentage of elements to include in the validation dataset. - :param bool repeat: Whether to repeat the dataset data in each - condition during training. + :param bool common_batch_size: Whether to use the same batch size for + all conditions. + :param bool seperate_conditions: Whether to iterate dataloaders for + each condition separately. :param bool automatic_batching: Whether to perform automatic batching with PyTorch. :param bool compile: If ``True``, the model is compiled before training. @@ -304,8 +314,8 @@ class Trainer(lightning.pytorch.Trainer): check_consistency(train_size, float) check_consistency(test_size, float) check_consistency(val_size, float) - if repeat is not None: - check_consistency(repeat, bool) + check_consistency(common_batch_size, bool) + check_consistency(seperate_conditions, bool) if automatic_batching is not None: check_consistency(automatic_batching, bool) if compile is not None: