Skip to content

Commit 031ff47

Browse files
authored
[PIR] Fix machine_reading_comprehension SQuAD (#10445)
* [PIR] fix machine_reading_comprehension squad * Update README.md
1 parent d24ec24 commit 031ff47

File tree

2 files changed

+96
-1
lines changed

2 files changed

+96
-1
lines changed

slm/examples/machine_reading_comprehension/SQuAD/README.md

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,13 +164,22 @@ python -m paddle.distributed.launch --gpus "0" run_squad.py \
164164

165165
在 Fine-tune 完成后,我们可以使用如下方式导出希望用来预测的模型:
166166

167+
默认模型:
167168
```shell
168169
python -u ./export_model.py \
169170
--model_type bert \
170171
--model_path bert-base-uncased \
171172
--output_path ./infer_model/model
172173
```
173174

175+
微调模型:
176+
```shell
177+
python -u ./export_model.py \
178+
--model_type bert \
179+
--model_path tmp/squad/model_5000 \
180+
--output_path ./infer_model/model
181+
```
182+
174183
其中参数释义如下:
175184
- `model_type` 指示了模型类型,使用 BERT 模型时设置为 bert 即可。
176185
- `model_path` 表示训练模型的保存路径,与训练时的`output_dir`一致。
@@ -192,4 +201,83 @@ python -u deploy/python/predict.py \
192201
- `batch_size` 表示每个预测批次的样本数目。
193202
- `max_seq_length` 表示最大句子长度,超过该长度将被截断,和训练时一致。
194203

204+
运行结果示例:
205+
```
206+
{
207+
"exact": 37.74109323675567,
208+
"f1": 42.348199704946815,
209+
"total": 11873,
210+
"HasAns_exact": 75.59041835357625,
211+
"HasAns_f1": 84.81784330243481,
212+
"HasAns_total": 5928,
213+
"NoAns_exact": 0.0,
214+
"NoAns_f1": 0.0,
215+
"NoAns_total": 5945,
216+
"best_exact": 50.11370336056599,
217+
"best_exact_thresh": 0.0,
218+
"best_f1": 50.11370336056599,
219+
"best_f1_thresh": 0.0
220+
}
221+
```
222+
195223
以上命令将在 SQuAD v1.1的验证集上进行预测。此外,同训练时一样,用户可以通过命令行传入`--version_2_with_negative`控制所需要的 SQuAD 数据集版本。
224+
225+
### 其他问题
226+
#### Q1: 适配 python 3.8的 datasets 3.1.0无法支持当前任务
227+
如果运行时出现如下问题:
228+
229+
> File "/home/aistudio/.cache/huggingface/modules/datasets_modules/datasets/squad_v2/dca5ba0e483a42ca20ec41a13e9fb630541d6fcb0ba646da3e8ff9a1f21fcb81/squad_v2.py", line 19, in <module>
230+
> from datasets.tasks import QuestionAnsweringExtractive
231+
> ModuleNotFoundError: No module named 'datasets.tasks'
232+
233+
那么需要对 datasets 进行版本更换。运行:
234+
```shell
235+
pip install -U "datasets>=2.14.6,<3.0.0"
236+
```
237+
安装 ```datasets-2.21.0```等版本可以正常运行。
238+
239+
240+
#### Q2: 无法通过运行命令连接 huggingface 获取 SQuAD 数据集
241+
1. 手动从[数据集官网](https://rajpurkar.github.io/SQuAD-explorer/)下载 training/dev set 并放在当前目录。
242+
243+
2.```run_squad.py```中的
244+
```python
245+
if args.version_2_with_negative:
246+
train_examples = load_dataset("squad_v2", split="train", trust_remote_code=True)
247+
dev_examples = load_dataset("squad_v2", split="validation", trust_remote_code=True)
248+
else:
249+
train_examples = load_dataset("squad", split="train", trust_remote_code=True)
250+
dev_examples = load_dataset("squad", split="validation", trust_remote_code=True)
251+
```
252+
替换为
253+
```python
254+
datasets = load_dataset(
255+
"squad_v2",
256+
data_files={
257+
"train": "train-v2.0.json",
258+
"validation": "dev-v2.0.json"
259+
}
260+
)
261+
train_examples = datasets["train"]
262+
dev_examples = datasets["validation"]
263+
```
264+
265+
3.```deploy/python/predict.py```中的
266+
```python
267+
if args.version_2_with_negative:
268+
raw_dataset = load_dataset("squad_v2", split="validation")
269+
else:
270+
raw_dataset = load_dataset("squad", split="validation")
271+
```
272+
替换为
273+
```python
274+
datasets = load_dataset(
275+
"squad_v2",
276+
data_files={
277+
"train": "train-v2.0.json",
278+
"validation": "dev-v2.0.json"
279+
}
280+
)
281+
raw_dataset = datasets["validation"]
282+
```
283+
正常运行命令即可。

slm/examples/machine_reading_comprehension/SQuAD/deploy/python/predict.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121

2222
from paddlenlp.data import Dict, Pad
2323
from paddlenlp.metrics.squad import compute_prediction, squad_evaluate
24+
from paddlenlp.utils.env import (
25+
PADDLE_INFERENCE_MODEL_SUFFIX,
26+
PADDLE_INFERENCE_WEIGHTS_SUFFIX,
27+
)
2428

2529
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)))
2630
from args import parse_args # noqa: E402
@@ -35,7 +39,10 @@ def __init__(self, predictor, input_handles, output_handles):
3539

3640
@classmethod
3741
def create_predictor(cls, args):
38-
config = paddle.inference.Config(args.model_name_or_path + ".pdmodel", args.model_name_or_path + ".pdiparams")
42+
config = paddle.inference.Config(
43+
args.model_name_or_path + f"{PADDLE_INFERENCE_MODEL_SUFFIX}",
44+
args.model_name_or_path + f"{PADDLE_INFERENCE_WEIGHTS_SUFFIX}",
45+
)
3946
if args.device == "gpu":
4047
# set GPU configs accordingly
4148
config.enable_use_gpu(100, 0)

0 commit comments

Comments
 (0)