update trainer (#466)
This commit is contained in:
committed by
Nicola Demo
parent
a2c08ae211
commit
c3aaf5b1a0
@@ -22,6 +22,7 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
automatic_batching=None,
|
||||
num_workers=None,
|
||||
pin_memory=None,
|
||||
shuffle=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -34,13 +35,13 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
If ``batch_size=None`` all
|
||||
samples are loaded and data are not batched, defaults to None.
|
||||
:type batch_size: int | None
|
||||
:param train_size: percentage of elements in the train dataset
|
||||
:param train_size: Percentage of elements in the train dataset.
|
||||
:type train_size: float
|
||||
:param test_size: percentage of elements in the test dataset
|
||||
:param test_size: Percentage of elements in the test dataset.
|
||||
:type test_size: float
|
||||
:param val_size: percentage of elements in the val dataset
|
||||
:param val_size: Percentage of elements in the val dataset.
|
||||
:type val_size: float
|
||||
:param predict_size: percentage of elements in the predict dataset
|
||||
:param predict_size: Percentage of elements in the predict dataset.
|
||||
:type predict_size: float
|
||||
:param compile: if True model is compiled before training,
|
||||
default False. For Windows users compilation is always disabled.
|
||||
@@ -49,9 +50,13 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
performed. Please avoid using automatic batching when batch_size is
|
||||
large, default False.
|
||||
:type automatic_batching: bool
|
||||
:param num_workers: Number of worker threads for data loading. Default 0 (serial loading)
|
||||
:param num_workers: Number of worker threads for data loading.
|
||||
Default 0 (serial loading).
|
||||
:type num_workers: int
|
||||
:param pin_memory: Whether to use pinned memory for faster data transfer to GPU. (Default False)
|
||||
:param pin_memory: Whether to use pinned memory for faster data
|
||||
transfer to GPU. Default False.
|
||||
:type pin_memory: bool
|
||||
:param shuffle: Whether to shuffle the data for training. Default False.
|
||||
:type pin_memory: bool
|
||||
|
||||
:Keyword Arguments:
|
||||
@@ -77,6 +82,10 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
check_consistency(pin_memory, int)
|
||||
else:
|
||||
num_workers = 0
|
||||
if shuffle is not None:
|
||||
check_consistency(shuffle, bool)
|
||||
else:
|
||||
shuffle = False
|
||||
if train_size + test_size + val_size + predict_size > 1:
|
||||
raise ValueError(
|
||||
"train_size, test_size, val_size and predict_size "
|
||||
@@ -131,6 +140,7 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
automatic_batching,
|
||||
pin_memory,
|
||||
num_workers,
|
||||
shuffle,
|
||||
)
|
||||
|
||||
# logging
|
||||
@@ -166,6 +176,7 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
automatic_batching,
|
||||
pin_memory,
|
||||
num_workers,
|
||||
shuffle,
|
||||
):
|
||||
"""
|
||||
This method is used here because is resampling is needed
|
||||
@@ -196,6 +207,7 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
automatic_batching=automatic_batching,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory,
|
||||
shuffle=shuffle,
|
||||
)
|
||||
|
||||
def train(self, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user