update trainer (#466)

This commit is contained in:
Dario Coscia
2025-02-26 19:29:19 +01:00
committed by Nicola Demo
parent a2c08ae211
commit c3aaf5b1a0

View File

@@ -22,6 +22,7 @@ class Trainer(lightning.pytorch.Trainer):
automatic_batching=None,
num_workers=None,
pin_memory=None,
shuffle=None,
**kwargs,
):
"""
@@ -34,13 +35,13 @@ class Trainer(lightning.pytorch.Trainer):
If ``batch_size=None`` all
samples are loaded and data are not batched, defaults to None.
:type batch_size: int | None
:param train_size: percentage of elements in the train dataset
:param train_size: Percentage of elements in the train dataset.
:type train_size: float
:param test_size: percentage of elements in the test dataset
:param test_size: Percentage of elements in the test dataset.
:type test_size: float
:param val_size: percentage of elements in the val dataset
:param val_size: Percentage of elements in the val dataset.
:type val_size: float
:param predict_size: percentage of elements in the predict dataset
:param predict_size: Percentage of elements in the predict dataset.
:type predict_size: float
:param compile: if True model is compiled before training,
default False. For Windows users compilation is always disabled.
@@ -49,9 +50,13 @@ class Trainer(lightning.pytorch.Trainer):
performed. Please avoid using automatic batching when batch_size is
large, default False.
:type automatic_batching: bool
:param num_workers: Number of worker threads for data loading. Default 0 (serial loading)
:param num_workers: Number of worker threads for data loading.
Default 0 (serial loading).
:type num_workers: int
:param pin_memory: Whether to use pinned memory for faster data transfer to GPU. (Default False)
:param pin_memory: Whether to use pinned memory for faster data
transfer to GPU. Default False.
:type pin_memory: bool
:param shuffle: Whether to shuffle the data for training. Default False.
:type pin_memory: bool
:Keyword Arguments:
@@ -77,6 +82,10 @@ class Trainer(lightning.pytorch.Trainer):
check_consistency(pin_memory, int)
else:
num_workers = 0
if shuffle is not None:
check_consistency(shuffle, bool)
else:
shuffle = False
if train_size + test_size + val_size + predict_size > 1:
raise ValueError(
"train_size, test_size, val_size and predict_size "
@@ -131,6 +140,7 @@ class Trainer(lightning.pytorch.Trainer):
automatic_batching,
pin_memory,
num_workers,
shuffle,
)
# logging
@@ -166,6 +176,7 @@ class Trainer(lightning.pytorch.Trainer):
automatic_batching,
pin_memory,
num_workers,
shuffle,
):
"""
This method is used here because is resampling is needed
@@ -196,6 +207,7 @@ class Trainer(lightning.pytorch.Trainer):
automatic_batching=automatic_batching,
num_workers=num_workers,
pin_memory=pin_memory,
shuffle=shuffle,
)
def train(self, **kwargs):