Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Infinite loop while using xgrammar with specific grammar #127

Open
roG0d opened this issue Dec 12, 2024 · 1 comment
Open

[Bug] Infinite loop while using xgrammar with specific grammar #127

roG0d opened this issue Dec 12, 2024 · 1 comment

Comments

@roG0d
Copy link

roG0d commented Dec 12, 2024

Env

xgrammar == 0.1.6
CUDA == 12.4
torch == 2.4.0
transformers==4.46.1

Issue

Hitting a infinite loop using the specific grammar provided in the code. (WIP: Trying to debug and understand the error)

Context

We (@antferdom) are researching different gramar-based decoding techniques for code generation. In this specific setting we're trying to generate a DSL called UVL. A reiterative problem is that is indent-sensitive so context-dependent, if It can be fixed I believe it could unlock Python codegen as well.

Errors

When leaving generation, no error is triggered rather than an infinite execution. When the processs is stopped with SIGINT (CTRL + C) the error output is:

File "/root/ebnf_xgrammar.py", line 110, in <module>
    generated_ids = model.generate(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 2215, in generate
    result = self._sample(
  File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 3223, in _sample
    next_token_scores = logits_processor(input_ids, next_token_logits)
  File "/usr/local/lib/python3.10/dist-packages/transformers/generation/logits_process.py", line 104, in __call__
    scores = processor(input_ids, scores)
  File "/usr/local/lib/python3.10/dist-packages/xgrammar/contrib/hf.py", line 90, in __call__
    self.matchers[i].fill_next_token_bitmask(self.token_bitmask, i)
  File "/usr/local/lib/python3.10/dist-packages/xgrammar/matcher.py", line 211, in fill_next_token_bitmask
    self._handle.fill_next_token_bitmask(bitmask.data_ptr(), list(bitmask.shape), index)

Code snippet for reproducibility

# from: https://xgrammar.mlc.ai/docs/how_to/ebnf_guided_generation.html#try-out-via-hf-transformers
import xgrammar as xgr

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, set_seed
import random

r1 = 2 #random.randint(0, 1e5)
set_seed(r1)
print(f"seed: {r1}")
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "meta-llama/Llama-3.2-3B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=torch.float32, device_map=device
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)

messages = [
    {"role": "system", "content": """
You are a helpful AI assistant for creating sintactically correct expressions similar to these example:
features
INDENT
    SmartWatch
    INDENT
        mandatory
        INDENT
            Functionalities
            INDENT
                mandatory
                INDENT
                    FitnessMonitor
                    SleepTracker
                    VibrateAlert
                    INDENT
                        mandatory
                        INDENT
                            Call
                            Notification
                        DEDENT
                    DEDENT
                DEDENT
            DEDENT
            Sensors
            INDENT
                mandatory
                INDENT
                    Pedometer
                    Accelerometer
                DEDENT
                optional
                INDENT
                    HeartRateSensor
                DEDENT
            DEDENT
            Connectivity
            INDENT
                mandatory
                INDENT
                    BT40
                DEDENT
            DEDENT
        DEDENT
    DEDENT
DEDENT
     """},

    {"role": "user", "content": """       
    Create a new one representing a computer
     """},
]
texts = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
model_inputs = tokenizer(texts, return_tensors="pt").to(model.device)

tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=config.vocab_size)
grammar_compiler = xgr.GrammarCompiler(tokenizer_info)

uvl_base = r"""
root ::= "\n features \n\t" feature ws
feature ::= reference "\n"( ("\t")+ group)? 
group    ::= ( "or " | "alternative " | "optional " | "mandatory ") ("\n" ("\t")+  feature)+ 
reference ::= [a-zA-Z_] [a-zA-Z_0-9]*
ws ::= [ \t\n]
"""

compiled_grammar = grammar_compiler.compile_grammar(uvl_base)

xgr_logits_processor = xgr.contrib.hf.LogitsProcessor(compiled_grammar)

import datetime
init_time = datetime.datetime.now()

generated_ids = model.generate(
    **model_inputs, max_new_tokens=500, logits_processor=[xgr_logits_processor]
)
generated_ids = generated_ids[0][len(model_inputs.input_ids[0]) :] 
print(tokenizer.decode(generated_ids, skip_special_tokens=False))

print(f"generation elapsed time :{datetime.datetime.now()-init_time}")
@Ubospica
Copy link
Collaborator

Ubospica commented Dec 25, 2024

Hi @roG0d, thanks for the bug report. Sorry for the late response as I just finished my travel.

After digging into your script I think that is mainly because there is too much non-determinism in the grammar (that means, for the same input string, there can be multiple interpretations according to the grammar). To much nondeterminism means multiple (maybe exponential) parsing stacks, which increases overhead at runtime. (Check out the "Why is it hard to accelerate general CFGs?" section in our blog.

Specifically in your grammar,

feature ::= reference "\n" ( ("\t")+ group )?
group ::= ( "or " | "alternative " | "optional " | "mandatory ") ("\n" ("\t")+  feature)+ 
reference ::= [a-zA-Z_] [a-zA-Z_0-9]*

group and feature overlap because or, alternative, etc. can be interpreted as reference or group keywords. So for this input

feature1
	or 
		feature2

or can be treated as a feature or a group start. When the number of nesting level increases, the number of stacks could significantly increase.

Currently, you can try rewriting your grammar to reduce such nondeterminism.

We are also looking into ways to alleviate this issue through automatic optimization. Please stay tuned!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants