Skip to content

Commit

Permalink
修复Trainer里check_code函数忽略pin_memory参数导致的内存bug
Browse files Browse the repository at this point in the history
  • Loading branch information
ouyhlan committed Nov 29, 2021
1 parent 9ac7d09 commit 3be86c6
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 3 deletions.
2 changes: 1 addition & 1 deletion fastNLP/core/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=No
self.verbose = verbose
self.use_tqdm = use_tqdm
self.logger = logger
self.pin_memory = kwargs.get('pin_memory', True)
self.pin_memory = kwargs.get('pin_memory', False)

if isinstance(data, DataSet):
sampler = kwargs.get('sampler', None)
Expand Down
3 changes: 1 addition & 2 deletions fastNLP/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,6 @@ def on_epoch_end(self):
except:
from .utils import _pseudo_tqdm as tqdm
import warnings
from pkg_resources import parse_version

from .batch import DataSetIter, BatchIter
from .callback import CallbackManager, CallbackException, Callback
Expand Down Expand Up @@ -475,7 +474,7 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
if drop_last:
warnings.warn("drop_last is ignored when train_data is BatchIter.")
# concerning issue from https://github.com/pytorch/pytorch/issues/57273
self.pin_memory = kwargs.get('pin_memory', False if parse_version(torch.__version__)==parse_version('1.9') else True)
self.pin_memory = kwargs.get('pin_memory', False)
if isinstance(model, nn.parallel.DistributedDataParallel): # 如果是分布式的
# device为None
if device is not None:
Expand Down

0 comments on commit 3be86c6

Please sign in to comment.