update trainer (#466)
This commit is contained in:
committed by
Nicola Demo
parent
a2c08ae211
commit
c3aaf5b1a0
@@ -22,6 +22,7 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
automatic_batching=None,
|
automatic_batching=None,
|
||||||
num_workers=None,
|
num_workers=None,
|
||||||
pin_memory=None,
|
pin_memory=None,
|
||||||
|
shuffle=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -34,13 +35,13 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
If ``batch_size=None`` all
|
If ``batch_size=None`` all
|
||||||
samples are loaded and data are not batched, defaults to None.
|
samples are loaded and data are not batched, defaults to None.
|
||||||
:type batch_size: int | None
|
:type batch_size: int | None
|
||||||
:param train_size: percentage of elements in the train dataset
|
:param train_size: Percentage of elements in the train dataset.
|
||||||
:type train_size: float
|
:type train_size: float
|
||||||
:param test_size: percentage of elements in the test dataset
|
:param test_size: Percentage of elements in the test dataset.
|
||||||
:type test_size: float
|
:type test_size: float
|
||||||
:param val_size: percentage of elements in the val dataset
|
:param val_size: Percentage of elements in the val dataset.
|
||||||
:type val_size: float
|
:type val_size: float
|
||||||
:param predict_size: percentage of elements in the predict dataset
|
:param predict_size: Percentage of elements in the predict dataset.
|
||||||
:type predict_size: float
|
:type predict_size: float
|
||||||
:param compile: if True model is compiled before training,
|
:param compile: if True model is compiled before training,
|
||||||
default False. For Windows users compilation is always disabled.
|
default False. For Windows users compilation is always disabled.
|
||||||
@@ -49,9 +50,13 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
performed. Please avoid using automatic batching when batch_size is
|
performed. Please avoid using automatic batching when batch_size is
|
||||||
large, default False.
|
large, default False.
|
||||||
:type automatic_batching: bool
|
:type automatic_batching: bool
|
||||||
:param num_workers: Number of worker threads for data loading. Default 0 (serial loading)
|
:param num_workers: Number of worker threads for data loading.
|
||||||
|
Default 0 (serial loading).
|
||||||
:type num_workers: int
|
:type num_workers: int
|
||||||
:param pin_memory: Whether to use pinned memory for faster data transfer to GPU. (Default False)
|
:param pin_memory: Whether to use pinned memory for faster data
|
||||||
|
transfer to GPU. Default False.
|
||||||
|
:type pin_memory: bool
|
||||||
|
:param shuffle: Whether to shuffle the data for training. Default False.
|
||||||
:type pin_memory: bool
|
:type pin_memory: bool
|
||||||
|
|
||||||
:Keyword Arguments:
|
:Keyword Arguments:
|
||||||
@@ -77,6 +82,10 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
check_consistency(pin_memory, int)
|
check_consistency(pin_memory, int)
|
||||||
else:
|
else:
|
||||||
num_workers = 0
|
num_workers = 0
|
||||||
|
if shuffle is not None:
|
||||||
|
check_consistency(shuffle, bool)
|
||||||
|
else:
|
||||||
|
shuffle = False
|
||||||
if train_size + test_size + val_size + predict_size > 1:
|
if train_size + test_size + val_size + predict_size > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"train_size, test_size, val_size and predict_size "
|
"train_size, test_size, val_size and predict_size "
|
||||||
@@ -131,6 +140,7 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
automatic_batching,
|
automatic_batching,
|
||||||
pin_memory,
|
pin_memory,
|
||||||
num_workers,
|
num_workers,
|
||||||
|
shuffle,
|
||||||
)
|
)
|
||||||
|
|
||||||
# logging
|
# logging
|
||||||
@@ -166,6 +176,7 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
automatic_batching,
|
automatic_batching,
|
||||||
pin_memory,
|
pin_memory,
|
||||||
num_workers,
|
num_workers,
|
||||||
|
shuffle,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
This method is used here because is resampling is needed
|
This method is used here because is resampling is needed
|
||||||
@@ -196,6 +207,7 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
automatic_batching=automatic_batching,
|
automatic_batching=automatic_batching,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
|
shuffle=shuffle,
|
||||||
)
|
)
|
||||||
|
|
||||||
def train(self, **kwargs):
|
def train(self, **kwargs):
|
||||||
|
|||||||
Reference in New Issue
Block a user