diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index f4f8a093..f90070c1 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -558,7 +558,7 @@ def __init__(self, train_data, model, optimizer=None, loss=None, check_batch_size = max(len(self.model.device_ids), check_batch_size) _check_code(dataset=train_data, model=self.model, losser=losser, forward_func=self._forward_func, metrics=metrics, dev_data=dev_dataset, metric_key=self.metric_key, check_level=check_code_level, - batch_size=check_batch_size) + batch_size=check_batch_size, pin_memory=self.pin_memory) self.train_data = train_data self.dev_data = dev_data # If None, No validation. @@ -950,7 +950,7 @@ def _get_value_info(_dict): return strs -def _check_code(dataset, model, losser, metrics, forward_func, batch_size=DEFAULT_CHECK_BATCH_SIZE, +def _check_code(dataset, model, losser, metrics, forward_func, pin_memory, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, metric_key=None, check_level=0): # check get_loss 方法 model_device = _get_model_device(model=model) @@ -1010,7 +1010,7 @@ def _check_code(dataset, model, losser, metrics, forward_func, batch_size=DEFAUL if dev_data is not None: tester = Tester(data=dev_data[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, - batch_size=batch_size, verbose=-1, use_tqdm=False) + batch_size=batch_size, verbose=-1, use_tqdm=False, pin_memory=pin_memory) evaluate_results = tester.test() _check_eval_results(metrics=evaluate_results, metric_key=metric_key, metric_list=metrics)