Skip to content

Commit 54c3ec3

Browse files
authored
[PIR] Fix in-batch negative recall model for neural_search (#10352)
1 parent ae355c0 commit 54c3ec3

File tree

4 files changed

+52
-6
lines changed

4 files changed

+52
-6
lines changed

slm/applications/neural_search/recall/in_batch_negative/README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,14 +182,15 @@ Recall@K 召回率是指预测的前 topK(top-k 是指从最后的按得分排
182182

183183
如果使用 CPU 进行训练,则需要吧`--gpus`参数去除,然后吧`device`设置成 cpu 即可,详细请参考 train_batch_neg.sh 文件的训练设置
184184

185+
如果不存在```checkpoints/inbatch```, 需要在命令行运行```mkdir -p checkpoints/inbatch```创建相关目录(如果运行脚本进行训练则不需要)。
186+
185187
然后运行下面的命令使用 GPU 训练,得到语义索引模型:
186188

187189
```
188-
root_path=inbatch
189190
python -u -m paddle.distributed.launch --gpus "0,1,2,3" \
190191
train_batch_neg.py \
191192
--device gpu \
192-
--save_dir ./checkpoints/${root_path} \
193+
--save_dir ./checkpoints/inbatch \
193194
--batch_size 64 \
194195
--learning_rate 5E-5 \
195196
--epochs 3 \
@@ -464,7 +465,7 @@ python deploy/python/predict.py \
464465
也可以运行下面的 bash 脚本:
465466

466467
```
467-
sh deploy.sh
468+
sh deploy/python/deploy.sh
468469
```
469470
最终输出的是256维度的特征向量和句子对的预测概率:
470471

slm/applications/neural_search/recall/in_batch_negative/deploy/python/deploy.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
python predict.py --model_dir=../../output
15+
python ./deploy/python/predict.py --model_dir=./output \
16+
--model_name_or_path rocketqa-zh-base-query-encoder

slm/applications/neural_search/recall/in_batch_negative/deploy/python/predict.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from paddlenlp.data import Pad, Tuple
2323
from paddlenlp.transformers import AutoTokenizer
2424
from paddlenlp.utils.log import logger
25+
from paddlenlp.utils.env import PADDLE_INFERENCE_MODEL_SUFFIX, PADDLE_INFERENCE_WEIGHTS_SUFFIX
2526

2627
sys.path.append(".")
2728

@@ -87,8 +88,8 @@ def __init__(
8788
self.max_seq_length = max_seq_length
8889
self.batch_size = batch_size
8990

90-
model_file = model_dir + "/inference.pdmodel"
91-
params_file = model_dir + "/inference.pdiparams"
91+
model_file = model_dir + f"/inference{PADDLE_INFERENCE_MODEL_SUFFIX}"
92+
params_file = model_dir + f"/inference{PADDLE_INFERENCE_WEIGHTS_SUFFIX}"
9293
if not os.path.exists(model_file):
9394
raise ValueError("not find model file path {}".format(model_file))
9495
if not os.path.exists(params_file):
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
root_dir="./checkpoints/inbatch"
16+
17+
if [ ! -d "$root_dir" ]; then
18+
mkdir -p "$root_dir"
19+
echo "Created directory: $root_dir"
20+
else
21+
echo "Directory already exists: $root_dir"
22+
fi
23+
24+
python -u -m paddle.distributed.launch --gpus "0" \
25+
train_batch_neg.py \
26+
--device gpu \
27+
--save_dir ${root_dir} \
28+
--batch_size 64 \
29+
--learning_rate 5E-5 \
30+
--epochs 3 \
31+
--output_emb_size 256 \
32+
--model_name_or_path rocketqa-zh-base-query-encoder \
33+
--save_steps 10 \
34+
--max_seq_length 64 \
35+
--margin 0.2 \
36+
--train_set_file recall/train.csv \
37+
--recall_result_dir "recall_result_dir" \
38+
--recall_result_file "recall_result.txt" \
39+
--hnsw_m 100 \
40+
--hnsw_ef 100 \
41+
--recall_num 50 \
42+
--similar_text_pair_file "recall/dev.csv" \
43+
--corpus_file "recall/corpus.csv"

0 commit comments

Comments
 (0)