From 23028534338ccc71f80b125e3b06a71dd90a49d9 Mon Sep 17 00:00:00 2001 From: ouyhlan Date: Mon, 29 Nov 2021 18:03:26 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8DTrainer=E9=87=8Ccheck=5Fcode?= =?UTF-8?q?=E5=87=BD=E6=95=B0=E5=BF=BD=E7=95=A5pin=5Fmemory=E5=8F=82?= =?UTF-8?q?=E6=95=B0=E5=AF=BC=E8=87=B4=E7=9A=84=E5=86=85=E5=AD=98bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index f4f8a093..f90070c1 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -558,7 +558,7 @@ def __init__(self, train_data, model, optimizer=None, loss=None, check_batch_size = max(len(self.model.device_ids), check_batch_size) _check_code(dataset=train_data, model=self.model, losser=losser, forward_func=self._forward_func, metrics=metrics, dev_data=dev_dataset, metric_key=self.metric_key, check_level=check_code_level, - batch_size=check_batch_size) + batch_size=check_batch_size, pin_memory=self.pin_memory) self.train_data = train_data self.dev_data = dev_data # If None, No validation. @@ -950,7 +950,7 @@ def _get_value_info(_dict): return strs -def _check_code(dataset, model, losser, metrics, forward_func, batch_size=DEFAULT_CHECK_BATCH_SIZE, +def _check_code(dataset, model, losser, metrics, forward_func, pin_memory, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, metric_key=None, check_level=0): # check get_loss 方法 model_device = _get_model_device(model=model) @@ -1010,7 +1010,7 @@ def _check_code(dataset, model, losser, metrics, forward_func, batch_size=DEFAUL if dev_data is not None: tester = Tester(data=dev_data[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, - batch_size=batch_size, verbose=-1, use_tqdm=False) + batch_size=batch_size, verbose=-1, use_tqdm=False, pin_memory=pin_memory) evaluate_results = tester.test() _check_eval_results(metrics=evaluate_results, metric_key=metric_key, metric_list=metrics)