Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix-issue-415-integer-type-error #420

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 53 additions & 51 deletions src/llm_vm/guided_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import torch
from lark import Lark, Transformer, v_args
from lark.indenter import PythonIndenter
from transformers import (AutoModelForCausalLM, LogitsProcessorList, LogitsProcessor, AutoTokenizer)
from transformers import AutoModelForCausalLM, LogitsProcessorList, LogitsProcessor, AutoTokenizer
import re
from abc import ABC,abstractmethod
from abc import ABC, abstractmethod

model = models.transformers("gpt2")

Expand Down Expand Up @@ -37,12 +37,12 @@ def create(regex, type, choices, grammar_type, *, default=None):
class GenerativeCompletion(Completion):
def __init__(self, generator, *generator_args):
"""
Parameters:
Parameters:
-----------

generator (Callable[[Transformers, ...generator_args], None]): Generator function to be used on the complete
*generator_args (Any): Generator arguments (without model)

"""
self.generator = generator
self.generator_args = generator_args
Expand All @@ -62,7 +62,7 @@ def choices_completion(choices):
def type_completion(type_name):
if type_name not in ["float", "integer"]:
raise Exception("type must be float or integer")
return GenerativeCompletion(getattr(generate, type_name))
return GenerativeCompletion(generate.format, type_name)

@staticmethod
def response_completion():
Expand All @@ -78,7 +78,7 @@ def __init__(self, model_uri, tokenizer, grammar_type='python'):
# Load model from HuggingFace
self.model = AutoModelForCausalLM.from_pretrained(self.model_identifier)
print(f"{self.model_identifier} model is ready for use on {self.model.device}", flush=True)

self.constraint = GrammarConstraint.create(grammar_type, model_uri, tokenizer)

# Initialize dict to store terminal symbols and their corresponding token mappings
Expand All @@ -87,7 +87,7 @@ def __init__(self, model_uri, tokenizer, grammar_type='python'):
@property
def eos_token_id(self):
return self.model.config.eos_token_id

def complete(self, prompt, **model_kwargs):
# Check if the specified grammar type is supported

Expand Down Expand Up @@ -115,12 +115,18 @@ def complete(self, prompt, **model_kwargs):
model_kwargs["logits_processor"] = logits_processor

# Generate completion using the model
res = self.model.generate(**model_kwargs, input_ids=token_ids['input_ids'], attention_mask=token_ids['attention_mask'], eos_token_id=self.eos_token_id, pad_token_id=self.eos_token_id)
res = self.model.generate(
**model_kwargs,
input_ids=token_ids['input_ids'],
attention_mask=token_ids['attention_mask'],
eos_token_id=self.eos_token_id,
pad_token_id=self.eos_token_id,
)

# Decode completion and return the generated text, including input tokens
res_text = self.tokenizer.batch_decode(res.sequences, skip_special_tokens=True)[0]
return res_text.strip()


class GrammarConstraint(ABC):
# Initialize the GrammarConstraint class with the model URI and tokenizer
Expand Down Expand Up @@ -153,8 +159,7 @@ def create(grammar_type, model_uri, tokenizer):
return grammars[grammar_type](model_uri, tokenizer)


class PythonConstraint(GrammarConstraint):

class PythonConstraint(GrammarConstraint):
# Construct a filter set for valid tokens based on a given expression
def construct_filter_set(self, expression):
vocab = self.tokenizer.vocab
Expand All @@ -164,24 +169,26 @@ def construct_filter_set(self, expression):
# Preprocess expression to handle special cases
if expression[0] in specials:
expression = f'\{expression}'
elif len(expression) == 1 and (expression == '[' or expression == '('):
elif len(expression) == 1 and (expression == '[' or expression == '('):
expression = f'\{expression}'

try:
# Compile the regex pattern and use it to match valid tokens in the vocabulary
pattern = re.compile(expression, re.UNICODE)
for token, id in vocab.items():
if pattern.match(token) is not None:
valid_tokens.append((token, id))
except Exception as e:
print("Regex Compiling Error: ", f"{e} - {expression}")
valid_tokens.append((token, id))
except Exception as e:
print("Regex Compiling Error: ", f"{e} - {expression}")

# Return a set of valid tokens based on the pattern
return set(valid_tokens)

def parse_grammar(self):
# Create the Python parser with Lark, using the LALR algorithm
self.parser = Lark.open_from_package('lark', 'python.lark', ['grammars'], parser='lalr', regex=True, lexer='contextual', postlex=PythonIndenter(), start='file_input')
self.parser = Lark.open_from_package(
'lark', 'python.lark', ['grammars'], parser='lalr', regex=True, lexer='contextual', postlex=PythonIndenter(), start='file_input'
)
terminals = self.parser.terminals
t_map = {}
for t in terminals:
Expand All @@ -193,12 +200,12 @@ def parse_grammar(self):
def _prefix_state(self, prefix_str=None, last_token=None):
valid_next = []
if self._parser_state is None:
try:
try:
# Parse the entire token sequence
interactive_tree = self.parser.parse_interactive(prefix_str)
interactive_tree.exhaust_lexer()
self._parser_state = interactive_tree.copy()

