Skip to content

Commit 1aadf1b

Browse files
authored
[PIR] Fix text_classification (#10497)
1 parent 95901cb commit 1aadf1b

File tree

7 files changed

+37
-13
lines changed

7 files changed

+37
-13
lines changed

slm/applications/text_classification/hierarchical/retrieval_based/deploy/python/predict.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222

2323
from paddlenlp.data import Pad, Tuple
2424
from paddlenlp.transformers import AutoTokenizer
25+
from paddlenlp.utils.env import (
26+
PADDLE_INFERENCE_MODEL_SUFFIX,
27+
PADDLE_INFERENCE_WEIGHTS_SUFFIX,
28+
)
2529

2630
sys.path.append(".")
2731

@@ -114,8 +118,8 @@ def __init__(
114118
self.max_seq_length = max_seq_length
115119
self.batch_size = batch_size
116120

117-
model_file = model_dir + "/inference.get_pooled_embedding.pdmodel"
118-
params_file = model_dir + "/inference.get_pooled_embedding.pdiparams"
121+
model_file = model_dir + f"/inference.get_pooled_embedding{PADDLE_INFERENCE_MODEL_SUFFIX}"
122+
params_file = model_dir + f"/inference.get_pooled_embedding{PADDLE_INFERENCE_WEIGHTS_SUFFIX}"
119123
if not os.path.exists(model_file):
120124
raise ValueError("not find model file path {}".format(model_file))
121125
if not os.path.exists(params_file):

slm/applications/text_classification/hierarchical/retrieval_based/utils/feature_extract.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222

2323
import paddlenlp as ppnlp
2424
from paddlenlp.data import Pad, Tuple
25+
from paddlenlp.utils.env import (
26+
PADDLE_INFERENCE_MODEL_SUFFIX,
27+
PADDLE_INFERENCE_WEIGHTS_SUFFIX,
28+
)
2529

2630
# fmt: off
2731
parser = argparse.ArgumentParser()
@@ -82,8 +86,8 @@ def __init__(
8286
self.max_seq_length = max_seq_length
8387
self.batch_size = batch_size
8488

85-
model_file = model_dir + "/inference.get_pooled_embedding.pdmodel"
86-
params_file = model_dir + "/inference.get_pooled_embedding.pdiparams"
89+
model_file = model_dir + f"/inference.get_pooled_embedding{PADDLE_INFERENCE_MODEL_SUFFIX}"
90+
params_file = model_dir + f"/inference.get_pooled_embedding{PADDLE_INFERENCE_WEIGHTS_SUFFIX}"
8791
if not os.path.exists(model_file):
8892
raise ValueError("not find model file path {}".format(model_file))
8993
if not os.path.exists(params_file):

slm/applications/text_classification/multi_class/retrieval_based/deploy/python/predict.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222

2323
from paddlenlp.data import Pad, Tuple
2424
from paddlenlp.transformers import AutoTokenizer
25+
from paddlenlp.utils.env import (
26+
PADDLE_INFERENCE_MODEL_SUFFIX,
27+
PADDLE_INFERENCE_WEIGHTS_SUFFIX,
28+
)
2529

2630
sys.path.append(".")
2731

@@ -114,8 +118,8 @@ def __init__(
114118
self.max_seq_length = max_seq_length
115119
self.batch_size = batch_size
116120

117-
model_file = model_dir + "/inference.get_pooled_embedding.pdmodel"
118-
params_file = model_dir + "/inference.get_pooled_embedding.pdiparams"
121+
model_file = model_dir + f"/inference.get_pooled_embedding{PADDLE_INFERENCE_MODEL_SUFFIX}"
122+
params_file = model_dir + f"/inference.get_pooled_embedding{PADDLE_INFERENCE_WEIGHTS_SUFFIX}"
119123
if not os.path.exists(model_file):
120124
raise ValueError("not find model file path {}".format(model_file))
121125
if not os.path.exists(params_file):

slm/applications/text_classification/multi_class/retrieval_based/utils/feature_extract.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222

2323
import paddlenlp as ppnlp
2424
from paddlenlp.data import Pad, Tuple
25+
from paddlenlp.utils.env import (
26+
PADDLE_INFERENCE_MODEL_SUFFIX,
27+
PADDLE_INFERENCE_WEIGHTS_SUFFIX,
28+
)
2529

2630
# fmt: off
2731
parser = argparse.ArgumentParser()
@@ -83,8 +87,8 @@ def __init__(
8387
self.max_seq_length = max_seq_length
8488
self.batch_size = batch_size
8589

86-
model_file = model_dir + "/inference.get_pooled_embedding.pdmodel"
87-
params_file = model_dir + "/inference.get_pooled_embedding.pdiparams"
90+
model_file = model_dir + f"/inference.get_pooled_embedding{PADDLE_INFERENCE_MODEL_SUFFIX}"
91+
params_file = model_dir + f"/inference.get_pooled_embedding{PADDLE_INFERENCE_WEIGHTS_SUFFIX}"
8892
if not os.path.exists(model_file):
8993
raise ValueError("not find model file path {}".format(model_file))
9094
if not os.path.exists(params_file):

slm/applications/text_classification/multi_label/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ python export_model.py --params_path ./checkpoint/ --output_path ./export --mult
355355
export/
356356
├── float32.pdiparams
357357
├── float32.pdiparams.info
358-
└── float32.pdmodel
358+
└── float32.json(PIR enabled)/float32.pdmodel(PIR disabled)
359359
```
360360
导出模型之后用于部署,项目提供了基于 ONNXRuntime 的 [离线部署方案](./deploy/predictor/README.md) 和基于 Paddle Serving 的 [在线服务化部署方案](./deploy/predictor/README.md)
361361

slm/applications/text_classification/multi_label/retrieval_based/deploy/python/predict.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222

2323
from paddlenlp.data import Pad, Tuple
2424
from paddlenlp.transformers import AutoTokenizer
25+
from paddlenlp.utils.env import (
26+
PADDLE_INFERENCE_MODEL_SUFFIX,
27+
PADDLE_INFERENCE_WEIGHTS_SUFFIX,
28+
)
2529

2630
sys.path.append(".")
2731

@@ -114,8 +118,8 @@ def __init__(
114118
self.max_seq_length = max_seq_length
115119
self.batch_size = batch_size
116120

117-
model_file = model_dir + "/inference.get_pooled_embedding.pdmodel"
118-
params_file = model_dir + "/inference.get_pooled_embedding.pdiparams"
121+
model_file = model_dir + f"/inference.get_pooled_embedding{PADDLE_INFERENCE_MODEL_SUFFIX}"
122+
params_file = model_dir + f"/inference.get_pooled_embedding{PADDLE_INFERENCE_WEIGHTS_SUFFIX}"
119123
if not os.path.exists(model_file):
120124
raise ValueError("not find model file path {}".format(model_file))
121125
if not os.path.exists(params_file):

slm/applications/text_classification/multi_label/retrieval_based/utils/feature_extract.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222

2323
import paddlenlp as ppnlp
2424
from paddlenlp.data import Pad, Tuple
25+
from paddlenlp.utils.env import (
26+
PADDLE_INFERENCE_MODEL_SUFFIX,
27+
PADDLE_INFERENCE_WEIGHTS_SUFFIX,
28+
)
2529

2630
# fmt: off
2731
parser = argparse.ArgumentParser()
@@ -84,8 +88,8 @@ def __init__(
8488
self.max_seq_length = max_seq_length
8589
self.batch_size = batch_size
8690

87-
model_file = model_dir + "/inference.get_pooled_embedding.pdmodel"
88-
params_file = model_dir + "/inference.get_pooled_embedding.pdiparams"
91+
model_file = model_dir + f"/inference.get_pooled_embedding{PADDLE_INFERENCE_MODEL_SUFFIX}"
92+
params_file = model_dir + f"/inference.get_pooled_embedding{PADDLE_INFERENCE_WEIGHTS_SUFFIX}"
8993
if not os.path.exists(model_file):
9094
raise ValueError("not find model file path {}".format(model_file))
9195
if not os.path.exists(params_file):

0 commit comments

Comments
 (0)