integrate new datamodule in trainer
This commit is contained in:
@@ -31,7 +31,8 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
test_size=0.0,
|
test_size=0.0,
|
||||||
val_size=0.0,
|
val_size=0.0,
|
||||||
compile=None,
|
compile=None,
|
||||||
repeat=None,
|
common_batch_size=True,
|
||||||
|
separate_conditions=False,
|
||||||
automatic_batching=None,
|
automatic_batching=None,
|
||||||
num_workers=None,
|
num_workers=None,
|
||||||
pin_memory=None,
|
pin_memory=None,
|
||||||
@@ -56,10 +57,12 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
:param bool compile: If ``True``, the model is compiled before training.
|
:param bool compile: If ``True``, the model is compiled before training.
|
||||||
Default is ``False``. For Windows users, it is always disabled. Not
|
Default is ``False``. For Windows users, it is always disabled. Not
|
||||||
supported for python version greater or equal than 3.14.
|
supported for python version greater or equal than 3.14.
|
||||||
:param bool repeat: Whether to repeat the dataset data in each
|
:param bool common_batch_size: If ``True``, the same batch size is used
|
||||||
condition during training. For further details, see the
|
for all conditions. If ``False``, each condition can have its own
|
||||||
:class:`~pina.data.data_module.PinaDataModule` class. Default is
|
batch size, proportional to the size of the dataset in that
|
||||||
``False``.
|
condition. Default is ``True``.
|
||||||
|
:param bool separate_conditions: If ``True``, dataloaders for each
|
||||||
|
condition are iterated separately. Default is ``False``.
|
||||||
:param bool automatic_batching: If ``True``, automatic PyTorch batching
|
:param bool automatic_batching: If ``True``, automatic PyTorch batching
|
||||||
is performed, otherwise the items are retrieved from the dataset
|
is performed, otherwise the items are retrieved from the dataset
|
||||||
all at once. For further details, see the
|
all at once. For further details, see the
|
||||||
@@ -82,7 +85,8 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
train_size=train_size,
|
train_size=train_size,
|
||||||
test_size=test_size,
|
test_size=test_size,
|
||||||
val_size=val_size,
|
val_size=val_size,
|
||||||
repeat=repeat,
|
common_batch_size=common_batch_size,
|
||||||
|
seperate_conditions=separate_conditions,
|
||||||
automatic_batching=automatic_batching,
|
automatic_batching=automatic_batching,
|
||||||
compile=compile,
|
compile=compile,
|
||||||
)
|
)
|
||||||
@@ -122,8 +126,6 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
UserWarning,
|
UserWarning,
|
||||||
)
|
)
|
||||||
|
|
||||||
repeat = repeat if repeat is not None else False
|
|
||||||
|
|
||||||
automatic_batching = (
|
automatic_batching = (
|
||||||
automatic_batching if automatic_batching is not None else False
|
automatic_batching if automatic_batching is not None else False
|
||||||
)
|
)
|
||||||
@@ -139,7 +141,8 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
test_size=test_size,
|
test_size=test_size,
|
||||||
val_size=val_size,
|
val_size=val_size,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
repeat=repeat,
|
common_batch_size=common_batch_size,
|
||||||
|
seperate_conditions=separate_conditions,
|
||||||
automatic_batching=automatic_batching,
|
automatic_batching=automatic_batching,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
@@ -177,7 +180,8 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
test_size,
|
test_size,
|
||||||
val_size,
|
val_size,
|
||||||
batch_size,
|
batch_size,
|
||||||
repeat,
|
common_batch_size,
|
||||||
|
seperate_conditions,
|
||||||
automatic_batching,
|
automatic_batching,
|
||||||
pin_memory,
|
pin_memory,
|
||||||
num_workers,
|
num_workers,
|
||||||
@@ -196,8 +200,10 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
:param float val_size: The percentage of elements to include in the
|
:param float val_size: The percentage of elements to include in the
|
||||||
validation dataset.
|
validation dataset.
|
||||||
:param int batch_size: The number of samples per batch to load.
|
:param int batch_size: The number of samples per batch to load.
|
||||||
:param bool repeat: Whether to repeat the dataset data in each
|
:param bool common_batch_size: Whether to use the same batch size for
|
||||||
condition during training.
|
all conditions.
|
||||||
|
:param bool seperate_conditions: Whether to iterate dataloaders for
|
||||||
|
each condition separately.
|
||||||
:param bool automatic_batching: Whether to perform automatic batching
|
:param bool automatic_batching: Whether to perform automatic batching
|
||||||
with PyTorch.
|
with PyTorch.
|
||||||
:param bool pin_memory: Whether to use pinned memory for faster data
|
:param bool pin_memory: Whether to use pinned memory for faster data
|
||||||
@@ -227,7 +233,8 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
test_size=test_size,
|
test_size=test_size,
|
||||||
val_size=val_size,
|
val_size=val_size,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
repeat=repeat,
|
common_batch_size=common_batch_size,
|
||||||
|
separate_conditions=seperate_conditions,
|
||||||
automatic_batching=automatic_batching,
|
automatic_batching=automatic_batching,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
@@ -279,7 +286,8 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
train_size,
|
train_size,
|
||||||
test_size,
|
test_size,
|
||||||
val_size,
|
val_size,
|
||||||
repeat,
|
common_batch_size,
|
||||||
|
seperate_conditions,
|
||||||
automatic_batching,
|
automatic_batching,
|
||||||
compile,
|
compile,
|
||||||
):
|
):
|
||||||
@@ -293,8 +301,10 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
test dataset.
|
test dataset.
|
||||||
:param float val_size: The percentage of elements to include in the
|
:param float val_size: The percentage of elements to include in the
|
||||||
validation dataset.
|
validation dataset.
|
||||||
:param bool repeat: Whether to repeat the dataset data in each
|
:param bool common_batch_size: Whether to use the same batch size for
|
||||||
condition during training.
|
all conditions.
|
||||||
|
:param bool seperate_conditions: Whether to iterate dataloaders for
|
||||||
|
each condition separately.
|
||||||
:param bool automatic_batching: Whether to perform automatic batching
|
:param bool automatic_batching: Whether to perform automatic batching
|
||||||
with PyTorch.
|
with PyTorch.
|
||||||
:param bool compile: If ``True``, the model is compiled before training.
|
:param bool compile: If ``True``, the model is compiled before training.
|
||||||
@@ -304,8 +314,8 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
check_consistency(train_size, float)
|
check_consistency(train_size, float)
|
||||||
check_consistency(test_size, float)
|
check_consistency(test_size, float)
|
||||||
check_consistency(val_size, float)
|
check_consistency(val_size, float)
|
||||||
if repeat is not None:
|
check_consistency(common_batch_size, bool)
|
||||||
check_consistency(repeat, bool)
|
check_consistency(seperate_conditions, bool)
|
||||||
if automatic_batching is not None:
|
if automatic_batching is not None:
|
||||||
check_consistency(automatic_batching, bool)
|
check_consistency(automatic_batching, bool)
|
||||||
if compile is not None:
|
if compile is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user