integrate new datamodule in trainer

This commit is contained in:
FilippoOlivo
2025-11-12 15:59:48 +01:00
parent 4d172a8821
commit 09677d3c15

View File

@@ -31,7 +31,8 @@ class Trainer(lightning.pytorch.Trainer):
test_size=0.0,
val_size=0.0,
compile=None,
repeat=None,
common_batch_size=True,
separate_conditions=False,
automatic_batching=None,
num_workers=None,
pin_memory=None,
@@ -56,10 +57,12 @@ class Trainer(lightning.pytorch.Trainer):
:param bool compile: If ``True``, the model is compiled before training.
Default is ``False``. For Windows users, it is always disabled. Not
supported for python version greater or equal than 3.14.
:param bool repeat: Whether to repeat the dataset data in each
condition during training. For further details, see the
:class:`~pina.data.data_module.PinaDataModule` class. Default is
``False``.
:param bool common_batch_size: If ``True``, the same batch size is used
for all conditions. If ``False``, each condition can have its own
batch size, proportional to the size of the dataset in that
condition. Default is ``True``.
:param bool separate_conditions: If ``True``, dataloaders for each
condition are iterated separately. Default is ``False``.
:param bool automatic_batching: If ``True``, automatic PyTorch batching
is performed, otherwise the items are retrieved from the dataset
all at once. For further details, see the
@@ -82,7 +85,8 @@ class Trainer(lightning.pytorch.Trainer):
train_size=train_size,
test_size=test_size,
val_size=val_size,
repeat=repeat,
common_batch_size=common_batch_size,
seperate_conditions=separate_conditions,
automatic_batching=automatic_batching,
compile=compile,
)
@@ -122,8 +126,6 @@ class Trainer(lightning.pytorch.Trainer):
UserWarning,
)
repeat = repeat if repeat is not None else False
automatic_batching = (
automatic_batching if automatic_batching is not None else False
)
@@ -139,7 +141,8 @@ class Trainer(lightning.pytorch.Trainer):
test_size=test_size,
val_size=val_size,
batch_size=batch_size,
repeat=repeat,
common_batch_size=common_batch_size,
seperate_conditions=separate_conditions,
automatic_batching=automatic_batching,
pin_memory=pin_memory,
num_workers=num_workers,
@@ -177,7 +180,8 @@ class Trainer(lightning.pytorch.Trainer):
test_size,
val_size,
batch_size,
repeat,
common_batch_size,
seperate_conditions,
automatic_batching,
pin_memory,
num_workers,
@@ -196,8 +200,10 @@ class Trainer(lightning.pytorch.Trainer):
:param float val_size: The percentage of elements to include in the
validation dataset.
:param int batch_size: The number of samples per batch to load.
:param bool repeat: Whether to repeat the dataset data in each
condition during training.
:param bool common_batch_size: Whether to use the same batch size for
all conditions.
:param bool seperate_conditions: Whether to iterate dataloaders for
each condition separately.
:param bool automatic_batching: Whether to perform automatic batching
with PyTorch.
:param bool pin_memory: Whether to use pinned memory for faster data
@@ -227,7 +233,8 @@ class Trainer(lightning.pytorch.Trainer):
test_size=test_size,
val_size=val_size,
batch_size=batch_size,
repeat=repeat,
common_batch_size=common_batch_size,
separate_conditions=seperate_conditions,
automatic_batching=automatic_batching,
num_workers=num_workers,
pin_memory=pin_memory,
@@ -279,7 +286,8 @@ class Trainer(lightning.pytorch.Trainer):
train_size,
test_size,
val_size,
repeat,
common_batch_size,
seperate_conditions,
automatic_batching,
compile,
):
@@ -293,8 +301,10 @@ class Trainer(lightning.pytorch.Trainer):
test dataset.
:param float val_size: The percentage of elements to include in the
validation dataset.
:param bool repeat: Whether to repeat the dataset data in each
condition during training.
:param bool common_batch_size: Whether to use the same batch size for
all conditions.
:param bool seperate_conditions: Whether to iterate dataloaders for
each condition separately.
:param bool automatic_batching: Whether to perform automatic batching
with PyTorch.
:param bool compile: If ``True``, the model is compiled before training.
@@ -304,8 +314,8 @@ class Trainer(lightning.pytorch.Trainer):
check_consistency(train_size, float)
check_consistency(test_size, float)
check_consistency(val_size, float)
if repeat is not None:
check_consistency(repeat, bool)
check_consistency(common_batch_size, bool)
check_consistency(seperate_conditions, bool)
if automatic_batching is not None:
check_consistency(automatic_batching, bool)
if compile is not None: