Skip to content

【PaddleNLP No.5】fix ernie_matching PIR #10399

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,13 @@ ernie_matching/

```

下载数据集并解压到当前目录:
```shell
wget https://bj.bcebos.com/v1/paddlenlp/data/literature_search_data.zip
unzip literature_search_data.zip
```


<a name="模型训练"></a>

## 5. 模型训练
Expand Down Expand Up @@ -301,7 +308,7 @@ python export_to_serving.py \

参数含义说明
* `dirname`: 需要转换的模型文件存储路径,Program 结构文件和参数文件均保存在此目录。
* `model_filename`: 存储需要转换的模型 Inference Program 结构的文件名称。如果设置为 None ,则使用 `__model__` 作为默认的文件名
* `model_filename`: 存储需要转换的模型 Inference Program 结构的文件名称。如果设置为 None ,则使用 `__model__` 作为默认的文件名。如果启用了 PIR ,则有可能为 ```.json``` 文件,请注意鉴别
* `params_filename`: 存储需要转换的模型所有参数的文件名称。当且仅当所有模型参数被保>存在一个单独的二进制文件中,它才需要被指定。如果模型参数是存储在各自分离的文件中,设置它的值为 None
* `server_path`: 转换后的模型文件和配置文件的存储路径。默认值为 serving_server
* `client_path`: 转换后的客户端配置文件存储路径。默认值为 serving_client
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
from paddlenlp.data import Pad, Tuple
from paddlenlp.datasets import load_dataset
from paddlenlp.transformers import AutoTokenizer
from paddlenlp.utils.env import (
PADDLE_INFERENCE_MODEL_SUFFIX,
PADDLE_INFERENCE_WEIGHTS_SUFFIX,
)
from paddlenlp.utils.log import logger

sys.path.append(".")
Expand Down Expand Up @@ -86,8 +90,8 @@ def __init__(
self.max_seq_length = max_seq_length
self.batch_size = batch_size

model_file = model_dir + "/inference.predict.pdmodel"
params_file = model_dir + "/inference.predict.pdiparams"
model_file = model_dir + f"/inference.predict{PADDLE_INFERENCE_MODEL_SUFFIX}"
params_file = model_dir + f"/inference.predict{PADDLE_INFERENCE_WEIGHTS_SUFFIX}"
if not os.path.exists(model_file):
raise ValueError("not find model file path {}".format(model_file))
if not os.path.exists(params_file):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# PIR disable
python export_to_serving.py \
--dirname "output" \
--model_filename "inference.predict.pdmodel" \
--params_filename "inference.predict.pdiparams" \
--server_path "serving_server" \
--client_path "serving_client" \
--fetch_alias_names "predict"

# PIR enable
# python export_to_serving.py \
# --dirname "output" \
# --model_filename "inference.predict.json" \
# --params_filename "inference.predict.pdiparams" \
# --server_path "serving_server" \
# --client_path "serving_client" \
# --fetch_alias_names "predict"