forked from PaddlePaddle/PaddleNLP
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Function optimization] use file-lock to enable multi-process downloa…
…ding (PaddlePaddle#3788) * add file_lock unittest * enable multi-processing downloading * use small weight file & more process to download * use global lock file * fix testing * modify the docstring * use self.assertGreater * change the file name & variable name * add docstring for testing method * update internal testing bert url * update multiprocess testing code * remove unused importing * add Optional import * use the origin method name * format tokenize_utils_base
- Loading branch information
Showing
7 changed files
with
190 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
import unittest | ||
import shutil | ||
from tempfile import TemporaryDirectory | ||
from tests.testing_utils import slow | ||
from multiprocessing import Pool | ||
from paddlenlp.transformers import TinyBertModel, BertModel | ||
from paddlenlp.utils.env import MODEL_HOME | ||
|
||
|
||
def download_bert_model(model_name: str): | ||
"""set the global method: multiprocessing can not pickle local method | ||
Args: | ||
model_name (str): the model name | ||
""" | ||
|
||
model = BertModel.from_pretrained(model_name) | ||
# free the model resource | ||
del model | ||
|
||
|
||
class TestModeling(unittest.TestCase): | ||
"""Test PretrainedModel single time, not in Transformer models""" | ||
|
||
@slow | ||
def test_from_pretrained_with_load_as_state_np_params(self): | ||
"""init model with `load_state_as_np` params""" | ||
model = TinyBertModel.from_pretrained("tinybert-4l-312d", | ||
load_state_as_np=True) | ||
self.assertIsNotNone(model) | ||
|
||
@slow | ||
def test_multiprocess_downloading(self): | ||
"""test downloading with multi-process. Some errors may be triggered when downloading model | ||
weight file with multiprocess, so this test code was born. | ||
`num_process_in_pool` is the number of process in Pool. | ||
And the `num_jobs` is the number of total process to download file. | ||
""" | ||
num_process_in_pool, num_jobs = 10, 20 | ||
small_model_path = "https://paddlenlp.bj.bcebos.com/models/community/__internal_testing__/bert/model_state.pdparams" | ||
|
||
from paddlenlp.transformers.model_utils import get_path_from_url | ||
with TemporaryDirectory() as tempdir: | ||
|
||
with Pool(num_process_in_pool) as pool: | ||
pool.starmap(get_path_from_url, [(small_model_path, tempdir) | ||
for _ in range(num_jobs)]) | ||
|
||
# @slow | ||
def test_model_from_pretrained_with_multiprocessing(self): | ||
""" | ||
this test can not init tooooo many models which will occupy CPU/GPU memorys. | ||
`num_process_in_pool` is the number of process in Pool. | ||
And the `num_jobs` is the number of total process to download file. | ||
""" | ||
num_process_in_pool, num_jobs = 1, 10 | ||
|
||
# 1.remove tinybert model weight file | ||
model_name = "__internal_testing__/bert" | ||
shutil.rmtree(os.path.join(MODEL_HOME, model_name), ignore_errors=True) | ||
|
||
# 2. downloaing tinybert modeling using multi-processing | ||
with Pool(num_process_in_pool) as pool: | ||
pool.starmap(download_bert_model, | ||
[(model_name, ) for _ in range(num_jobs)]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
import unittest | ||
import time | ||
from datetime import datetime | ||
from multiprocessing import Pool | ||
from tempfile import TemporaryDirectory, TemporaryFile | ||
from paddlenlp.utils.file_lock import FileLock | ||
|
||
|
||
def time_lock(lock_file: str) -> datetime: | ||
"""just sleep 1.2 seconds to test sequence timing | ||
Args: | ||
lock_file (str): the path of lock file | ||
Returns: | ||
datetime: the current datetime | ||
""" | ||
with FileLock(lock_file): | ||
time.sleep(1.2) | ||
return datetime.now() | ||
|
||
|
||
class TestFileLock(unittest.TestCase): | ||
|
||
def test_time_lock(self): | ||
"""lock the time""" | ||
with TemporaryDirectory() as tempdir: | ||
lock_file = os.path.join(tempdir, 'download.lock') | ||
pre_time, seconds = datetime.now(), 0 | ||
|
||
with Pool(4) as pool: | ||
datetimes = pool.map(time_lock, [lock_file for _ in range(10)]) | ||
datetimes.sort() | ||
|
||
pre_time = None | ||
for current_time in datetimes: | ||
if pre_time is None: | ||
pre_time = current_time | ||
else: | ||
self.assertGreater((current_time - pre_time).seconds, | ||
1 - 1e-3) |