integrate new datamodule in trainer

This commit is contained in:
FilippoOlivo
2025-11-12 15:59:48 +01:00
parent 4d172a8821
commit 09677d3c15

View File

@@ -31,7 +31,8 @@ class Trainer(lightning.pytorch.Trainer):
test_size=0.0, test_size=0.0,
val_size=0.0, val_size=0.0,
compile=None, compile=None,
repeat=None, common_batch_size=True,
separate_conditions=False,
automatic_batching=None, automatic_batching=None,
num_workers=None, num_workers=None,
pin_memory=None, pin_memory=None,
@@ -56,10 +57,12 @@ class Trainer(lightning.pytorch.Trainer):
:param bool compile: If ``True``, the model is compiled before training. :param bool compile: If ``True``, the model is compiled before training.
Default is ``False``. For Windows users, it is always disabled. Not Default is ``False``. For Windows users, it is always disabled. Not
supported for python version greater or equal than 3.14. supported for python version greater or equal than 3.14.
:param bool repeat: Whether to repeat the dataset data in each :param bool common_batch_size: If ``True``, the same batch size is used
condition during training. For further details, see the for all conditions. If ``False``, each condition can have its own
:class:`~pina.data.data_module.PinaDataModule` class. Default is batch size, proportional to the size of the dataset in that
``False``. condition. Default is ``True``.
:param bool separate_conditions: If ``True``, dataloaders for each
condition are iterated separately. Default is ``False``.
:param bool automatic_batching: If ``True``, automatic PyTorch batching :param bool automatic_batching: If ``True``, automatic PyTorch batching
is performed, otherwise the items are retrieved from the dataset is performed, otherwise the items are retrieved from the dataset
all at once. For further details, see the all at once. For further details, see the
@@ -82,7 +85,8 @@ class Trainer(lightning.pytorch.Trainer):
train_size=train_size, train_size=train_size,
test_size=test_size, test_size=test_size,
val_size=val_size, val_size=val_size,
repeat=repeat, common_batch_size=common_batch_size,
seperate_conditions=separate_conditions,
automatic_batching=automatic_batching, automatic_batching=automatic_batching,
compile=compile, compile=compile,
) )
@@ -122,8 +126,6 @@ class Trainer(lightning.pytorch.Trainer):
UserWarning, UserWarning,
) )
repeat = repeat if repeat is not None else False
automatic_batching = ( automatic_batching = (
automatic_batching if automatic_batching is not None else False automatic_batching if automatic_batching is not None else False
) )
@@ -139,7 +141,8 @@ class Trainer(lightning.pytorch.Trainer):
test_size=test_size, test_size=test_size,
val_size=val_size, val_size=val_size,
batch_size=batch_size, batch_size=batch_size,
repeat=repeat, common_batch_size=common_batch_size,
seperate_conditions=separate_conditions,
automatic_batching=automatic_batching, automatic_batching=automatic_batching,
pin_memory=pin_memory, pin_memory=pin_memory,
num_workers=num_workers, num_workers=num_workers,
@@ -177,7 +180,8 @@ class Trainer(lightning.pytorch.Trainer):
test_size, test_size,
val_size, val_size,
batch_size, batch_size,
repeat, common_batch_size,
seperate_conditions,
automatic_batching, automatic_batching,
pin_memory, pin_memory,
num_workers, num_workers,
@@ -196,8 +200,10 @@ class Trainer(lightning.pytorch.Trainer):
:param float val_size: The percentage of elements to include in the :param float val_size: The percentage of elements to include in the
validation dataset. validation dataset.
:param int batch_size: The number of samples per batch to load. :param int batch_size: The number of samples per batch to load.
:param bool repeat: Whether to repeat the dataset data in each :param bool common_batch_size: Whether to use the same batch size for
condition during training. all conditions.
:param bool seperate_conditions: Whether to iterate dataloaders for
each condition separately.
:param bool automatic_batching: Whether to perform automatic batching :param bool automatic_batching: Whether to perform automatic batching
with PyTorch. with PyTorch.
:param bool pin_memory: Whether to use pinned memory for faster data :param bool pin_memory: Whether to use pinned memory for faster data
@@ -227,7 +233,8 @@ class Trainer(lightning.pytorch.Trainer):
test_size=test_size, test_size=test_size,
val_size=val_size, val_size=val_size,
batch_size=batch_size, batch_size=batch_size,
repeat=repeat, common_batch_size=common_batch_size,
separate_conditions=seperate_conditions,
automatic_batching=automatic_batching, automatic_batching=automatic_batching,
num_workers=num_workers, num_workers=num_workers,
pin_memory=pin_memory, pin_memory=pin_memory,
@@ -279,7 +286,8 @@ class Trainer(lightning.pytorch.Trainer):
train_size, train_size,
test_size, test_size,
val_size, val_size,
repeat, common_batch_size,
seperate_conditions,
automatic_batching, automatic_batching,
compile, compile,
): ):
@@ -293,8 +301,10 @@ class Trainer(lightning.pytorch.Trainer):
test dataset. test dataset.
:param float val_size: The percentage of elements to include in the :param float val_size: The percentage of elements to include in the
validation dataset. validation dataset.
:param bool repeat: Whether to repeat the dataset data in each :param bool common_batch_size: Whether to use the same batch size for
condition during training. all conditions.
:param bool seperate_conditions: Whether to iterate dataloaders for
each condition separately.
:param bool automatic_batching: Whether to perform automatic batching :param bool automatic_batching: Whether to perform automatic batching
with PyTorch. with PyTorch.
:param bool compile: If ``True``, the model is compiled before training. :param bool compile: If ``True``, the model is compiled before training.
@@ -304,8 +314,8 @@ class Trainer(lightning.pytorch.Trainer):
check_consistency(train_size, float) check_consistency(train_size, float)
check_consistency(test_size, float) check_consistency(test_size, float)
check_consistency(val_size, float) check_consistency(val_size, float)
if repeat is not None: check_consistency(common_batch_size, bool)
check_consistency(repeat, bool) check_consistency(seperate_conditions, bool)
if automatic_batching is not None: if automatic_batching is not None:
check_consistency(automatic_batching, bool) check_consistency(automatic_batching, bool)
if compile is not None: if compile is not None: