diff --git a/slm/model_zoo/gpt-3/ppfleetx/core/engine/inference_engine.py b/slm/model_zoo/gpt-3/ppfleetx/core/engine/inference_engine.py index 58e323d910cd..c9014bb595f5 100644 --- a/slm/model_zoo/gpt-3/ppfleetx/core/engine/inference_engine.py +++ b/slm/model_zoo/gpt-3/ppfleetx/core/engine/inference_engine.py @@ -18,6 +18,12 @@ import numpy as np import paddle import paddle.distributed.fleet as fleet + +from paddlenlp.utils.env import ( + PADDLE_INFERENCE_MODEL_SUFFIX, + PADDLE_INFERENCE_WEIGHTS_SUFFIX, +) + try: from ppfleetx_ops import topp_sampling except Exception as e: @@ -153,9 +159,9 @@ def _check_model(self): model_files = [] param_files = [] for fname in os.listdir(rank_path): - if os.path.splitext(fname)[1] == ".pdmodel": + if os.path.splitext(fname)[1] == f"{PADDLE_INFERENCE_MODEL_SUFFIX}": model_files.append(fname) - if os.path.splitext(fname)[1] == ".pdiparams": + if os.path.splitext(fname)[1] == f"{PADDLE_INFERENCE_WEIGHTS_SUFFIX}": param_files.append(fname) def _check_and_get_file(files, tag): @@ -184,8 +190,8 @@ def _generate_comm_init_config(self, rank, nranks): def _init_predictor(self): if self.auto: - self.model_file = os.path.join(self.model_dir, "auto_dist{}.pdmodel".format(self.rank)) - self.param_file = os.path.join(self.model_dir, "auto_dist{}.pdiparams".format(self.rank)) + self.model_file = os.path.join(self.model_dir, f"auto_dist{self.rank}{PADDLE_INFERENCE_MODEL_SUFFIX}") + self.param_file = os.path.join(self.model_dir, f"auto_dist{self.rank}{PADDLE_INFERENCE_WEIGHTS_SUFFIX}") config = paddle.inference.Config(self.model_file, self.param_file) config.enable_memory_optim()