integrate new datamodule in trainer
This commit is contained in:
@@ -31,7 +31,8 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
test_size=0.0,
|
||||
val_size=0.0,
|
||||
compile=None,
|
||||
repeat=None,
|
||||
common_batch_size=True,
|
||||
separate_conditions=False,
|
||||
automatic_batching=None,
|
||||
num_workers=None,
|
||||
pin_memory=None,
|
||||
@@ -56,10 +57,12 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
:param bool compile: If ``True``, the model is compiled before training.
|
||||
Default is ``False``. For Windows users, it is always disabled. Not
|
||||
supported for python version greater or equal than 3.14.
|
||||
:param bool repeat: Whether to repeat the dataset data in each
|
||||
condition during training. For further details, see the
|
||||
:class:`~pina.data.data_module.PinaDataModule` class. Default is
|
||||
``False``.
|
||||
:param bool common_batch_size: If ``True``, the same batch size is used
|
||||
for all conditions. If ``False``, each condition can have its own
|
||||
batch size, proportional to the size of the dataset in that
|
||||
condition. Default is ``True``.
|
||||
:param bool separate_conditions: If ``True``, dataloaders for each
|
||||
condition are iterated separately. Default is ``False``.
|
||||
:param bool automatic_batching: If ``True``, automatic PyTorch batching
|
||||
is performed, otherwise the items are retrieved from the dataset
|
||||
all at once. For further details, see the
|
||||
@@ -82,7 +85,8 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
train_size=train_size,
|
||||
test_size=test_size,
|
||||
val_size=val_size,
|
||||
repeat=repeat,
|
||||
common_batch_size=common_batch_size,
|
||||
seperate_conditions=separate_conditions,
|
||||
automatic_batching=automatic_batching,
|
||||
compile=compile,
|
||||
)
|
||||
@@ -122,8 +126,6 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
repeat = repeat if repeat is not None else False
|
||||
|
||||
automatic_batching = (
|
||||
automatic_batching if automatic_batching is not None else False
|
||||
)
|
||||
@@ -139,7 +141,8 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
test_size=test_size,
|
||||
val_size=val_size,
|
||||
batch_size=batch_size,
|
||||
repeat=repeat,
|
||||
common_batch_size=common_batch_size,
|
||||
seperate_conditions=separate_conditions,
|
||||
automatic_batching=automatic_batching,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
@@ -177,7 +180,8 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
test_size,
|
||||
val_size,
|
||||
batch_size,
|
||||
repeat,
|
||||
common_batch_size,
|
||||
seperate_conditions,
|
||||
automatic_batching,
|
||||
pin_memory,
|
||||
num_workers,
|
||||
@@ -196,8 +200,10 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
:param float val_size: The percentage of elements to include in the
|
||||
validation dataset.
|
||||
:param int batch_size: The number of samples per batch to load.
|
||||
:param bool repeat: Whether to repeat the dataset data in each
|
||||
condition during training.
|
||||
:param bool common_batch_size: Whether to use the same batch size for
|
||||
all conditions.
|
||||
:param bool seperate_conditions: Whether to iterate dataloaders for
|
||||
each condition separately.
|
||||
:param bool automatic_batching: Whether to perform automatic batching
|
||||
with PyTorch.
|
||||
:param bool pin_memory: Whether to use pinned memory for faster data
|
||||
@@ -227,7 +233,8 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
test_size=test_size,
|
||||
val_size=val_size,
|
||||
batch_size=batch_size,
|
||||
repeat=repeat,
|
||||
common_batch_size=common_batch_size,
|
||||
separate_conditions=seperate_conditions,
|
||||
automatic_batching=automatic_batching,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory,
|
||||
@@ -279,7 +286,8 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
train_size,
|
||||
test_size,
|
||||
val_size,
|
||||
repeat,
|
||||
common_batch_size,
|
||||
seperate_conditions,
|
||||
automatic_batching,
|
||||
compile,
|
||||
):
|
||||
@@ -293,8 +301,10 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
test dataset.
|
||||
:param float val_size: The percentage of elements to include in the
|
||||
validation dataset.
|
||||
:param bool repeat: Whether to repeat the dataset data in each
|
||||
condition during training.
|
||||
:param bool common_batch_size: Whether to use the same batch size for
|
||||
all conditions.
|
||||
:param bool seperate_conditions: Whether to iterate dataloaders for
|
||||
each condition separately.
|
||||
:param bool automatic_batching: Whether to perform automatic batching
|
||||
with PyTorch.
|
||||
:param bool compile: If ``True``, the model is compiled before training.
|
||||
@@ -304,8 +314,8 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
check_consistency(train_size, float)
|
||||
check_consistency(test_size, float)
|
||||
check_consistency(val_size, float)
|
||||
if repeat is not None:
|
||||
check_consistency(repeat, bool)
|
||||
check_consistency(common_batch_size, bool)
|
||||
check_consistency(seperate_conditions, bool)
|
||||
if automatic_batching is not None:
|
||||
check_consistency(automatic_batching, bool)
|
||||
if compile is not None:
|
||||
|
||||
Reference in New Issue
Block a user