# Get the valid next states
valid_next = list(interactive_tree.accepts())
except Exception as e:
Expand All @@ -220,7 +227,7 @@ def _prefix_state(self, prefix_str=None, last_token=None):

def construct_final_filter_set(self, prefix_ids, terminals_map):
valid_next_ids = []

if self._copy_state == True:
# Decode only the last prefix ID
last_token = self.tokenizer.batch_decode(prefix_ids[:, -1], skip_special_tokens=True)[0]
Expand All @@ -239,13 +246,12 @@ def construct_final_filter_set(self, prefix_ids, terminals_map):
for t in token_set:
# Add valid token IDs to the list
valid_next_ids.append(t[-1])

# Return a set of valid next token IDs
return set(valid_next_ids)


class JSONConstraint(GrammarConstraint):

class JSONConstraint(GrammarConstraint):
# Construct a filter set for valid tokens based on a given expression
def construct_filter_set(self, expression):
vocab = self.tokenizer.vocab
Expand All @@ -255,21 +261,21 @@ def construct_filter_set(self, expression):
# Preprocess expression to handle special cases
if expression[0] in specials:
expression = f'\{expression}'
elif len(expression) == 1 and (expression == '[' or expression == '('):
elif len(expression) == 1 and (expression == '[' or expression == '('):
expression = f'\{expression}'

try:
# Compile the regex pattern and use it to match all valid tokens in the vocabulary
pattern = re.compile(expression, re.UNICODE)
for token, id in vocab.items():
if pattern.match(token) is not None:
valid_tokens.append((token, id))
except Exception as e:
print(e, expression)
valid_tokens.append((token, id))
except Exception as e:
print(e, expression)

# Return a set of valid tokens based on the pattern
return set(valid_tokens)

def parse_grammar(self):
# Define JSON grammar
json_grammar = r"""
Expand All @@ -295,8 +301,8 @@ def parse_grammar(self):

%ignore WS
"""
# Create Lark internal transformer to make parsing faster and more memory efficient

# Create Lark internal transformer to make parsing faster and more memory efficient
class TreeToJson(Transformer):
@v_args(inline=True)
def string(self, s):
Expand All @@ -311,28 +317,25 @@ def string(self, s):
true = lambda self, _: True
false = lambda self, _: False


# Create the JSON parser with Lark, using the LALR algorithm
self.parser = Lark(json_grammar, parser='lalr',
lexer='contextual',
transformer=TreeToJson())
self.parser = Lark(json_grammar, parser='lalr', lexer='contextual', transformer=TreeToJson())
terminals = self.parser.terminals
t_map = {}
for t in terminals:
t_map[t.name] = t.pattern.value

# Return a map of terminal tokens and their corresponding regex patterns
return t_map

def _prefix_state(self, prefix_str=None, last_token=None):
valid_next = []
if self._parser_state is None:
try:
try:
# Parse the entire token sequence
interactive_tree = self.parser.parse_interactive(prefix_str)
interactive_tree.exhaust_lexer()
self._parser_state = interactive_tree.copy()

# Get the valid next states
valid_next = list(interactive_tree.accepts())
except Exception as e:
Expand All @@ -354,7 +357,7 @@ def _prefix_state(self, prefix_str=None, last_token=None):

def construct_final_filter_set(self, prefix_ids, terminals_map):
valid_next_ids = []

if self._copy_state == True:
# Decode only the last prefix ID
last_token = self.tokenizer.batch_decode(prefix_ids[:, -1], skip_special_tokens=True)[0]
Expand All @@ -373,33 +376,32 @@ def construct_final_filter_set(self, prefix_ids, terminals_map):
for t in token_set:
# Add valid token IDs to the list
valid_next_ids.append(t[-1])

# Return a set of valid next token IDs
return set(valid_next_ids)


class GrammarLogitsProcessor(LogitsProcessor):

def __init__(self, constraint_class, terminals_map):
# Initialize the GrammarLogitsProcessor class with a constraint class and terminals map.
self.constraint_class = constraint_class
self.terminals_map = terminals_map
self.constraint_class = constraint_class
self.terminals_map = terminals_map

def __call__(self, input_ids, scores):
# This method is called for each generation step

# Initialize a boolean bias tensor with the same shape as scores
bias = torch.zeros_like(scores, dtype=torch.bool)

# Get the set of valid next token IDs for current step
valid_next_ids = self.constraint_class.construct_final_filter_set(input_ids, self.terminals_map)

# Set the bias to True for valid next token IDs
for id in valid_next_ids:
bias[0, id] = True

# Add the bias to the scores tensor to zero-out invalid next tokens
scores += bias

# Return the modified scores tensor
return scores