Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

手动调用mode()方法和直接调用model.generate()方法,输出的结果十分不同 #1328

Open
1 of 2 tasks
rooikeee opened this issue Oct 17, 2024 · 0 comments
Open
1 of 2 tasks

Comments

@rooikeee
Copy link

System Info / 系統信息

CUDA: 12.6
Transformer: 4.41.0
python: 3.12.3
model: ChatGLM3-6b-8k

Who can help? / 谁可以帮助到您?

No response

Information / 问题信息

  • The official example scripts / 官方的示例脚本
  • My own modified scripts / 我自己修改的脚本和任务

Reproduction / 复现过程

我使用了两种generate方法,第一种方法如下:

output = model.generate(
        **input, 
        max_num_tokens=1,
        num_beams =1,
        do_sample=False,
        temperature=1.0
)[0]

得到的输出如下:
微信图片_20241018000231

第二种方法如下:

            with torch.no_grad():
                # prefill
                output = model(
                    input_ids=input.input_ids,
                    past_key_values=None,
                    use_cache=True,
                )
                past_key_values = output.past_key_values
                pred_token_idx = output.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
                generated_content = [pred_token_idx.item()]
                
                # decode
                for _ in range(max_gen - 1):
                    outputs = model(
                        input_ids=pred_token_idx,
                        past_key_values=past_key_values,
                        use_cache=True,
                    )

                    past_key_values = outputs.past_key_values
                    pred_token_idx = (
                        outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
                    )
                    generated_content += [pred_token_idx.item()]
                    if pred_token_idx.item() == tokenizer.eos_token_id:
                        break

得到的输出为:
微信图片_20241018000515

Expected behavior / 期待表现

两种方法的输入都是相同的,按理说得到的输出相差应该不会太大。请问大佬我忽略了那些因素?正确的调用方法应该是什么?谢谢大佬解答。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant