diff --git a/guidance/_grammar.py b/guidance/_grammar.py index 7500b54a3..16387ae4b 100644 --- a/guidance/_grammar.py +++ b/guidance/_grammar.py @@ -1,9 +1,8 @@ import re import types -from typing import Any, Dict, List, TYPE_CHECKING, TypeVar, Union +from typing import Any, TYPE_CHECKING, TypeVar, Union, cast, Optional -from . import _serialization_pb2 from . import _parser _T = TypeVar("_T") @@ -11,11 +10,12 @@ # to support the embedding of guidance functions inside Python f-strings we use tags with these delimiters tag_start = "{{G|" # start of a call tag tag_end = "|G}}" # end of a call tag -_call_pool: Dict[str, "Function"] = {} # the functions associated with the call tags +_call_pool: dict[str, "Function"] = {} # the functions associated with the call tags _tag_pattern = re.compile( re.escape(tag_start) + r"([^\|]+)" + re.escape(tag_end) ) # the pattern for matching call tags + class StatefulException(Exception): """This is raised when we try and use the state of a grammar object like it was a live model. @@ -49,13 +49,6 @@ def __str__(self): # return a string representation of this call so it can be combined with other strings/calls return tag_start + str_id + tag_end - def serialize(self): - raise NotImplementedError() - - @classmethod - def deserialize(cls, serialized_grammar): - raise NotImplementedError() - class RawFunction(Function): __slots__ = ("f", "args", "kwargs") @@ -179,21 +172,27 @@ def match( ) -> Union[Match, None]: if isinstance(byte_string, str): byte_string = byte_string.encode() - parser = _parser.EarleyCommitParser(self) - - for i in range(len(byte_string)): - try: - parser.consume_byte(byte_string[i : i + 1]) - except _parser.ParserException: - if raise_exceptions: - raise - else: - return None + parser = _parser.ByteParser(self) + + try: + parser.consume_bytes(byte_string) + except _parser.ByteParserException: + if raise_exceptions: + raise + else: + return None if not allow_partial and not parser.matched(): return None - else: - return Match(*parser.get_captures(), partial=not parser.matched()) # type: ignore[misc] + + if parser.matched(): + parser.force_done() + + return Match(*parser.get_captures(), partial=not parser.matched()) # type: ignore[misc] + + def forced_prefix(self) -> str: + parser = _parser.ByteParser(self) + return parser.bytes.decode("utf-8", errors="ignore") @staticmethod def _new_name(): @@ -222,72 +221,8 @@ def gbnf_string(self): lines.append("root ::= " + root_name) return "\n".join(lines) - def serialize(self): - g = _serialization_pb2.Grammar() - index_map = {} - nodes = {} - self._rec_create_index_map(index_map) # gives all the nodes an index - self._rec_serialize(index_map, nodes) # nodes is filled in (as is index_map) - g.nodes.extend(list(nodes.values())) - return g.SerializeToString() - - def _rec_create_index_map(self, index_map): - if self not in index_map: - index_map[self] = len(index_map) - if hasattr(self, "values"): - for value in self.values: - value._rec_create_index_map(index_map) - - def _rec_serialize(self, index_map, nodes): - if self not in nodes: - v = self._to_proto(index_map) - node = _serialization_pb2.GrammarFunction() - if isinstance(self, Byte): - node.byte.CopyFrom(v) - elif isinstance(self, ByteRange): - node.byte_range.CopyFrom(v) - elif isinstance(self, Select): - node.select.CopyFrom(v) - elif isinstance(self, Join): - node.join.CopyFrom(v) - elif isinstance(self, ModelVariable): - node.model_variable.CopyFrom(v) - else: - raise Exception("Unknown node type") - nodes[self] = node - if hasattr(self, "values"): - for value in self.values: - value._rec_serialize(index_map, nodes) - - @classmethod - def deserialize(cls, serialized_grammar): - g = _serialization_pb2.Grammar() - g.ParseFromString(serialized_grammar) - - # create the list of objects - values = [] - for node in g.nodes: - if node.HasField("byte"): - node = Byte._from_proto(node.byte) - elif node.HasField("byte_range"): - node = ByteRange._from_proto(node.byte_range) - elif node.HasField("select"): - node = Select._from_proto(node.select) - elif node.HasField("join"): - node = Join._from_proto(node.join) - elif node.HasField("model_variable"): - node = ModelVariable._from_proto(node.model_variable) - else: - raise Exception("Unknown node type") - values.append(node) - - # fill in the values pointers now that we have the full list of objects - for v in values: - if hasattr(v, "values"): - for i, index in enumerate(v.values): - v.values[i] = values[index] - - return values[0] # the first element in the root node of the grammar + def ll_serialize(self): + return {"grammars": LLSerializer().run(self)} class Terminal(GrammarFunction): @@ -300,14 +235,12 @@ def max_tokens(self): class Byte(Terminal): - __slots__ = ("byte", "hidden", "commit_point", "capture_name", "temperature") + __slots__ = ("byte", "capture_name", "temperature") def __init__(self, byte): assert isinstance(byte, bytes) assert len(byte) == 1 self.byte = byte - self.hidden = False - self.commit_point = False self.capture_name = None self.temperature = -1 @@ -330,38 +263,14 @@ def __len__(self): def match_byte(self, byte): return byte == self.byte - @property - def nullable(self): - return False - - def _to_proto(self, index_map): - data = _serialization_pb2.Byte() - data.byte = self.byte - data.hidden = self.hidden - data.commit_point = self.commit_point - data.capture_name = "" if self.capture_name is None else self.capture_name - data.temperature = self.temperature - return data - - @staticmethod - def _from_proto(data): - out = Byte(data.byte) - out.hidden = data.hidden - out.commit_point = data.commit_point - out.capture_name = None if data.capture_name == "" else data.capture_name - out.temperature = data.temperature - return out - class ByteRange(Terminal): - __slots__ = ("byte_range", "hidden", "commit_point", "capture_name", "temperature") + __slots__ = ("byte_range", "capture_name", "temperature") def __init__(self, byte_range): assert isinstance(byte_range, bytes) assert len(byte_range) == 2 self.byte_range = byte_range - self.hidden = False - self.commit_point = False self.capture_name = None self.temperature = -1 # -1 means not set @@ -376,9 +285,6 @@ def name(self): def name(self, value): pass # we ignore name changes - @property - def nullable(self): - return False def __hash__(self): return self.byte_range[0] + 256 * self.byte_range[1] @@ -396,34 +302,12 @@ def __repr__(self) -> str: def __len__(self): return 1 - def _to_proto(self, index_map): - data = _serialization_pb2.ByteRange() - data.byte_range = self.byte_range - data.hidden = self.hidden - data.commit_point = self.commit_point - data.capture_name = "" if self.capture_name is None else self.capture_name - data.temperature = self.temperature - return data - - @staticmethod - def _from_proto(data): - out = ByteRange(data.byte_range) - out.hidden = data.hidden - out.commit_point = data.commit_point - out.capture_name = None if data.capture_name == "" else data.capture_name - out.temperature = data.temperature - return out - class Null(Terminal): - __slots__ = ("name", "hidden", "commit_point", "capture_name") - - nullable = True + __slots__ = ("name", "capture_name") def __init__(self): - self.name = None - self.hidden = False - self.commit_point = False + self.name = "ε" self.capture_name = None def __add__(self, other): @@ -450,30 +334,11 @@ class ModelVariable(GrammarFunction): will get replaced with. """ - __slots__ = ("name", "hidden", "commit_point", "capture_name") + __slots__ = ("name", "capture_name") def __init__(self, name): self.name = name - self.hidden = False - self.commit_point = False self.capture_name = None - self.nullable = False - - def _to_proto(self, index_map): - data = _serialization_pb2.ModelVariable() - data.hidden = self.hidden - data.name = self.name - data.commit_point = self.commit_point - data.capture_name = "" if self.capture_name is None else self.capture_name - return data - - @staticmethod - def _from_proto(data): - out = ModelVariable(data.name) - out.hidden = data.hidden - out.commit_point = data.commit_point - out.capture_name = None if data.capture_name == "" else data.capture_name - return out def replace_grammar_node(grammar, target, replacement): @@ -558,10 +423,6 @@ def replace_model_variables(grammar, model, allowed_vars=None): obj = None if obj is not None: replacement_value = _wrap_as_grammar(getattr(obj, value.name)) - if value.commit_point: - replacement_value = commit_point( - replacement_value, hidden=value.hidden - ) replacements.append( (current, i, value) ) # Record the replacement @@ -630,41 +491,27 @@ def commit_point(value, hidden=False): Not that commit point nodes can be optionally hidden (in fact they are the only nodes that can be hidden since they are by definition not impacted by multiple possible inconsistent parses.)""" + raise NotImplementedError("commit_point is not implemented (may remove in the future)") + + +def as_regular_grammar(value): # TODO: assert that value is not empty since we don't yet support that if isinstance(value, str): value = string(value) - if isinstance(value, Terminal): - value = Join( - [value] - ) # commit points should be full nodes (otherwise we can't hide them) TODO: decide if we want to do this even for non-hidden commit points - value.commit_point = True - if hidden: - _rec_hide(value) - return value - - -def _rec_hide(grammar): - if not grammar.hidden: - grammar.hidden = True - if hasattr(grammar, "values"): - for g in grammar.values: - _rec_hide(g) + # check if it serializes + _ignore = LLSerializer().regex(value) + return RegularGrammar(value) class Placeholder(GrammarFunction): - __slots__ = tuple("nullable") - def __init__(self): - self.nullable = False + pass class Join(GrammarFunction): __slots__ = ( - "nullable", "values", "name", - "hidden", - "commit_point", "capture_name", "max_tokens", ) @@ -675,11 +522,8 @@ def __init__( values = [ string(v) if isinstance(v, (str, bytes)) else v for v in values ] # wrap raw strings - self.nullable = all(getattr(v, "nullable", False) for v in values) self.values = [v for v in values if not isinstance(v, Null)] self.name = name if name is not None else GrammarFunction._new_name() - self.hidden = False - self.commit_point = False self.capture_name = None self.max_tokens = max_tokens @@ -689,8 +533,6 @@ def __repr__(self, indent="", done=None): s = self.name.ljust(20) + " <- " + " ".join([v.name for v in self.values]) s += ( " " - + ("hidden " if self.hidden else "") - + ("commit_point " if self.commit_point else "") + (f"capture_name={self.capture_name} " if self.capture_name else "") + (f"max_tokens={self.max_tokens}" if self.max_tokens < 100000 else "") + "\n" @@ -701,39 +543,131 @@ def __repr__(self, indent="", done=None): s += v.__repr__(indent, done) return s - def _to_proto(self, index_map): - data = _serialization_pb2.Join() - data.nullable = self.nullable - for v in self.values: - data.values.append(index_map[v]) - data.name = self.name - data.hidden = self.hidden - data.commit_point = self.commit_point - data.capture_name = "" if self.capture_name is None else self.capture_name - data.max_tokens = self.max_tokens - return data - @staticmethod - def _from_proto(data): - out = Join( - data.values, # we put ints in that will be replaced later by the deserialize method - name=data.name, - max_tokens=data.max_tokens, +def quote_regex(value: str) -> str: + assert isinstance(value, str) + return re.sub(r"([\\+*?^$(){}\[\]\.|])", r"\\\1", value) + + +class Gen(Terminal): + __slots__ = ( + "body_regex", + "stop_regex", + "save_stop_text", + "name", + "capture_name", + "_max_tokens", + ) + + def __init__( + self, + body_regex: str, + stop_regex: str, + name: Union[str, None] = None, + save_stop_text: Optional[str] = None, + max_tokens=100000000, + ) -> None: + self.body_regex = body_regex + self.stop_regex = stop_regex + self.name = name if name is not None else GrammarFunction._new_name() + self.capture_name = None + self.save_stop_text = save_stop_text + self._max_tokens = max_tokens + self.temperature = -1 + + @property + def max_tokens(self): + return self._max_tokens + + def __repr__(self, indent="", done=None, lbl="Gen"): + if done is None: + done = set() + s = ( + self.name.ljust(20) + + " <- " + + lbl + + " " + + repr(self.body_regex) + + " + " + + repr(self.stop_regex) + ) + s += ( + " " + + (f"capture_name={self.capture_name} " if self.capture_name else "") + + (f"max_tokens={self.max_tokens}" if self.max_tokens < 100000 else "") + + "\n" ) - out.nullable = data.nullable - out.hidden = data.hidden - out.commit_point = data.commit_point - out.capture_name = None if data.capture_name == "" else data.capture_name - return out + done.add(self) + return s + + +class Lexeme(Gen): + __slots__ = ("contextual",) + + def __init__( + self, + body_regex: str, + contextual: bool = False, + name: Union[str, None] = None, + max_tokens=100000000, + ) -> None: + super().__init__(body_regex, "", name=name, max_tokens=max_tokens) + self.contextual = contextual + + def __repr__(self, indent="", done=None): + return super().__repr__(indent, done, "Lex") + + +class RegularGrammar(Gen): + __slots__ = ("grammar",) + + def __init__( + self, + grammar: GrammarFunction, + name: Union[str, None] = None, + max_tokens=100000000, + ) -> None: + super().__init__("", "", name=name, max_tokens=max_tokens) + self.grammar = grammar + + def __repr__(self, indent="", done=None): + # TODO add grammar repr + return super().__repr__(indent, done, "RegularGrammar") + + +class Subgrammar(Gen): + __slots__ = ( + "body", + "skip_regex", + "no_initial_skip", + ) + + def __init__( + self, + body: GrammarFunction, + skip_regex: Optional[str] = None, + no_initial_skip: bool = False, + name: Union[str, None] = None, + max_tokens=100000000, + ) -> None: + super().__init__( + body_regex="", + stop_regex="", + name=name, + max_tokens=max_tokens, + ) + self.body = body + self.skip_regex = skip_regex + self.no_initial_skip = no_initial_skip + + def __repr__(self) -> str: # type: ignore[override] + return self.name.ljust(20) + " <- " + self.body.name class Select(GrammarFunction): __slots__ = ( - "nullable", "_values", "name", - "hidden", - "commit_point", "capture_name", "max_tokens", "recursive", @@ -744,8 +678,6 @@ def __init__( ) -> None: self.values = values self.name = name if name is not None else GrammarFunction._new_name() - self.hidden = False - self.commit_point = False self.capture_name = capture_name self.max_tokens = max_tokens self.recursive = recursive @@ -757,8 +689,6 @@ def values(self): @values.setter def values(self, vals): self._values = [string(v) if isinstance(v, (str, bytes)) else v for v in vals] - self.nullable = any(getattr(v, "nullable", False) for v in self._values) - self._values = [v for v in self._values if not isinstance(v, Null)] def __repr__(self, indent="", done=None): if done is None: @@ -766,8 +696,6 @@ def __repr__(self, indent="", done=None): s = self.name.ljust(20) + " <- " + " | ".join([v.name for v in self.values]) s += ( " " - + ("hidden " if self.hidden else "") - + ("commit_point " if self.commit_point else "") + (f"max_tokens={self.max_tokens}" if self.max_tokens < 100000 else "") + "\n" ) @@ -777,36 +705,8 @@ def __repr__(self, indent="", done=None): s += v.__repr__(indent, done) return s - def _to_proto(self, index_map): - data = _serialization_pb2.Select() - data.nullable = self.nullable - for v in self.values: - data.values.append(index_map[v]) - data.name = self.name - data.hidden = self.hidden - data.commit_point = self.commit_point - data.capture_name = "" if self.capture_name is None else self.capture_name - data.max_tokens = self.max_tokens - data.recursive = self.recursive - return data - - @staticmethod - def _from_proto(data): - out = Select( - data.values, # we put ints in that will be replaced later by the deserialize method - name=data.name, - max_tokens=data.max_tokens, - ) - out.nullable = data.nullable - out.hidden = data.hidden - out.commit_point = data.commit_point - out.capture_name = None if data.capture_name == "" else data.capture_name - out.recursive = data.recursive - return out - - -def string(value: Union[str, bytes]) -> Union[Null, Byte, Join]: +def string(value: Union[str, bytes]) -> Union[Null, Join]: if isinstance(value, str): b = bytes(value, encoding="utf8") elif isinstance(value, bytes): @@ -815,14 +715,12 @@ def string(value: Union[str, bytes]) -> Union[Null, Byte, Join]: raise Exception("Must pass bytes or str to the string() function!") if len(value) == 0: return Null() - elif len(b) == 1: - return Byte(b) else: return Join([Byte(b[i : i + 1]) for i in range(len(b))], name=str(b)) def select( - options: List[_T], name=None, list_append=False, recurse=False, skip_checks=False + options: list[_T], name=None, list_append=False, recurse=False, skip_checks=False ) -> Union[Select, _T]: """Choose between a set of options. @@ -841,7 +739,7 @@ def select( If this is not None then the the results of the generation will be saved as a variable on the Model object (so you can access the result as `lm["var_name"]`). - options : List + options : list The set of available choices for the next generation list_append : bool @@ -952,7 +850,7 @@ def _re_with_temperature(grammar, temperature, visited_set): # if getattr(grammar, "temperature", 100000000) > temperature: if ( - isinstance(grammar, Terminal) and grammar.temperature < 0 + isinstance(grammar, Terminal) and not isinstance(grammar, Null) and grammar.temperature < 0 ): # only need to set temp for terminals grammar.temperature = temperature elif getattr(grammar, "temperature", 100000000) > temperature and hasattr( @@ -1019,3 +917,302 @@ def str_to_grammar(value: str): partial_grammar += string(part) is_id = not is_id return partial_grammar + + +def _is_string_literal(node: GrammarFunction): + if isinstance(node, Byte): + return True + if isinstance(node, Join): + return all(_is_string_literal(v) for v in node.values) + return False + + +class LLSerializer: + def __init__(self) -> None: + self.nodes: list[dict] = [] + self.curr_grammar = { + "nodes": self.nodes, + "rx_nodes": [], + } + self.grammars = [self.curr_grammar] + self.node_id_cache: dict[GrammarFunction, int] = {} + self.todo: list[GrammarFunction] = [] + self.grammar_id_cache: dict[Subgrammar, int] = {} + self.grammar_todo: list[Subgrammar] = [] + + self.regex_id_cache: dict[GrammarFunction, int] = {} + + def _add_regex_json(self, json): + id = len(self.curr_grammar["rx_nodes"]) + self.curr_grammar["rx_nodes"].append(json) + return id + + def _add_regex(self, key: str, val): + return self._add_regex_json({key: val}) + + def _regex_or(self, nodes: list[GrammarFunction]): + if len(nodes) == 1: + return self.regex_id_cache[nodes[0]] + else: + return self._add_regex("Or", [self.regex_id_cache[v] for v in nodes]) + + def regex(self, node: GrammarFunction): + """ + Serialize node as regex. Throws if impossible. + """ + + node0 = node + todo = [node] + pending: set[GrammarFunction] = set() + + def node_finished(node: GrammarFunction): + return node not in pending and node in self.regex_id_cache + + def all_finished(nodes): + return all(node_finished(v) for v in nodes) + + def add_todo(n: GrammarFunction): + if n in pending: + raise ValueError( + "GrammarFunction is recursive - cannot serialize as regex: " + + n.__repr__() + ) + todo.append(n) + + def add_todos(nodes): + for n in nodes: + add_todo(n) + + def check_unserializable_attrs(node: GrammarFunction): + if not isinstance(node, Terminal): + for v in getattr(node, "values", []): + # Only check one level deeper as we'll soon be processing the children + if isinstance(v, Terminal): + check_unserializable_attrs(v) + + if getattr(node, "capture_name", None) is not None: + raise ValueError( + f"Regex serialization does not support captures. Node: {node.__repr__()}" + ) + if getattr(node, "temperature", -1) >= 0: + raise ValueError( + f"Regex serialization does not support temperature. Node: {node.__repr__()}" + ) + + while todo: + node = todo.pop() + check_unserializable_attrs(node) + + if node in self.regex_id_cache: + continue + if isinstance(node, Select) and node.values: + with_node = [] + without_node = [] + for v in node.values: + if ( + isinstance(v, Join) + and len(v.values) == 2 + and v.values[0] is node + ): + with_node.append(v.values[1]) + else: + without_node.append(v) + if not all_finished(with_node) or not all_finished(without_node): + add_todo(node) + pending.add(node) + add_todos(with_node) + add_todos(without_node) + continue + #print(with_node, without_node) + if len(with_node) == 0: + # non-recursive + res = self._regex_or(without_node) + elif len(without_node) == 1 and isinstance(without_node[0], Null): + # zero_or_more() + inner = self._regex_or(with_node) + res = self._add_regex("Repeat", [inner, 0, None]) + elif with_node == without_node: + # one_or_more() + inner = self._regex_or(with_node) + res = self._add_regex("Repeat", [inner, 1, None]) + else: + raise ValueError( + "Cannot detect structure of recursive Select as regex: " + + node.__repr__() + ) + elif isinstance(node, Join): + if all(isinstance(v, Byte) for v in node.values): + literal = [cast(Byte, v).byte[0] for v in node.values] + try: + literal_ = bytes(literal).decode("utf-8", errors="strict") + res = self._add_regex("Literal", literal_) + except UnicodeDecodeError: + res = self._add_regex("ByteLiteral", literal) + else: + if not all_finished(node.values): + add_todo(node) + pending.add(node) + add_todos(node.values) + continue + res = self._add_regex( + "Concat", [self.regex_id_cache[v] for v in node.values] + ) + elif isinstance(node, Byte): + res = self._add_regex("Byte", node.byte[0]) + elif isinstance(node, ByteRange): + byteset = [0, 0, 0, 0, 0, 0, 0, 0] + for idx in range(256): + if node.match_byte(bytes([idx])): + byteset[idx // 32] |= 1 << (idx % 32) + res = self._add_regex("ByteSet", byteset) + elif isinstance(node, Null): + res = self._add_regex_json("EmptyString") + elif isinstance(node, Lexeme): + res = self._add_regex("Regex", node.body_regex) + else: + raise ValueError("Cannot serialize as regex: " + node.__repr__()) + if node in pending: + pending.remove(node) + self.regex_id_cache[node] = res + + assert not pending + return self.regex_id_cache[node0] + + def grammar(self, grammar: Subgrammar): + if grammar in self.grammar_id_cache: + return self.grammar_id_cache[grammar] + id = len(self.grammars) + self.grammar_id_cache[grammar] = id + self.grammars.append( + { + "greedy_skip_rx": grammar.skip_regex, + "nodes": [], + "rx_nodes": [], + } + ) + self.grammar_todo.append(grammar) + return id + + def node(self, node: GrammarFunction): + if node in self.node_id_cache: + return self.node_id_cache[node] + id = len(self.nodes) + self.node_id_cache[node] = id + self.nodes.append({}) + self.todo.append(node) + return id + + def process(self, node: GrammarFunction): + obj: dict[str, Any] = {} + if isinstance(node, Select): + obj = { + "Select": { + "among": [self.node(v) for v in node.values], + } + } + elif isinstance(node, Join): + if all(isinstance(v, Byte) for v in node.values): + literal = b"".join(cast(Byte, v).byte for v in node.values) + obj = { + "String": { + "literal": literal.decode("utf-8", errors="strict"), + } + } + else: + obj = { + "Join": { + "sequence": [self.node(v) for v in node.values], + } + } + elif isinstance(node, Lexeme): + obj = { + "Lexeme": { + "rx": node.body_regex, + "contextual": node.contextual, + } + } + elif isinstance(node, Subgrammar): + obj = { + "GenGrammar": { + "grammar": self.grammar(node), + "stop_rx": node.stop_regex, + "no_initial_skip": node.no_initial_skip, + "temperature": node.temperature if node.temperature >= 0 else None, + } + } + elif isinstance(node, RegularGrammar): + obj = { + "Gen": { + "body_rx": self.regex(node.grammar), + "stop_rx": "", + "lazy": False, # TODO this should be True + "temperature": node.temperature if node.temperature >= 0 else None, + } + } + elif isinstance(node, Gen): + obj = { + "Gen": { + "body_rx": node.body_regex, + "stop_rx": node.stop_regex, + "lazy": node.stop_regex != "", + "stop_capture_name": node.save_stop_text, + "temperature": node.temperature if node.temperature >= 0 else None, + } + } + elif isinstance(node, ByteRange): + # TODO: maybe raise a warning in this case, as user should probably be using a larger + # GenCommitPoint? + obj = { + "Gen": { + "body_rx": self.regex(node), + "stop_rx": "", + "lazy": True, + "temperature": node.temperature if node.temperature >= 0 else None, + } + } + elif isinstance(node, Byte): + obj = { + "String": { + "literal": node.byte.decode("utf-8", errors="strict"), + } + } + elif isinstance(node, Null): + obj = { + "String": { + "literal": "", + } + } + else: + raise Exception("Unknown node type:", type(node)) + tp = next(iter(obj)) + inner: dict = obj[tp] + if (capture_name:=getattr(node, "capture_name")): + inner["capture_name"] = capture_name + # Names on nodes are mostly useless + # if getattr(node, "name", None): + # inner["name"] = node.name + if (max_tokens:=getattr(node, "max_tokens")) and max_tokens < 1000000: + inner["max_tokens"] = max_tokens + self.nodes[self.node(node)] = obj + + def run_grammar(self, node: GrammarFunction): + assert self.todo == [] + id = self.node(node) + assert id == 0 + while self.todo: + node = self.todo.pop() + self.process(node) + + def run(self, node: GrammarFunction): + # avoid top-level node being a String + if _is_string_literal(node): + node = Select([node]) + self.run_grammar(node) + while self.grammar_todo: + grammar = self.grammar_todo.pop() + self.curr_grammar = self.grammars[self.grammar(grammar)] + self.nodes = cast(list[dict], self.curr_grammar["nodes"]) + self.node_id_cache = {} + self.regex_id_cache = {} + self.run_grammar(grammar.body) + return self.grammars diff --git a/guidance/_parser.py b/guidance/_parser.py index 600d0cb43..e6a22671d 100644 --- a/guidance/_parser.py +++ b/guidance/_parser.py @@ -1,666 +1,291 @@ -from sys import stderr +import json +import os +from typing import Any, Generator, Optional, Tuple, Union + +import llguidance # type: ignore[import-untyped] import numpy as np -from ordered_set import OrderedSet -from ._grammar import Join, Select, Terminal, Null, Byte, ByteRange +from numpy.typing import NDArray +from ._grammar import GrammarFunction, Join, Terminal +from ._schema import GenData, EngineCallResponse, LLInterpreterResponse +from .models._byte_tokenizer import ByteTokenizer +from .models._tokenizer import Tokenizer -class ParserException(Exception): - def __init__(self, *args, **kwargs): - self.current_byte = kwargs.pop("current_byte", None) - self.allowed_bytes = kwargs.pop("allowed_bytes", None) - self.consumed_bytes = kwargs.pop("consumed_bytes", None) - super().__init__(*args, **kwargs) + +class TokenParserException(Exception): + pass -class EarleyItem: - __slots__ = ( - "node", - "values", - "start", - "pos", - "log_prob", - "children", - "hidden_start", - ) - - def __init__(self, node, values, pos, start, log_prob, hidden_start): - self.node = node - self.values = values - self.start = start - self.pos = pos - self.log_prob = log_prob - self.children = None - self.hidden_start = hidden_start - - def __eq__(self, other): - return ( - isinstance(other, EarleyItem) - and self.start == other.start - and self.pos == other.pos - and self.node == other.node - and self.values == other.values - and self.log_prob == other.log_prob +class InvalidTokenException(TokenParserException): + def __init__(self, token: int, valid_tokens: list[int], prompt_tokens: list[int]): + self.token = token + self.valid_tokens = valid_tokens + self.prompt_tokens = prompt_tokens + super().__init__( + f"Invalid token {token}, expected one of {valid_tokens} after {prompt_tokens}" ) - def __hash__(self): - return hash((self.node, self.values, self.start, self.pos)) - - def __repr__(self): - if isinstance(self.node, Join): - s = f"{self.node.name:20} -> " - rs = "" - for i, v in enumerate(self.values): - if self.pos == i: - rs += "•" - rs += v.name + " " - if self.pos == len(self.values): - rs += "•" - elif isinstance(self.node, Select): - s = f"{self.node.name:20} -> " - rs = "" - if self.pos == 0: - rs += "•" - rs += self.values[0].name - if self.pos == 1: - rs += "•" - else: - assert False - return s + f"{rs:40} ({self.start}) {'nullable' if self.node.nullable else ''}" - -class Parser: - """An abstract base class for guidance parsers.""" +class TokenParser: + + def __init__( + self, + grammar: Union[GrammarFunction, str], + tokenizer: Tokenizer, + prompt: bytes = b"", + ensure_bos_token: bool = True, + ): + if isinstance(grammar, GrammarFunction): + # we can't have a terminal as the root + if isinstance(grammar, Terminal): + grammar = Join([grammar]) + serialized_grammar = json.dumps(grammar.ll_serialize()) + else: + serialized_grammar = grammar - pass + self.tokenizer = tokenizer + self.ll_tokenizer = llguidance.LLTokenizer( + llguidance.TokenizerWrapper(tokenizer) + ) + self.ll_interpreter = llguidance.LLInterpreter( + self.ll_tokenizer, + serialized_grammar, + log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")), + ) + self._generator = self._parse(prompt, ensure_bos_token) + self._done = False + + def is_accepting(self) -> bool: + return self.ll_interpreter.is_accepting() + + def done(self) -> bool: + return self._done + + def advance( + self, token: Optional[int] + ) -> Tuple[Optional[GenData], EngineCallResponse]: + try: + return self._generator.send(token) + except StopIteration as e: + self._done = True + return None, e.value + + def _process_prompt(self, prompt: bytes, ensure_bos_token: bool) -> list[int]: + prompt_tokens = self.ll_interpreter.process_prompt( + self.tokenizer.encode(prompt) + ) + if ( + ensure_bos_token + and self.tokenizer.bos_token is not None + and prompt_tokens[:1] != [self.tokenizer.bos_token_id] + ): + # add the beginning of sequence token if needed + prompt_tokens = [self.tokenizer.bos_token_id] + prompt_tokens + return self.tokenizer.recode(prompt_tokens) -class EarleyCommitParser(Parser): - def __init__(self, grammar): - # we can't have a terminal as the root - if isinstance(grammar, Terminal): - grammar = Join([grammar]) + def _parse( + self, + prompt: bytes, + ensure_bos_token: bool, + ) -> Generator[Tuple[Optional[GenData], EngineCallResponse], Optional[int], EngineCallResponse]: + tokens = self._process_prompt(prompt=prompt, ensure_bos_token=ensure_bos_token) - self.grammar = grammar - self.bytes = b"" - self.state_sets = [OrderedSet()] # the list of Earley items for each byte - self.token_counts = [] # used to track how many tokens have been used - self.state_set_pos = 0 - self.shadow_pos = 0 - self._add_node(self.grammar, 0, 0.0, 1000000000) - self._inner_loop(self.state_set_pos) - - @property - def pos(self): - return self.shadow_pos - - @pos.setter - def pos(self, new_pos): - - # do nothing if we aren't moving - if new_pos == self.state_set_pos: - return - elif new_pos > self.state_set_pos: - raise ParserException( - "Can't move the parser position forward! (only backward)" - ) - - # check if we are just moving the shadow position - if new_pos >= self.shadow_pos: - self.shadow_pos = new_pos - return + while True: + mask, resp = self.ll_interpreter.mid_process() + r = LLInterpreterResponse.model_validate_json(resp) + response = r.progress.to_engine_call_response() + if r.stop: + break - # actually reset our position if we need to - self.state_sets = self.state_sets[: new_pos + 1] + [OrderedSet()] - self.token_counts = self.token_counts[: new_pos + 2] - self.bytes = self.bytes[:new_pos] - self.state_set_pos = new_pos - self.shadow_pos = new_pos - self._inner_loop(self.state_set_pos) - - def _add_item(self, state_set_pos, new_item): - state_set = self.state_sets[state_set_pos] - if new_item not in state_set: - state_set.append(new_item) - else: - existing_item = state_set.items[state_set.map[new_item]] - existing_item.hidden_start = min( - existing_item.hidden_start, new_item.hidden_start - ) - - def _add_node(self, grammar, state_set_pos, log_prob, hidden_start): - if isinstance(grammar, Terminal): - new_item = EarleyItem( - grammar, tuple(), 0, state_set_pos, log_prob, hidden_start - ) - self._add_item(state_set_pos, new_item) - - elif isinstance(grammar, Join): - new_item = EarleyItem( - grammar, tuple(grammar.values), 0, state_set_pos, log_prob, hidden_start - ) - self._add_item(state_set_pos, new_item) - - elif isinstance(grammar, Select): - for value in grammar.values: - new_item = EarleyItem( - grammar, (value,), 0, state_set_pos, log_prob, hidden_start + if mask is not None: + assert r.temperature is not None + gen_data = GenData( + tokens=tokens, + mask=mask, + temperature=r.temperature, ) - self._add_item(state_set_pos, new_item) - - def _inner_loop(self, state_set_pos, start_pos=0): - curr_state_set = self.state_sets[state_set_pos] - if len(self.state_sets) == state_set_pos + 1: - self.state_sets.append(OrderedSet()) - self.token_counts.append( - self.token_counts[-1] if len(self.token_counts) > 0 else 0 - ) - next_state_set = self.state_sets[state_set_pos + 1] - pos = start_pos - while len(curr_state_set) > pos: - item = curr_state_set[pos] - - # completion - if item.pos == len(item.values): - - # if we complete an item that is a "commit point" then we eliminate all other possible - # parses so that we are "committed" to using this item - # we do this by removing any unprocessed items in the current state set and clearing the next state set - if item.node.commit_point: - while len(curr_state_set) > pos: - - # if we find another valid commit point that starts earlier we use that instead - # this causes us to pick the longest matching valid commit point - end_item = curr_state_set[-1] - if ( - end_item.node.commit_point - and end_item.pos == len(end_item.values) - and end_item.start < item.start - ): - item = end_item - - curr_state_set.pop() - curr_state_set.append( - item - ) # we append the current item again (we do this since we may have swapped it out above) - next_state_set.clear() - - # advance all the parents that our completion impacts - token_span = ( - self.token_counts[state_set_pos] - self.token_counts[item.start] - ) - start_state_set = self.state_sets[item.start] - for start_item in start_state_set: - if ( - start_item.pos < len(start_item.values) - and start_item.values[start_item.pos] == item.node - ): + # Send caller the mask and response; wait for token + token = yield (gen_data, response) + if token is None: + raise TokenParserException("Expected token, got None") + if not mask[token]: + # Note: we could punt this probem to ll_interpreter.post_process, + # but it's a bit clearer to handle it here + raise InvalidTokenException(token, gen_data.valid_next_tokens, tokens) + else: + gen_data = None + token = yield (gen_data, response) + if token is not None: + raise TokenParserException(f"Expected None, got token {token}") - # if item.node.max_tokens <= token_span and any(start_item.node == v and len(v.values) > 1 for v in item.node.values): - # continue # skip advancing parents that are also children (recursion) once we are past the token limit - - curr_state_set.append( - EarleyItem( - start_item.node, - start_item.values, - start_item.pos + 1, - start_item.start, - start_item.log_prob - + item.log_prob, # increment the log prob by the child value, - start_item.hidden_start, - ) - ) + backtrack, ff_tokens = self.ll_interpreter.post_process(token) + if backtrack: + tokens = tokens[:-backtrack] + tokens = tokens + ff_tokens - # don't advance past our max token limit - elif ( - item.node.max_tokens - > self.token_counts[state_set_pos] - self.token_counts[item.start] - ): - - # scan (note we only scan forward when we have more max token headroom left) - next_item_node = item.values[item.pos] - hidden_start = item.hidden_start - if next_item_node.hidden: - hidden_start = min(state_set_pos, hidden_start) - if isinstance( - next_item_node, Terminal - ): # and item.node.max_tokens > self.token_counts[state_set_pos] - self.token_counts[item.start]: - next_state_set.append( - EarleyItem( - item.node, - item.values, - item.pos + 1, - item.start, - item.log_prob, - hidden_start, - ) - ) # the log prob will get incremented when consume_bytes is called - - # prediction - else: - self._add_node( - next_item_node, state_set_pos, 0.0, hidden_start - ) # the log probs will get incremented by children later - - # handle nullable items by advancing them automatically (since we know we can) - if next_item_node.nullable: - new_item = EarleyItem( - item.node, - item.values, - item.pos + 1, - item.start, - item.log_prob, - item.hidden_start, - ) - if new_item not in self.state_sets[state_set_pos]: - self.state_sets[state_set_pos].append(new_item) - pos += 1 - - def earliest_hidden_start(self, state_pos=None): - """The earliest that a hidden node might match. - - This is useful because it tells us which bytes may end being hidden. - """ - if state_pos is None: - state_pos = self.state_set_pos - earliest_pos = 10000000000 - for item in self.state_sets[state_pos]: - earliest_pos = min(earliest_pos, item.hidden_start) - return earliest_pos - - def matched(self): - """Checks if the parser has completely matched the grammar.""" - if self.shadow_pos != self.state_set_pos: - return False - for item in self.state_sets[self.state_set_pos]: - if item.node == self.grammar and item.pos == len(item.values): - return True - return False + stop_reason = self.ll_interpreter.stop_reason() + if stop_reason not in {"NoExtension", "EndOfSentence"}: + # TODO: extend exception handling + raise TokenParserException(f"Unexpected stop reason: {stop_reason}") - def shadow_rewind(self, new_pos): - if new_pos == self.state_set_pos: - return - self.shadow_pos = new_pos - - def commit_and_collapse_item(self, item): - """This collapses the item into zero width and rewinds the parser position accordingly. - - Note we assume the item is in the current state set. - """ - - # trim off the state sets that matches this item - self.state_sets = self.state_sets[: item.start + 1] - self.token_counts = self.token_counts[: item.start + 1] - self.bytes = self.bytes[: item.start] - self.state_set_pos = item.start - self.shadow_pos = item.start - - # add this state to its start point (making it a zero length match with no values) - self.state_sets[item.start].append( - EarleyItem( - item.node, tuple(), 0, item.start, item.log_prob, item.hidden_start - ) - ) + return response - # expand from this state - self._inner_loop(item.start, len(self.state_sets[item.start]) - 1) - def mark_new_token(self): - # TODO: we allow ourselves to go one past our max token limit when we hit a one-byte token - # because we don't know if we are continuing or extending a new token when we parse - # the first byte of the token. We could fix this by rerunning the inner_loop after each - # token, but we skip that for now since max_tokens is not a hard garuntee anyway when you - # have patterns. +class ByteParserException(Exception): + def __init__(self, *args, **kwargs): + self.current_byte = kwargs.pop("current_byte", None) + self.allowed_bytes = kwargs.pop("allowed_bytes", None) + self.consumed_bytes = kwargs.pop("consumed_bytes", None) + super().__init__(*args, **kwargs) - self.token_counts[-1] += 1 - def consume_byte(self, byte, log_prob=0.0): - """Advances the parser by the given byte.""" +class ByteParser: + def __init__( + self, + grammar: GrammarFunction, + prompt: bytes = b"", + ensure_bos_token: bool = True, + ): + self.tokenizer = ByteTokenizer() + self.token_parser = TokenParser(grammar, self.tokenizer, prompt, ensure_bos_token) + self.bytes = b"" + self.gen_data: Optional[GenData] = None + self.pos = 0 + self._variables: dict[str, Any] = {} + self._variables_log_probs: dict[str, Any] = {} + self.consume_bytes(prompt) + + def matched(self) -> bool: + if self.pos < len(self.bytes): + return False + return self.token_parser.is_accepting() + + def valid_next_bytes(self) -> set[bytes]: + if self.pos < len(self.bytes): + return {self.bytes[self.pos : self.pos + 1]} + if self.gen_data is None: + return set() + return { + bytes([t]) + for t in self.gen_data.valid_next_tokens + if t != self.tokenizer.eos_token_id + } + + def next_byte_mask(self) -> NDArray[np.uint8]: + mask = np.zeros(256, dtype=np.uint8) + for t in self.valid_next_bytes(): + mask[t[0]] = 1 + return mask - # see if we need to advance our shadow position... - if self.shadow_pos < self.state_set_pos: - assert ( - byte == self.bytes[self.shadow_pos : self.shadow_pos + 1] - ), "Attempted to consume a byte by advancing shadow_pos but the byte didn't match!" - self.shadow_pos += 1 - return + def consume_bytes(self, bts: bytes) -> None: + # Run underlying ll_parser and fast-forward all of our bytes + # until we have a "choice" (generation step) to make + while self.gen_data is None and not self.token_parser.done(): + self.gen_data, response = self.token_parser.advance(None) + self._update_capture(response) + self.bytes += response.new_bytes - # ...if not, we extend our bytes - self.bytes += byte - - # filter out all the extensions that don't match this byte - new_next_state_set = [] - found_valid = False - found_invalid = False - hidden_start = 10000000000 - for item in self.state_sets[self.state_set_pos + 1]: - token_span = self.token_counts[-1] - self.token_counts[item.start] - if item.node.max_tokens <= token_span: - found_invalid = True - continue - elif item.pos > 0 and isinstance(item.values[item.pos - 1], Terminal): - last_inner_node = item.values[item.pos - 1] - if not last_inner_node.match_byte(byte): - found_invalid = True - continue - else: - found_valid = True - if last_inner_node.commit_point: - item.log_prob += log_prob - new_next_state_set = [item] - hidden_start = min(hidden_start, item.hidden_start) - found_invalid = True # we make everything else invalid, so that means we found something invalid - break - item.log_prob += log_prob # update the probability of the item by the probability of choosing this byte - new_next_state_set.append(item) - hidden_start = min(hidden_start, item.hidden_start) - if not found_valid: - raise ParserException( - "Attempted to consume a byte that the grammar does not accept!", - current_byte=byte, - allowed_bytes=self.valid_next_bytes(), - consumed_bytes=self.bytes, - ) - if found_invalid: # only update if we changed the set - self.state_sets[self.state_set_pos + 1] = OrderedSet(new_next_state_set) - - # advance the parser one position - self.state_set_pos += 1 - self.shadow_pos += 1 - self._inner_loop(self.state_set_pos) - - # look for a commit point node - commit_point = None - for item in self.state_sets[self.state_set_pos]: - if ( - item.node.commit_point - and item.pos == len(item.values) - or (item.pos > 0 and item.values[item.pos - 1].commit_point) - ): - commit_point = item - break # TODO: consider how we might need to prioritize multiple commit point nodes (an uncommon scenario I think) - # hidden_start, - return commit_point - - def valid_next_bytes(self): - """A list of Byte and ByteRange objects representing the next valid bytes.""" - valid_items = set() - next_state_set = self.state_sets[self.state_set_pos + 1] - for item in next_state_set: - token_span = self.token_counts[-1] - self.token_counts[item.start] - if item.node.max_tokens <= token_span: - continue - elif item.pos > 0 and isinstance(item.values[item.pos - 1], Terminal): - v = item.values[item.pos - 1] - if v not in valid_items: - valid_items.add(v) - return valid_items - - def next_byte_temperature(self): - """The maximum temperature over all the next bytes, or -1 if no temperature is set.""" - max_temp = -1 - next_state_set = self.state_sets[self.state_set_pos + 1] - for item in next_state_set: - if item.pos > 0 and isinstance(item.values[item.pos - 1], Terminal): - v = item.values[item.pos - 1] - max_temp = max(max_temp, v.temperature) - return max_temp - - def next_byte_mask(self): - """A mask version of the `valid_next_bytes` method.""" - - mask = np.zeros(256, dtype=bool) - - # if we are shadow rewound then we just force those bytes again - if self.shadow_pos < self.state_set_pos: - mask[self.bytes[self.shadow_pos]] = True - - # otherwise we compute the valid bytes from the grammar - else: - valid_items = self.valid_next_bytes() - for item in valid_items: - if isinstance(item, Byte): - mask[item.byte[0]] = True - elif isinstance(item, ByteRange): - mask[item.byte_range[0] : item.byte_range[1] + 1] = True - else: - raise ParserException( - "Unknown Terminal Type: " + str(type(item)), - ) - return mask + if not bts: + return - def __repr__(self, state_sets=None) -> str: - s = "" - if state_sets is None: - _state_sets = self.state_sets + b = bts[0] + # If the current position is less than the length of the bytes, then we are in fast_forward mode + # and we need to make sure that the byte we are consuming is the same as the byte at the current + # position + if self.pos < len(self.bytes): + if b != self.bytes[self.pos]: + next_byte = self.bytes[self.pos : self.pos + 1] + raise ByteParserException( + f"Expected byte {next_byte!r} (fast_forward), got {bytes([b])!r}", + current_byte=bytes([b]), + allowed_bytes={next_byte}, + consumed_bytes=self.bytes[: self.pos], + ) + # Byte was good, move to the next byte + self.pos += 1 + self.consume_bytes(bts[1:]) else: - _state_sets = state_sets - for i, states in enumerate(_state_sets): - s += f"\n=== {i} ===" - if self.state_set_pos == i: - s += " (state_set_pos)" - s += "\n" - for state in states: - if isinstance(state.node, Join): - s += f"{state.node.name:20} -> " - rs = "" - for i, v in enumerate(state.values): - if state.pos == i: - rs += "•" - rs += v.name + " " - if state.pos == len(state.values): - rs += "•" - elif isinstance(state.node, Select): - s += f"{state.node.name:20} -> " - rs = "" - if state.pos == 0: - rs += "•" - if len(state.values) == 0: - rs += "NO_VALUES!" - else: - rs += state.values[0].name - if state.pos == 1: - rs += "•" - else: - assert False - s += f"{rs:40} ({state.start}) {'nullable' if state.node.nullable else ''}\n" # type: ignore[attr-defined] - return s - - def _reversed_state_sets(self): - new_state_sets = [OrderedSet([]) for _ in range(len(self.state_sets))] - for i, states in enumerate(self.state_sets): - for state in states: - # if state.node.name == "__call___c": - # pass - new_state_sets[state.start].append( - EarleyItem( - state.node, - state.values, - state.pos, - i, - state.log_prob, - state.hidden_start, - ) + # If we are here, then we are either in generation mode or we are done. + if self.gen_data is None: + # TODO: may run into trouble here if we need to backtrack + assert self.token_parser.done() + assert not self.valid_next_bytes() + raise ByteParserException( + f"Expected end of input, got {bytes([b])!r}", + current_byte=bytes([b]), + allowed_bytes=set(), + consumed_bytes=self.bytes[: self.pos], ) - - return new_state_sets - - def parse_tree(self): - reversed_state_sets = self._reversed_state_sets() - root_item = None - - # find the matching root state - for item in reversed_state_sets[0]: - if ( - item.node == self.grammar - and item.start == len(self.bytes) - and item.pos == len(item.values) - ): # note that ".start" mean end because items are reversed - root_item = item - if root_item is None: - return None - self._compute_parse_tree(0, root_item, reversed_state_sets) - return root_item - - def get_captures(self, data=None, log_prob_data=None): - root_node = self.parse_tree() - if data is None: - data = {} - if log_prob_data is None: - log_prob_data = {} - if root_node is not None: - # parse complete, so we can get the captures - self._record_captures_from_root(root_node, data, log_prob_data) - return data, log_prob_data - # compute on partially parsed tree - self._record_captures_partial(data, log_prob_data) - return data, log_prob_data - - def _record_captures_partial(self, data, log_prob_data): - byte_data = self.bytes - - for item in self.state_sets[self.state_set_pos]: - cname = item.node.capture_name - if cname is None: - continue - captured_value = byte_data[item.start : self.earliest_hidden_start()] - if captured_value.endswith(b"<"): - print( - "WARNING: Captured value ends with '<' which is a special character in the parser!", - file=stderr, + # We're in generation mode. Assure that the byte is one of the valid next bytes + if b not in self.gen_data.valid_next_tokens: + valid_next_bytes = self.valid_next_bytes() + raise ByteParserException( + f"Expected one of the following bytes: {valid_next_bytes!r}, got {bytes([b])!r}", + current_byte=bytes([b]), + allowed_bytes=valid_next_bytes, + consumed_bytes=self.bytes[: self.pos], ) - data[cname] = captured_value - log_prob_data[cname] = item.log_prob - - def _record_captures_from_root(self, initial_item, data, log_prob_data): - byte_data = self.bytes - stack = [(initial_item, 0)] - used_names = ( - set() - ) # track which capture names have been used so self-recursive children don't overwrite their parents - - while stack: - item, byte_pos = stack.pop() - # terminal nodes - if isinstance(item, Terminal): - - # if we are at a capture group node then we save the matched terminal byte - if item.capture_name is not None: - data[item.capture_name] = item.byte - log_prob_data[item.capture_name] = 0 - - # internal nodes - else: - start_byte_pos = byte_pos - - # recurse for all our non-null children - for child in item.children: - if child is not None: - stack.append((child, byte_pos)) - # _record_captures(child, data, log_prob_data, byte_data, byte_pos) - if isinstance(child, Terminal): - byte_pos += len(child) - else: - byte_pos = ( - child.start - ) # note that "start" means "end" since this is a reversed state set - - # if we are at a capture group node then we save the matched bytes range - # note that we record this after calling our children so that we save the outermost version of self-recursive calls - cname = item.node.capture_name - if ( - cname is not None - and cname not in used_names - and not item.node.hidden - ): - - # see if we are doing a list append - if cname.startswith("__LIST_APPEND:"): - cname = cname[14:] # trim off the list append tag - if cname not in data or not isinstance(data[cname], list): - data[cname] = [] - log_prob_data[cname] = [] - data[cname].append(byte_data[start_byte_pos : item.start]) - log_prob_data[cname].append(item.log_prob) - - # or just a regular assignment - else: - data[cname] = byte_data[ - start_byte_pos : item.start - ] # note that "start" means "end" since this is a reversed state set - log_prob_data[cname] = item.log_prob - - used_names.add(cname) - - def _compute_parse_tree(self, initial_pos, initial_item, reversed_state_sets): - stack = [(initial_pos, initial_item)] - - while stack: - pos, item = stack.pop() - - # compute the children for this item - assert self._compute_children(pos, item, reversed_state_sets) - - # recurse on the children - for child in item.children: - if child is None: - pass # this child was nullable and was chosen to be null (empty) - elif isinstance(child, Terminal): - pos += len(child) - else: - stack.append((pos, child)) - pos = ( - child.start - ) # note that ".start" mean end because items are reversed - - def _compute_children(self, state_set_pos, item, reversed_state_sets, values_pos=0): - - # ensure we have a children array - if item.children is None: - item.children = [None for _ in range(len(item.values))] - - # consume as many terminal children as possible - while True: + # Byte was good, have ll_parser consume it so we can advance further + self.gen_data, response = self.token_parser.advance(b) + self._update_capture(response) + self.bytes += response.new_bytes + + # Run consume_bytes to advance ll_parser and consume the next byte + self.consume_bytes(bts) + + def force_done(self): + if not self.matched(): + raise ByteParserException("Hit end of input before reaching a valid state") + if self.token_parser.done(): + return - # if we are at the end of the values then there no more children and we see if we consumed all the right bytes - if values_pos == len(item.values): - return ( - state_set_pos == item.start - ) # note that ".start" mean end because items are reversed + self.gen_data, response = self.token_parser.advance(self.tokenizer.eos_token_id) + self._update_capture(response) + self.bytes += response.new_bytes + if not self.token_parser.done() or not self.matched(): + raise ByteParserException("Hit end of input before reaching a valid state") + + def get_captures(self): + return self._variables, self._variables_log_probs + + def _update_capture(self, response: EngineCallResponse): + # Stolen from model. TODO: refactor to share code + for k in response.capture_groups: + v = response.capture_groups[k] + + # see if we are in a list_append mode + if isinstance(v, list): + for i, inner_v in enumerate(v): + # convert to a string if possible + # TODO: will need to not just always do this once we support images etc. + try: + inner_v = ( + inner_v.decode("utf8") + if isinstance(inner_v, bytes) + else inner_v + ) + except UnicodeDecodeError: + pass - # get the child we are trying to match (meaning we are looking for completed early items for this node) - value = item.values[values_pos] + if k not in self._variables or not isinstance( + self._variables[k], list + ): + self._variables[k] = [] + self._variables_log_probs[k] = [] + self._variables[k].append(inner_v) + self._variables_log_probs[k].append( + response.capture_group_log_probs[k][i] + ) - # if we have a terminal node we can jump forward that many bytes - if isinstance(value, Terminal): - item.children[values_pos] = value - values_pos += 1 - state_set_pos += len(value) + # ...or standard assignment mode else: - break - - # otherwise we need to try all possible next matching items in the current state set - # so we loop over every item in the current state set looking for a completed match - for inner_item in reversed_state_sets[state_set_pos]: - if inner_item.node == value and inner_item.pos == len(inner_item.values): - - # see if we can get a complete parse following this inner item - if self._compute_children( - inner_item.start, item, reversed_state_sets, values_pos + 1 - ): - item.children[values_pos] = inner_item - return True - - # if we didn't find a child set and this is nullable we can skip this child (since it may not exist if nulled) - # we skip it by adding a fake EarlyItem with zero length (this makes zero length named captures still work) - if value.nullable: - if self._compute_children( - state_set_pos, item, reversed_state_sets, values_pos + 1 - ): - # this child has zero length since it was nullable - item.children[values_pos] = EarleyItem( - value, tuple(), 0, state_set_pos, 0, state_set_pos - ) - return True - - return False + # convert to a string if possible + # TODO: will need to not just always do this once we support images etc. + try: + v = v.decode("utf8") if isinstance(v, bytes) else v + except UnicodeDecodeError: + pass + self._variables[k] = v + self._variables_log_probs[k] = response.capture_group_log_probs[k] diff --git a/guidance/_schema.py b/guidance/_schema.py new file mode 100644 index 000000000..e5c9d15e8 --- /dev/null +++ b/guidance/_schema.py @@ -0,0 +1,120 @@ +from typing import Any, Literal, Optional, Union + +from pydantic import BaseModel, Field, NonNegativeInt, RootModel, model_validator, computed_field +from typing_extensions import Annotated +from functools import cached_property + + +class GuidanceEngineMetrics(BaseModel): + engine_input_tokens: NonNegativeInt = 0 + engine_output_tokens: NonNegativeInt = 0 + + +class EngineCallResponse(BaseModel): + new_bytes: bytes + is_generated: bool + new_bytes_prob: float + capture_groups: dict + capture_group_log_probs: dict + new_token_count: NonNegativeInt + + +class GenData(BaseModel): + tokens: list[int] + mask: bytes + temperature: float + + @computed_field # type: ignore[misc] + @cached_property + def valid_next_tokens(self) -> list[int]: + return [i for i, b in enumerate(self.mask) if b != 0] + + +class LLProgressCapture(BaseModel): + object: Literal["capture"] + name: str + hex: str + log_prob: float + list_append: bool = False + + @model_validator(mode="before") + def strip_list_append_prefix(cls, values): + name = values["name"] + if name.startswith("__LIST_APPEND:"): + values["name"] = name[14:] + # Override whatever was set + values["list_append"] = True + return values + + +class LLProgressText(BaseModel): + object: Literal["text"] + hex: str + num_tokens: NonNegativeInt + log_prob: float + is_generated: bool + + +class LLProgressFinalText(BaseModel): + object: Literal["final_text"] + # we don't need to handle this for now + + +LLProgressItem = Annotated[ + Union[LLProgressCapture, LLProgressText, LLProgressFinalText], + Field(discriminator="object"), +] + + +class LLProgress(RootModel): + root: list[LLProgressItem] + + def to_engine_call_response(self) -> EngineCallResponse: + new_bytes = b"" + new_token_count = 0 + new_bytes_prob = 0.0 + is_generated = False + capture_groups: dict[str, Any] = {} + capture_group_log_probs: dict[str, Any] = {} + num_text_entries = 0 + + for j in self.root: + if isinstance(j, LLProgressCapture): + is_generated = True + cname = j.name + data = bytes.fromhex(j.hex) + if j.list_append: + if cname not in capture_groups or not isinstance( + capture_groups[cname], list + ): + capture_groups[cname] = [] + capture_group_log_probs[cname] = [] + capture_groups[cname].append(data) + capture_group_log_probs[cname].append(j.log_prob) + else: + capture_groups[cname] = data + capture_group_log_probs[cname] = j.log_prob + elif isinstance(j, LLProgressText): + # it actually should only happen once per round... + new_bytes += bytes.fromhex(j.hex) + new_token_count += j.num_tokens + new_bytes_prob += j.log_prob + is_generated |= j.is_generated + num_text_entries += 1 + if num_text_entries > 0: + new_bytes_prob /= num_text_entries + + return EngineCallResponse( + new_bytes=new_bytes, + new_token_count=new_token_count, + new_bytes_prob=new_bytes_prob, + is_generated=is_generated, + capture_groups=capture_groups, + capture_group_log_probs=capture_group_log_probs, + ) + + +class LLInterpreterResponse(BaseModel): + progress: LLProgress + stop: bool + temperature: Optional[float] diff --git a/guidance/_serialization.proto b/guidance/_serialization.proto deleted file mode 100644 index c3e55e34c..000000000 --- a/guidance/_serialization.proto +++ /dev/null @@ -1,101 +0,0 @@ -syntax = "proto3"; - -package guidance; - -message Grammar { - repeated GrammarFunction nodes = 1; -} - -message EngineCallResponse { - bytes new_bytes = 1; - bool is_generated = 2; - float new_bytes_prob = 3; - map capture_groups = 4; - map capture_group_log_probs = 5; - int32 new_token_count = 6; -} - -message Value { - oneof kind { - string string_value = 1; - bytes bytes_value = 2; - float float_value = 3; - ListValue list_value = 4; - } -} - -message ListValue { - repeated Value values = 1; -} - -message Byte { - bytes byte = 1; - bool hidden = 2; - bool commit_point = 3; - bool nullable = 4; - string capture_name = 5; - float temperature = 6; -} - -message ByteRange { - bytes byte_range = 1; - bool hidden = 3; - bool commit_point = 4; - string capture_name = 5; - float temperature = 6; -} - -message Null { -} - -message ModelVariable { - string name = 1; - bool hidden = 2; - bool commit_point = 3; - string capture_name = 4; - bool nullable = 5; -} - -message Join { - bool nullable = 1; - - // Use a repeated field to store the list of values - repeated int32 values = 2; - - string name = 3; - bool hidden = 4; - bool commit_point = 5; - string capture_name = 6; - int32 max_tokens = 7; -} - -message Select { - bool nullable = 1; - - // Use a repeated field to store the list of values - repeated int32 values = 2; - - string name = 3; - bool hidden = 4; - bool commit_point = 5; - string capture_name = 6; - int32 max_tokens = 7; - bool recursive = 8; -} - -// message Terminal { -// oneof function_type { -// Byte byte = 1; -// ByteRange byte_range = 2; -// } -// } - -message GrammarFunction { - oneof function_type { - Join join = 1; - Select select = 2; - Byte byte = 3; - ByteRange byte_range = 4; - ModelVariable model_variable = 5; - } -} \ No newline at end of file diff --git a/guidance/_serialization_pb2.py b/guidance/_serialization_pb2.py deleted file mode 100644 index d54beb3bb..000000000 --- a/guidance/_serialization_pb2.py +++ /dev/null @@ -1,54 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: _serialization.proto -# Protobuf Python Version: 4.25.3 -"""Generated protocol buffer code.""" -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14_serialization.proto\x12\x08guidance\"3\n\x07Grammar\x12(\n\x05nodes\x18\x01 \x03(\x0b\x32\x19.guidance.GrammarFunction\"\xa5\x03\n\x12\x45ngineCallResponse\x12\x11\n\tnew_bytes\x18\x01 \x01(\x0c\x12\x14\n\x0cis_generated\x18\x02 \x01(\x08\x12\x16\n\x0enew_bytes_prob\x18\x03 \x01(\x02\x12G\n\x0e\x63\x61pture_groups\x18\x04 \x03(\x0b\x32/.guidance.EngineCallResponse.CaptureGroupsEntry\x12W\n\x17\x63\x61pture_group_log_probs\x18\x05 \x03(\x0b\x32\x36.guidance.EngineCallResponse.CaptureGroupLogProbsEntry\x12\x17\n\x0fnew_token_count\x18\x06 \x01(\x05\x1a\x45\n\x12\x43\x61ptureGroupsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x1e\n\x05value\x18\x02 \x01(\x0b\x32\x0f.guidance.Value:\x02\x38\x01\x1aL\n\x19\x43\x61ptureGroupLogProbsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x1e\n\x05value\x18\x02 \x01(\x0b\x32\x0f.guidance.Value:\x02\x38\x01\"\x80\x01\n\x05Value\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x15\n\x0b\x62ytes_value\x18\x02 \x01(\x0cH\x00\x12\x15\n\x0b\x66loat_value\x18\x03 \x01(\x02H\x00\x12)\n\nlist_value\x18\x04 \x01(\x0b\x32\x13.guidance.ListValueH\x00\x42\x06\n\x04kind\",\n\tListValue\x12\x1f\n\x06values\x18\x01 \x03(\x0b\x32\x0f.guidance.Value\"w\n\x04\x42yte\x12\x0c\n\x04\x62yte\x18\x01 \x01(\x0c\x12\x0e\n\x06hidden\x18\x02 \x01(\x08\x12\x14\n\x0c\x63ommit_point\x18\x03 \x01(\x08\x12\x10\n\x08nullable\x18\x04 \x01(\x08\x12\x14\n\x0c\x63\x61pture_name\x18\x05 \x01(\t\x12\x13\n\x0btemperature\x18\x06 \x01(\x02\"p\n\tByteRange\x12\x12\n\nbyte_range\x18\x01 \x01(\x0c\x12\x0e\n\x06hidden\x18\x03 \x01(\x08\x12\x14\n\x0c\x63ommit_point\x18\x04 \x01(\x08\x12\x14\n\x0c\x63\x61pture_name\x18\x05 \x01(\t\x12\x13\n\x0btemperature\x18\x06 \x01(\x02\"\x06\n\x04Null\"k\n\rModelVariable\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0e\n\x06hidden\x18\x02 \x01(\x08\x12\x14\n\x0c\x63ommit_point\x18\x03 \x01(\x08\x12\x14\n\x0c\x63\x61pture_name\x18\x04 \x01(\t\x12\x10\n\x08nullable\x18\x05 \x01(\x08\"\x86\x01\n\x04Join\x12\x10\n\x08nullable\x18\x01 \x01(\x08\x12\x0e\n\x06values\x18\x02 \x03(\x05\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\x0e\n\x06hidden\x18\x04 \x01(\x08\x12\x14\n\x0c\x63ommit_point\x18\x05 \x01(\x08\x12\x14\n\x0c\x63\x61pture_name\x18\x06 \x01(\t\x12\x12\n\nmax_tokens\x18\x07 \x01(\x05\"\x9b\x01\n\x06Select\x12\x10\n\x08nullable\x18\x01 \x01(\x08\x12\x0e\n\x06values\x18\x02 \x03(\x05\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\x0e\n\x06hidden\x18\x04 \x01(\x08\x12\x14\n\x0c\x63ommit_point\x18\x05 \x01(\x08\x12\x14\n\x0c\x63\x61pture_name\x18\x06 \x01(\t\x12\x12\n\nmax_tokens\x18\x07 \x01(\x05\x12\x11\n\trecursive\x18\x08 \x01(\x08\"\xe4\x01\n\x0fGrammarFunction\x12\x1e\n\x04join\x18\x01 \x01(\x0b\x32\x0e.guidance.JoinH\x00\x12\"\n\x06select\x18\x02 \x01(\x0b\x32\x10.guidance.SelectH\x00\x12\x1e\n\x04\x62yte\x18\x03 \x01(\x0b\x32\x0e.guidance.ByteH\x00\x12)\n\nbyte_range\x18\x04 \x01(\x0b\x32\x13.guidance.ByteRangeH\x00\x12\x31\n\x0emodel_variable\x18\x05 \x01(\x0b\x32\x17.guidance.ModelVariableH\x00\x42\x0f\n\rfunction_typeb\x06proto3') - -_globals = globals() -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, '_serialization_pb2', _globals) -if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _globals['_ENGINECALLRESPONSE_CAPTUREGROUPSENTRY']._options = None - _globals['_ENGINECALLRESPONSE_CAPTUREGROUPSENTRY']._serialized_options = b'8\001' - _globals['_ENGINECALLRESPONSE_CAPTUREGROUPLOGPROBSENTRY']._options = None - _globals['_ENGINECALLRESPONSE_CAPTUREGROUPLOGPROBSENTRY']._serialized_options = b'8\001' - _globals['_GRAMMAR']._serialized_start=34 - _globals['_GRAMMAR']._serialized_end=85 - _globals['_ENGINECALLRESPONSE']._serialized_start=88 - _globals['_ENGINECALLRESPONSE']._serialized_end=509 - _globals['_ENGINECALLRESPONSE_CAPTUREGROUPSENTRY']._serialized_start=362 - _globals['_ENGINECALLRESPONSE_CAPTUREGROUPSENTRY']._serialized_end=431 - _globals['_ENGINECALLRESPONSE_CAPTUREGROUPLOGPROBSENTRY']._serialized_start=433 - _globals['_ENGINECALLRESPONSE_CAPTUREGROUPLOGPROBSENTRY']._serialized_end=509 - _globals['_VALUE']._serialized_start=512 - _globals['_VALUE']._serialized_end=640 - _globals['_LISTVALUE']._serialized_start=642 - _globals['_LISTVALUE']._serialized_end=686 - _globals['_BYTE']._serialized_start=688 - _globals['_BYTE']._serialized_end=807 - _globals['_BYTERANGE']._serialized_start=809 - _globals['_BYTERANGE']._serialized_end=921 - _globals['_NULL']._serialized_start=923 - _globals['_NULL']._serialized_end=929 - _globals['_MODELVARIABLE']._serialized_start=931 - _globals['_MODELVARIABLE']._serialized_end=1038 - _globals['_JOIN']._serialized_start=1041 - _globals['_JOIN']._serialized_end=1175 - _globals['_SELECT']._serialized_start=1178 - _globals['_SELECT']._serialized_end=1333 - _globals['_GRAMMARFUNCTION']._serialized_start=1336 - _globals['_GRAMMARFUNCTION']._serialized_end=1564 -# @@protoc_insertion_point(module_scope) diff --git a/guidance/_serialization_pb2.pyi b/guidance/_serialization_pb2.pyi deleted file mode 100644 index 9c7300120..000000000 --- a/guidance/_serialization_pb2.pyi +++ /dev/null @@ -1,351 +0,0 @@ -""" -@generated by mypy-protobuf. Do not edit manually! -isort:skip_file -""" - -import builtins -import collections.abc -import google.protobuf.descriptor -import google.protobuf.internal.containers -import google.protobuf.message -import typing - -DESCRIPTOR: google.protobuf.descriptor.FileDescriptor - -@typing.final -class Grammar(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - NODES_FIELD_NUMBER: builtins.int - @property - def nodes(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___GrammarFunction]: ... - def __init__( - self, - *, - nodes: collections.abc.Iterable[global___GrammarFunction] | None = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["nodes", b"nodes"]) -> None: ... - -global___Grammar = Grammar - -@typing.final -class EngineCallResponse(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - @typing.final - class CaptureGroupsEntry(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - KEY_FIELD_NUMBER: builtins.int - VALUE_FIELD_NUMBER: builtins.int - key: builtins.str - @property - def value(self) -> global___Value: ... - def __init__( - self, - *, - key: builtins.str = ..., - value: global___Value | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["value", b"value"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ... - - @typing.final - class CaptureGroupLogProbsEntry(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - KEY_FIELD_NUMBER: builtins.int - VALUE_FIELD_NUMBER: builtins.int - key: builtins.str - @property - def value(self) -> global___Value: ... - def __init__( - self, - *, - key: builtins.str = ..., - value: global___Value | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["value", b"value"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ... - - NEW_BYTES_FIELD_NUMBER: builtins.int - IS_GENERATED_FIELD_NUMBER: builtins.int - NEW_BYTES_PROB_FIELD_NUMBER: builtins.int - CAPTURE_GROUPS_FIELD_NUMBER: builtins.int - CAPTURE_GROUP_LOG_PROBS_FIELD_NUMBER: builtins.int - NEW_TOKEN_COUNT_FIELD_NUMBER: builtins.int - new_bytes: builtins.bytes - is_generated: builtins.bool - new_bytes_prob: builtins.float - new_token_count: builtins.int - @property - def capture_groups(self) -> google.protobuf.internal.containers.MessageMap[builtins.str, global___Value]: ... - @property - def capture_group_log_probs(self) -> google.protobuf.internal.containers.MessageMap[builtins.str, global___Value]: ... - def __init__( - self, - *, - new_bytes: builtins.bytes = ..., - is_generated: builtins.bool = ..., - new_bytes_prob: builtins.float = ..., - capture_groups: collections.abc.Mapping[builtins.str, global___Value] | None = ..., - capture_group_log_probs: collections.abc.Mapping[builtins.str, global___Value] | None = ..., - new_token_count: builtins.int = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["capture_group_log_probs", b"capture_group_log_probs", "capture_groups", b"capture_groups", "is_generated", b"is_generated", "new_bytes", b"new_bytes", "new_bytes_prob", b"new_bytes_prob", "new_token_count", b"new_token_count"]) -> None: ... - -global___EngineCallResponse = EngineCallResponse - -@typing.final -class Value(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - STRING_VALUE_FIELD_NUMBER: builtins.int - BYTES_VALUE_FIELD_NUMBER: builtins.int - FLOAT_VALUE_FIELD_NUMBER: builtins.int - LIST_VALUE_FIELD_NUMBER: builtins.int - string_value: builtins.str - bytes_value: builtins.bytes - float_value: builtins.float - @property - def list_value(self) -> global___ListValue: ... - def __init__( - self, - *, - string_value: builtins.str = ..., - bytes_value: builtins.bytes = ..., - float_value: builtins.float = ..., - list_value: global___ListValue | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["bytes_value", b"bytes_value", "float_value", b"float_value", "kind", b"kind", "list_value", b"list_value", "string_value", b"string_value"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["bytes_value", b"bytes_value", "float_value", b"float_value", "kind", b"kind", "list_value", b"list_value", "string_value", b"string_value"]) -> None: ... - def WhichOneof(self, oneof_group: typing.Literal["kind", b"kind"]) -> typing.Literal["string_value", "bytes_value", "float_value", "list_value"] | None: ... - -global___Value = Value - -@typing.final -class ListValue(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - VALUES_FIELD_NUMBER: builtins.int - @property - def values(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Value]: ... - def __init__( - self, - *, - values: collections.abc.Iterable[global___Value] | None = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["values", b"values"]) -> None: ... - -global___ListValue = ListValue - -@typing.final -class Byte(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - BYTE_FIELD_NUMBER: builtins.int - HIDDEN_FIELD_NUMBER: builtins.int - COMMIT_POINT_FIELD_NUMBER: builtins.int - NULLABLE_FIELD_NUMBER: builtins.int - CAPTURE_NAME_FIELD_NUMBER: builtins.int - TEMPERATURE_FIELD_NUMBER: builtins.int - byte: builtins.bytes - hidden: builtins.bool - commit_point: builtins.bool - nullable: builtins.bool - capture_name: builtins.str - temperature: builtins.float - def __init__( - self, - *, - byte: builtins.bytes = ..., - hidden: builtins.bool = ..., - commit_point: builtins.bool = ..., - nullable: builtins.bool = ..., - capture_name: builtins.str = ..., - temperature: builtins.float = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["byte", b"byte", "capture_name", b"capture_name", "commit_point", b"commit_point", "hidden", b"hidden", "nullable", b"nullable", "temperature", b"temperature"]) -> None: ... - -global___Byte = Byte - -@typing.final -class ByteRange(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - BYTE_RANGE_FIELD_NUMBER: builtins.int - HIDDEN_FIELD_NUMBER: builtins.int - COMMIT_POINT_FIELD_NUMBER: builtins.int - CAPTURE_NAME_FIELD_NUMBER: builtins.int - TEMPERATURE_FIELD_NUMBER: builtins.int - byte_range: builtins.bytes - hidden: builtins.bool - commit_point: builtins.bool - capture_name: builtins.str - temperature: builtins.float - def __init__( - self, - *, - byte_range: builtins.bytes = ..., - hidden: builtins.bool = ..., - commit_point: builtins.bool = ..., - capture_name: builtins.str = ..., - temperature: builtins.float = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["byte_range", b"byte_range", "capture_name", b"capture_name", "commit_point", b"commit_point", "hidden", b"hidden", "temperature", b"temperature"]) -> None: ... - -global___ByteRange = ByteRange - -@typing.final -class Null(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - def __init__( - self, - ) -> None: ... - -global___Null = Null - -@typing.final -class ModelVariable(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - NAME_FIELD_NUMBER: builtins.int - HIDDEN_FIELD_NUMBER: builtins.int - COMMIT_POINT_FIELD_NUMBER: builtins.int - CAPTURE_NAME_FIELD_NUMBER: builtins.int - NULLABLE_FIELD_NUMBER: builtins.int - name: builtins.str - hidden: builtins.bool - commit_point: builtins.bool - capture_name: builtins.str - nullable: builtins.bool - def __init__( - self, - *, - name: builtins.str = ..., - hidden: builtins.bool = ..., - commit_point: builtins.bool = ..., - capture_name: builtins.str = ..., - nullable: builtins.bool = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["capture_name", b"capture_name", "commit_point", b"commit_point", "hidden", b"hidden", "name", b"name", "nullable", b"nullable"]) -> None: ... - -global___ModelVariable = ModelVariable - -@typing.final -class Join(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - NULLABLE_FIELD_NUMBER: builtins.int - VALUES_FIELD_NUMBER: builtins.int - NAME_FIELD_NUMBER: builtins.int - HIDDEN_FIELD_NUMBER: builtins.int - COMMIT_POINT_FIELD_NUMBER: builtins.int - CAPTURE_NAME_FIELD_NUMBER: builtins.int - MAX_TOKENS_FIELD_NUMBER: builtins.int - nullable: builtins.bool - name: builtins.str - hidden: builtins.bool - commit_point: builtins.bool - capture_name: builtins.str - max_tokens: builtins.int - @property - def values(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: - """Use a repeated field to store the list of values""" - - def __init__( - self, - *, - nullable: builtins.bool = ..., - values: collections.abc.Iterable[builtins.int] | None = ..., - name: builtins.str = ..., - hidden: builtins.bool = ..., - commit_point: builtins.bool = ..., - capture_name: builtins.str = ..., - max_tokens: builtins.int = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["capture_name", b"capture_name", "commit_point", b"commit_point", "hidden", b"hidden", "max_tokens", b"max_tokens", "name", b"name", "nullable", b"nullable", "values", b"values"]) -> None: ... - -global___Join = Join - -@typing.final -class Select(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - NULLABLE_FIELD_NUMBER: builtins.int - VALUES_FIELD_NUMBER: builtins.int - NAME_FIELD_NUMBER: builtins.int - HIDDEN_FIELD_NUMBER: builtins.int - COMMIT_POINT_FIELD_NUMBER: builtins.int - CAPTURE_NAME_FIELD_NUMBER: builtins.int - MAX_TOKENS_FIELD_NUMBER: builtins.int - RECURSIVE_FIELD_NUMBER: builtins.int - nullable: builtins.bool - name: builtins.str - hidden: builtins.bool - commit_point: builtins.bool - capture_name: builtins.str - max_tokens: builtins.int - recursive: builtins.bool - @property - def values(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: - """Use a repeated field to store the list of values""" - - def __init__( - self, - *, - nullable: builtins.bool = ..., - values: collections.abc.Iterable[builtins.int] | None = ..., - name: builtins.str = ..., - hidden: builtins.bool = ..., - commit_point: builtins.bool = ..., - capture_name: builtins.str = ..., - max_tokens: builtins.int = ..., - recursive: builtins.bool = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["capture_name", b"capture_name", "commit_point", b"commit_point", "hidden", b"hidden", "max_tokens", b"max_tokens", "name", b"name", "nullable", b"nullable", "recursive", b"recursive", "values", b"values"]) -> None: ... - -global___Select = Select - -@typing.final -class GrammarFunction(google.protobuf.message.Message): - """message Terminal { - oneof function_type { - Byte byte = 1; - ByteRange byte_range = 2; - } - } - """ - - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - JOIN_FIELD_NUMBER: builtins.int - SELECT_FIELD_NUMBER: builtins.int - BYTE_FIELD_NUMBER: builtins.int - BYTE_RANGE_FIELD_NUMBER: builtins.int - MODEL_VARIABLE_FIELD_NUMBER: builtins.int - @property - def join(self) -> global___Join: ... - @property - def select(self) -> global___Select: ... - @property - def byte(self) -> global___Byte: ... - @property - def byte_range(self) -> global___ByteRange: ... - @property - def model_variable(self) -> global___ModelVariable: ... - def __init__( - self, - *, - join: global___Join | None = ..., - select: global___Select | None = ..., - byte: global___Byte | None = ..., - byte_range: global___ByteRange | None = ..., - model_variable: global___ModelVariable | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["byte", b"byte", "byte_range", b"byte_range", "function_type", b"function_type", "join", b"join", "model_variable", b"model_variable", "select", b"select"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["byte", b"byte", "byte_range", b"byte_range", "function_type", b"function_type", "join", b"join", "model_variable", b"model_variable", "select", b"select"]) -> None: ... - def WhichOneof(self, oneof_group: typing.Literal["function_type", b"function_type"]) -> typing.Literal["join", "select", "byte", "byte_range", "model_variable"] | None: ... - -global___GrammarFunction = GrammarFunction diff --git a/guidance/_server.py b/guidance/_server.py index 0d74cf090..e2e3848b3 100644 --- a/guidance/_server.py +++ b/guidance/_server.py @@ -1,7 +1,7 @@ import base64 import os -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Iterator try: import pydantic @@ -14,7 +14,7 @@ raise from .models._model import Model, Engine -from ._grammar import GrammarFunction +from ._schema import EngineCallResponse class GuidanceRequest(pydantic.BaseModel): @@ -65,14 +65,16 @@ async def extend_parser( if x_api_key not in self.valid_api_keys: raise HTTPException(status_code=401, detail="Invalid API key") - # data = await request.json() - # parser = data.get("parser") - grammar = GrammarFunction.deserialize( - base64.b64decode(guidance_request.grammar) + engine_responses: Iterator[EngineCallResponse] = self.engine( + guidance_request.parser, guidance_request.grammar + ) + # Note the use of a generator comprehension here -- this will be evaluated lazily + json_stream: Iterator[str] = ( + response.model_dump_json() for response in engine_responses ) return StreamingResponse( - self.engine(guidance_request.parser, grammar), + json_stream, media_type="application/json", ) diff --git a/guidance/library/__init__.py b/guidance/library/__init__.py index d7c23fa6a..9b8a2b0cf 100644 --- a/guidance/library/__init__.py +++ b/guidance/library/__init__.py @@ -1,5 +1,5 @@ # import functions that can be called directly -from ._gen import gen, call_tool, will_gen +from ._gen import gen, call_tool, will_gen, regex from ._image import image from ._capture import capture @@ -28,9 +28,7 @@ from ._char_set import char_set from ._prefix_tree import prefix_tree from ._substring import substring -from ._regex import regex from ._optional import optional from ._tool import Tool from ._any_char_but import any_char_but - from ._json import json diff --git a/guidance/library/_gen.py b/guidance/library/_gen.py index bf8118b5d..644bf3b61 100644 --- a/guidance/library/_gen.py +++ b/guidance/library/_gen.py @@ -1,23 +1,17 @@ import regex as regex_module import logging from .._guidance import guidance +from .._grammar import select, Gen, quote_regex, capture, token_limit, with_temperature +from ._block import block from ._silent import silent -from .._grammar import select -from ._sequences import zero_or_more -from .._grammar import commit_point -from ._any_char import any_char -from .._grammar import capture -from ._regex import regex as regex_grammar -from .._grammar import token_limit, eos_token, active_role_end, with_temperature from ._tool import Tool -from ._block import block logger = logging.getLogger(__name__) # TODO: make this stateless! # TODO: uncomment this once we get temperature stateless -@guidance(stateless=lambda *args, **kwargs: kwargs.get("tools", None) is None) +@guidance(stateless=lambda *args, **kwargs: kwargs.get("tools", None) is None) def gen( lm, name=None, @@ -98,6 +92,9 @@ def gen( call from the model's context if you plan to change it's format after the call is made. """ # TODO: expand the tools doc string + if [tools, regex].count(None) == 0: + raise ValueError("Cannot use regex with tools") + assert ( n == 1 ), "We still need to add support for n>1! Consider putting your gen call in a loop for now." @@ -105,100 +102,84 @@ def gen( logger.debug(f'start gen(name="{name}")') - # set stream if we are interactive - # if stream_tokens is None and not lm.is_silent() and n == 1: - # stream_tokens = True - - # use the suffix as the stop string if not otherwise specified - # TODO: still need to make suffix work with grammars - # eos_token = lm.eos_token.decode('utf8') if stop is None and stop_regex is None and suffix != "": stop = suffix - # if stop is None and stop_regex is None and getattr(lm, "suffix", False): - # if lm.suffix.startswith("\n"): - # stop = "\n" - # elif lm.suffix.startswith('"') and str(lm).endswith('"'): - # stop = '"' - # elif lm.suffix.startswith("'") and str(lm).endswith("'"): - # stop = "'" - - # fall back to stopping at the EOS token + + # Empty stop condition is implicitly the EOS token + gen_stop = "" if stop is not False: if stop is None: stop = [] if isinstance(stop, str): stop = [stop] - if regex is None: - stop = stop + [select([eos_token(), active_role_end()])] if stop_regex is None: stop_regex = [] if isinstance(stop_regex, str): stop_regex = [stop_regex] - stop_regex = [regex_grammar(x) for x in stop_regex] - # This needs to be here for streaming - # if name is not None and not list_append: - # lm[name] = "" + stop_regex += [quote_regex(s) for s in stop] + if len(stop_regex) == 1: + gen_stop = stop_regex[0] + else: + gen_stop = "|".join("(" + s + ")" for s in stop_regex) - # define the generation pattern - if regex is not None: - pattern = regex_grammar(regex) - else: - pattern = zero_or_more(any_char()) + if regex is None: + regex = r"(?s:.*)" + if save_stop_text is True: + save_stop_text = str(name) + "_stop_text" + if not isinstance(save_stop_text, str): + save_stop_text = None tagged_name = "__LIST_APPEND:" + name if list_append and name is not None else name - # define any capture group for non-tool calls - if name is not None and tools is None: - pattern = capture(pattern, name=tagged_name) - - # limit the number of tokens - pattern = token_limit(pattern, max_tokens) - - # define the stop pattern - if stop is False or len(stop + stop_regex) == 0: - stop_pattern = "" - else: - stop_pattern = select(stop + stop_regex) - if save_stop_text is True: - save_stop_text = str(name) + "_stop_text" - if isinstance(save_stop_text, str): - stop_pattern = capture(stop_pattern, name=save_stop_text) - stop_pattern = commit_point(stop_pattern, hidden=True) - - # single generation - start_pos = len(str(lm)) if tools is not None: - with block(tagged_name): - tools = [Tool(callable=x) if not isinstance(x, Tool) else x for x in tools] - init_token_count = lm.token_count - gen_grammar = pattern + select( - [stop_pattern] - + [ - capture( - commit_point(x.call_grammar, hidden=hide_tool_call), - name=f"tool{i}", - ) - for i, x in enumerate(tools) - ] + tools = [Tool(callable=x) if not isinstance(x, Tool) else x for x in tools] + options = [ + Gen(body_regex=regex, stop_regex=gen_stop, save_stop_text=save_stop_text, max_tokens=max_tokens) + ] + for i, tool in enumerate(tools): + # Infer a regex that will match the start of a tool call + tool_call_prefix = tool.call_grammar.forced_prefix() + if len(tool_call_prefix) < 4: + # TODO: alternatively check that the prefix contains the name (case insensitive) of the tool? + # anything shorter is probably far too ambiguous + raise ValueError(f"Could not infer unambiguous tool call prefix for tool {tool.name}") + options.append( + capture( + Gen(body_regex=regex, stop_regex=quote_regex(tool_call_prefix), max_tokens=max_tokens), + name=f"tool{i}" + ) ) - while lm.token_count <= max_tokens + init_token_count: - lm = lm._run_stateless( - gen_grammar, temperature=temperature - ) # TODO: we should not be using this internal method + grm = with_temperature(select(options), temperature) + initial_token_count = lm.token_count + with block(tagged_name): + while lm.token_count <= max_tokens + initial_token_count: + lm += grm tool_called = False for i in range(len(tools)): tool_i = f"tool{i}" if tool_i in lm: tool_called = True - lm += tools[i].tool_call() - lm = lm.remove(tool_i) + if hide_tool_call: + temp_lm = lm + tools[i].call_grammar + with block("tool_call"): + temp_lm += tools[i].tool_call() + lm += temp_lm["tool_call"] + else: + lm += tools[i].call_grammar + tools[i].tool_call() + lm = lm.remove(tool_i) if not tool_called: lm += suffix break - elif n == 1: - lm += with_temperature(pattern + stop_pattern + suffix, temperature) + return lm + + pattern = Gen(body_regex=regex, stop_regex=gen_stop, save_stop_text=save_stop_text, name=tagged_name, max_tokens=max_tokens) + + # define any capture group for non-tool calls + if name is not None and tools is None: + pattern = capture(pattern, name=tagged_name) + lm += with_temperature(pattern + suffix, temperature) logger.debug(f"finish gen") return lm @@ -283,4 +264,11 @@ def will_gen(lm, stop=None, stop_regex=None, ignore_spaces=False, max_tokens=30) @guidance def call_tool(lm, tool): - return lm + tool.call_grammar + tool.tool_call() + lm += tool.call_grammar + lm += tool.tool_call() + return lm + + +@guidance(stateless=True) +def regex(lm, pattern, *, name=None): + return lm + gen(regex=pattern, name=name) diff --git a/guidance/library/_json.py b/guidance/library/_json.py index 7d3f679f3..ffc4ce0e5 100644 --- a/guidance/library/_json.py +++ b/guidance/library/_json.py @@ -24,6 +24,7 @@ from .._grammar import GrammarFunction, select, capture, with_temperature from ._pydantic import pydantic_to_json_schema +from ._subgrammar import lexeme, subgrammar def _to_compact_json(target: Any) -> str: @@ -67,6 +68,7 @@ class Keyword(str, Enum): "object": {"properties", "additionalProperties"}, } +WHITESPACE = {b" ", b"\t", b"\n", b"\r"} STRING_CHARS = [ char_range("a", "z"), char_range("A", "Z"), @@ -92,17 +94,16 @@ def validate_json_node_keys(node: Mapping[str, Any]): @guidance(stateless=True) def _gen_json_int(lm): - pos_nonzero = char_range("1", "9") + sequence(char_range("0", "9")) - return lm + optional("-") + select(["0", pos_nonzero]) + return lm + lexeme(r"-?(?:0|[1-9][0-9]*)", contextual=True) @guidance(stateless=True) def _gen_json_number(lm): - mantissa_int = _gen_json_int() - mantissa_frac = "." + one_or_more(char_range("0", "9")) - exponent = "e" + select(["", "+", "-"]) + one_or_more(char_range("0", "9")) - - return lm + mantissa_int + optional(mantissa_frac) + optional(exponent) + return lm + select([ + _gen_json_int(), + lexeme(r"-?(?:0|[1-9][0-9]*)(?:\.[0-9]+)", contextual=True), + lexeme(r"-?(?:0|[1-9][0-9]*)(?:\.[0-9]+)?(?:[eE][+-]?[0-9]+)", contextual=True), + ]) @guidance(stateless=True) @@ -112,7 +113,6 @@ def _gen_json_string( max_length: Union[int, None] = None, regex: Union[str, None] = None, ): - lm += '"' if regex is not None: if min_length > 0 or max_length is not None: msg = ( @@ -121,10 +121,12 @@ def _gen_json_string( "left unspecified." ) raise ValueError(msg) - lm += gen(regex=regex) - else: - lm += sequence(select(STRING_CHARS), min_length=min_length, max_length=max_length) - return lm + '"' + return '"' + gen(regex=regex) + '"' + + char_expr = r'(\\([\"\\\/bfnrt]|u[a-fA-F0-9]{4})|[^\"\\\x00-\x1F\x7F])' + range_expr = f"{{{min_length},{max_length}}}" if max_length is not None else f"{{{min_length},}}" + string_expr = f'"{char_expr}{range_expr}"' + return lm + lexeme(string_expr, contextual=True) @guidance(stateless=True) @@ -285,7 +287,7 @@ def _process_anyOf( @guidance(stateless=True) def _process_enum(lm, *, options: Sequence[Mapping[str, Any]]): - # options will come in as python objects, so we need to convert to (compact) JSON + # TODO: can we support a whitespace-flexible version of this? all_opts = [] for opt in options: all_opts.append(_to_compact_json(opt)) @@ -341,6 +343,7 @@ def _gen_json( return lm + _get_definition(reference=json_schema[Keyword.REF], definitions=definitions) if Keyword.CONST in json_schema: + # TODO: can we support a whitespace-flexible version of this? return lm + _to_compact_json(json_schema[Keyword.CONST]) if Keyword.ENUM in json_schema: @@ -392,7 +395,9 @@ def json( Type["pydantic.BaseModel"], "pydantic.TypeAdapter", ] = None, + compact: bool = False, temperature: float = 0.0, + max_tokens: int = 100000000, ): """Generate valid JSON according to the supplied JSON schema or `pydantic` model. @@ -435,6 +440,10 @@ def json( - A JSON schema object. This is a JSON schema string which has been passed to ``json.loads()`` - A subclass of ``pydantic.BaseModel`` - An instance of ``pydantic.TypeAdapter`` + + compact : bool + If True, the generated JSON will be forced to be compact (no whitespace). + If False, output will be whitespace-flexible (i.e. decided by the model). """ if isinstance(schema, Mapping): # Raises jsonschema.exceptions.SchemaError or ValueError @@ -451,9 +460,18 @@ def json( assert len(definitions) == 0, "Found duplicate definitions" definitions = _build_definitions(schema[dk]) - return lm + capture( - with_temperature(_gen_json(schema, definitions), temperature=temperature), - name=name, + return lm + with_temperature( + subgrammar( + name, + body=_gen_json(json_schema=schema, definitions=definitions), + skip_regex=( + None if compact + else r"[\x20\x0A\x0D\x09]+" + ), + no_initial_skip=True, + max_tokens=max_tokens, + ), + temperature=temperature, ) @@ -463,7 +481,7 @@ def _build_definitions( definitions: Dict[str, Callable[[], GrammarFunction]] = {} def build_definition(json_schema: Mapping[str, Any]) -> Callable[[], GrammarFunction]: - @guidance(stateless=True, dedent=False) + @guidance(stateless=True, dedent=False, cache=True) def closure(lm): return lm + _gen_json(json_schema=json_schema, definitions=definitions) diff --git a/guidance/library/_regex.py b/guidance/library/_regex.py deleted file mode 100644 index f20e4d584..000000000 --- a/guidance/library/_regex.py +++ /dev/null @@ -1,154 +0,0 @@ -import sys - -if sys.version_info >= (3, 11): - import re._constants as constants # type: ignore[import-not-found] - import re._parser as parser # type: ignore[import-not-found] -else: - import sre_parse as parser - import sre_constants as constants - -from re import RegexFlag -from typing import Any, List, Tuple, Union - -from typing_extensions import TypeAlias - -from .._grammar import Byte, ByteRange, Join, Select, byte_range, select, capture -from .._guidance import guidance -from ._any_char_but import any_char_but -from ._sequences import sequence - -# Type aliases -Node: TypeAlias = Tuple[constants._NamedIntConstant, Any] - - -class UnsupportedRegexError(Exception): - pass - - -class RegexPatternConverter: - - @classmethod - def parse(cls, pattern: str): - return cls.convert(parser.parse(pattern)) - - @classmethod - def convert(cls, tree: Union[parser.SubPattern, Node], flags: int = 0): - if flags != 0: - # Equivalent to re.NOFLAG - raise UnsupportedRegexError( - f"Flags other than re.NOFLAG not supported; got {RegexFlag(flags)}" - ) - if isinstance(tree, parser.SubPattern): - if len(tree.data) == 1: - return cls.convert(tree.data[0]) - return Join([cls.convert(node) for node in tree.data]) - - opcode, args = tree - opcode_name = opcode.name - try: - method = getattr(cls, opcode_name) - except AttributeError as e: - raise UnsupportedRegexError( - f"Unsupported regex feature with opcode {opcode_name}" - ) from e - return method(args) - - @classmethod - def SUBPATTERN(cls, args: Tuple[int, int, int, parser.SubPattern]): - # capture group - group, add_flags, del_flags, arg = args - flags = add_flags & ~del_flags - return cls.convert(arg, flags) - - @classmethod - def LITERAL(cls, args: int): - # byte - return Byte(bytes([args])) - - @classmethod - def NOT_LITERAL(cls, args: int): - return any_char_but(chr(args)) - - @classmethod - def RANGE(cls, args: Tuple[int, int]): - # byte_range - low, high = args - return byte_range(bytes([low]), bytes([high])) - - @classmethod - def ANY(cls, _: None): - return any_char_but("\n") - - @classmethod - def IN(cls, args: List[Node]): - if args[0][0] == constants.NEGATE: - transformed_args = [cls.convert(arg) for arg in args[1:]] - negated_bytes = cls._get_negated_bytes(transformed_args) - return any_char_but(negated_bytes) - transformed_args = [cls.convert(arg) for arg in args] - return select(transformed_args) - - @classmethod - def _get_negated_bytes(cls, grammars: List[Union[Byte, ByteRange, Select]]): - negated_bytes = set() - for value in grammars: - if isinstance(value, Byte): - negated_bytes.add(value.byte) - elif isinstance(value, ByteRange): - low, high = value.byte_range - negated_bytes.update([bytes([i]) for i in range(low, high + 1)]) - elif isinstance(value, Select): - negated_bytes.update(cls._get_negated_bytes(value._values)) - else: - raise TypeError(f"Can't negate {type(value)} object") - return negated_bytes - - @classmethod - def BRANCH(cls, args: Tuple[Any, List[parser.SubPattern]]): - unknown, arg = args - if unknown is not None: - # Unsure of the semantics of this value, but it seems to be - # None in all cases tested so far - raise UnsupportedRegexError(f"Unkwnown argument in BRANCH: {unknown}") - transformed_args = [cls.convert(a) for a in arg] - return select(transformed_args) - - @classmethod - def MAX_REPEAT( - cls, - args: Tuple[int, Union[int, constants._NamedIntConstant], parser.SubPattern], - ): - low, high, arg = args - transformed_arg = cls.convert(arg) - if isinstance(high, constants._NamedIntConstant): - if high != constants.MAXREPEAT: - raise UnsupportedRegexError(f"Unsupported high value in range: {high}") - return sequence(transformed_arg, min_length=low) - return sequence(transformed_arg, min_length=low, max_length=high) - - @classmethod - def CATEGORY(cls, args: constants._NamedIntConstant): - # \d - if args.name == "CATEGORY_DIGIT": - return cls.parse(r"[0-9]") - # \D - if args.name == "CATEGORY_NOT_DIGIT": - return cls.parse(r"[^0-9]") - # \w - if args.name == "CATEGORY_WORD": - return cls.parse(r"[0-9A-Za-z_]") - # \W - if args.name == "CATEGORY_NOT_WORD": - return cls.parse(r"[^0-9A-Za-z_]") - # \s - if args.name == "CATEGORY_SPACE": - return cls.parse(r"[ \t\n\r\f\v]") - # \S - if args.name == "CATEGORY_NOT_SPACE": - return cls.parse(r"[^ \t\n\r\f\v]") - raise UnsupportedRegexError(f"Unsupported category: {args.name}") - - -@guidance(stateless=True) -def regex(lm, pattern, *, name=None): - return lm + capture(RegexPatternConverter.parse(pattern), name=name) diff --git a/guidance/library/_subgrammar.py b/guidance/library/_subgrammar.py new file mode 100644 index 000000000..1c8045694 --- /dev/null +++ b/guidance/library/_subgrammar.py @@ -0,0 +1,28 @@ +from .._grammar import Subgrammar, Lexeme, GrammarFunction, capture +from typing import Optional + + +def lexeme( + body_regex: str, + contextual: bool = False, +): + return Lexeme(body_regex=body_regex, contextual=contextual) + + +def subgrammar( + name: str = None, + *, + body: GrammarFunction, + skip_regex: Optional[str] = None, + no_initial_skip: bool = False, + max_tokens=100000000, +): + r = Subgrammar( + body=body, + skip_regex=skip_regex, + no_initial_skip=no_initial_skip, + max_tokens=max_tokens, + ) + if name: + r = capture(r, name) + return r diff --git a/guidance/library/_substring.py b/guidance/library/_substring.py index 9a9cc8443..b284324ca 100644 --- a/guidance/library/_substring.py +++ b/guidance/library/_substring.py @@ -3,7 +3,7 @@ from .._guidance import guidance # from ._prefix_tree import prefix_tree -from .._grammar import string, select, capture +from .._grammar import string, select, capture, as_regular_grammar from ._optional import optional @@ -135,7 +135,7 @@ def substring(lm, target_string: str, name: Optional[str] = None): ) state_stack.pop() - return lm + capture(node_cache[0], name=name) + return lm + capture(as_regular_grammar(node_cache[0]), name=name) # @guidance(stateless=True, dedent=False) diff --git a/guidance/library/_tool.py b/guidance/library/_tool.py index 21a2e9140..792393d86 100644 --- a/guidance/library/_tool.py +++ b/guidance/library/_tool.py @@ -1,10 +1,8 @@ from .._guidance import guidance -from ._any_char import any_char -from .._grammar import select, capture, string, commit_point -from ._sequences import zero_or_more, one_or_more -from ._any_char_but import any_char_but -from ._any_char import any_char - +from .._grammar import select +from ._optional import optional +from ._sequences import zero_or_more +from ._subgrammar import lexeme, subgrammar class Tool: def __init__(self, call_grammar=None, tool_call=None, callable=None): @@ -25,25 +23,25 @@ def __init__(self, call_grammar=None, tool_call=None, callable=None): self.tool_call = tool_call -def valid_chars(): - return any_char_but(["=", ")"]) - - -def positional_arg(): - return one_or_more(valid_chars()) - - -def kwarg(): - return one_or_more(valid_chars()) + "=" + one_or_more(valid_chars()) - +arg = lexeme(r"[^,=)]+") +kwarg = arg + "=" + arg +args = arg + zero_or_more("," + arg) +kwargs = kwarg + zero_or_more("," + kwarg) def basic_func_grammar(name): - obj = string(name + "(") - obj += capture( - select([zero_or_more(positional_arg()), ""]) + select([zero_or_more(kwarg()), ""]), + obj = name + "(" + obj += subgrammar( name="tool_args", + body=optional( + select([ + args, + kwargs, + args + "," + kwargs, + ]) + ), + skip_regex=r" *" ) - obj += string(")") + obj += ")" return obj diff --git a/guidance/models/_azure_guidance.py b/guidance/models/_azure_guidance.py index 16eef04b8..353e55e56 100644 --- a/guidance/models/_azure_guidance.py +++ b/guidance/models/_azure_guidance.py @@ -1,36 +1,41 @@ import requests import os -import base64 import json import urllib.parse -from ._model import Engine, Model, EngineCallResponse +from ._model import Engine, Model +from .._schema import LLProgress from ..chat import Phi3MiniChatTemplate from ._byte_tokenizer import ByteTokenizer +from typing import Dict, Tuple, Optional class AzureGuidanceEngine(Engine): """This connects to a remote guidance server on Azure and runs all computation using the remote engine.""" - def __init__(self, server_url, max_streaming_tokens=1000, chat_template=None): - if ( - server_url is None - or isinstance(server_url, str) - and len(server_url.strip()) == 0 - ): + def __init__( + self, + server_url, + max_streaming_tokens=1000, + chat_template=None, + log_level=1, + ): + if server_url is None or isinstance(server_url, str) and len(server_url.strip()) == 0: server_url = os.getenv("AZURE_GUIDANCE_URL", "") elif not isinstance(server_url, str): raise ValueError("server_url must contain a URL string.") - if not server_url.startswith("https://"): - raise ValueError( - "AzureGuidance requires a remote model URL that starts with https://" - ) - self.server_url = server_url + if ( + not server_url.startswith("https://") + and not server_url.startswith("http://") + ): + raise ValueError("AzureGuidance requires a remote model URL that starts with https:// or http://") + self.conn_str = server_url self.max_streaming_tokens = max_streaming_tokens + self.log_level = log_level if chat_template is None: # TODO [PK]: obtain this from the server - chat_template=Phi3MiniChatTemplate + chat_template = Phi3MiniChatTemplate tokenizer = ByteTokenizer(chat_template) @@ -38,17 +43,25 @@ def __init__(self, server_url, max_streaming_tokens=1000, chat_template=None): super().__init__(tokenizer=tokenizer, compute_log_probs=False) def __call__(self, parser, grammar, ensure_bos_token=True): - b64 = base64.b64encode(grammar.serialize()).decode("utf-8") - + serialized = {"grammar": grammar.ll_serialize()} + # this is a hack to avoid loops + serialized["grammar"]["max_tokens"] = self.max_streaming_tokens + # print(json.dumps(serialized)) data = { - "controller": "guidance", - "controller_arg": {"guidance_b64": b64}, + "controller": "llguidance", + "controller_arg": serialized, "prompt": parser, "max_tokens": self.max_streaming_tokens, - "temperature": 0.0, # this is just default temperature + "temperature": 0.0, # this is just default temperature } - resp = req("post", "run", json=data, stream=True, base_url=self.server_url) + url, headers, info = _mk_url("run", conn_str=self.conn_str) + if self.log_level >= 4: + print(f"POST {info}", flush=True) + if self.log_level >= 5: + print(f" {json.dumps(data, indent=None)}", flush=True) + resp = requests.request("post", url, headers=headers, json=data, stream=True) + if resp.status_code != 200: text = resp.text try: @@ -58,7 +71,8 @@ def __call__(self, parser, grammar, ensure_bos_token=True): except: pass raise RuntimeError( - f"Bad response to Guidance request {resp.status_code} {resp.reason}: {text}." + f"Bad response to Guidance request\nRequest: {info}\n" + + f"Response: {resp.status_code} {resp.reason}\n{text}" ) for line in resp.iter_lines(): @@ -70,51 +84,39 @@ def __call__(self, parser, grammar, ensure_bos_token=True): if "forks" not in d: continue for ch in d["forks"]: - capture_groups = {} - capture_group_log_probs = {} - if "Previous WASM Error" in ch["logs"]: raise RuntimeError("Previous WASM Error.") idx = ch["index"] assert idx == 0, "unexpected index in response from server" - new_bytes = b"" - new_token_count = 0 - new_bytes_prob = 0.0 - num_text_entries = 0 + progress = [] for ln in ch["logs"].split("\n"): ln: str if ln.startswith("JSON-OUT: "): j = json.loads(ln[10:]) - tag = j.get("object", "") - if tag == "capture": - capture_groups[j["name"]] = bytes.fromhex(j["hex"]) - capture_group_log_probs[j["name"]] = j["log_prob"] - elif tag == "text": - # it actually should only happen once per round... - new_bytes += bytes.fromhex(j["hex"]) - new_token_count += j["num_tokens"] - new_bytes_prob += j["log_prob"] - num_text_entries += 1 - if num_text_entries > 0: - new_bytes_prob /= num_text_entries - - # print(ch["logs"].rstrip("\n"), flush=True) + progress.append(j) + # don't print warnings if log_level >= 0, since we're + # going to print them anyway below together with the + # rest of the logs + elif ln.startswith("Warning: ") and self.log_level < 2: + if self.log_level >= 1: + print(ln, flush=True) + progress = LLProgress.model_validate(progress) + + if self.log_level >= 2: + print(ch["logs"].rstrip("\n"), flush=True) err = ch.get("error", "") if err: raise RuntimeError(f"Error returned by grammar server {err}.") - is_generated = True # TODO: get this from the server - - response_data = EngineCallResponse( - new_bytes, - is_generated, - new_bytes_prob, - capture_groups, - capture_group_log_probs, - new_token_count, - ) - yield response_data + # TODO: these metrics may be a little off -- notice the `-1` (which is a hack for passing + # tests in tests/model_integration/library/test_gen.py for now, may have to do with BOS?) + usage = d["usage"] + self.metrics.engine_input_tokens = usage["ff_tokens"] + self.metrics.engine_output_tokens = usage["sampled_tokens"] - 1 + + yield progress.to_engine_call_response() + elif decoded_line == "data: [DONE]": pass else: @@ -122,46 +124,42 @@ def __call__(self, parser, grammar, ensure_bos_token=True): class AzureGuidance(Model): + def __init__( self, model=None, echo=True, max_streaming_tokens=1000, chat_template=None, + log_level: Optional[int] = None, ): """Build a new remote grammar processing Azure model object that represents a model in a given state.""" - - engine = AzureGuidanceEngine(model, max_streaming_tokens, chat_template) + if log_level is None: + log_level = int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")) + engine = AzureGuidanceEngine(model, max_streaming_tokens, chat_template, log_level) super().__init__(engine, echo=echo) -def _parse_base_url(base_url: str): - p = urllib.parse.urlparse(base_url) - key = "" +def _mk_url(path: str, conn_str: str): + p = urllib.parse.urlparse(conn_str) + headers = {} + info = "no auth header" if p.fragment: f = urllib.parse.parse_qs(p.fragment) - key = f.get("key", [""])[0] - r = urllib.parse.urlunparse(p._replace(fragment="", query="")) - if not r.endswith("/"): - r += "/" - return r, key - - -def _headers(arg_base_url: str) -> dict: - _, key = _parse_base_url(arg_base_url) - if key: - return {"api-key": key} + if key := f.get("key", [""])[0]: + headers = {"api-key": key} + info = f"api-key: {key[0:2]}...{key[-2:]}" + elif key := f.get("auth", [""])[0]: + headers = {"authorization": "Bearer " + key} + info = f"authorization: Bearer {key[0:2]}...{key[-2:]}" + url = urllib.parse.urlunparse(p._replace(fragment="", query="")) + if url.endswith("/"): + url = url[:-1] + if url.endswith("/run"): + url = url[:-4] + "/" + path + elif url.endswith("/guidance") and path == "run": + url = url else: - return {} - - -def _mk_url(path: str, arg_base_url: str) -> str: - pref, _ = _parse_base_url(arg_base_url) - return pref + path - - -def req(tp: str, path: str, base_url: str, **kwargs): - url = _mk_url(path, arg_base_url=base_url) - headers = _headers(arg_base_url=base_url) - resp = requests.request(tp, url, headers=headers, **kwargs) - return resp + url = url + "/" + path + info = f"{url} ({info})" + return url, headers, info diff --git a/guidance/models/_byte_tokenizer.py b/guidance/models/_byte_tokenizer.py index 9d0d3e6e2..a4a342599 100644 --- a/guidance/models/_byte_tokenizer.py +++ b/guidance/models/_byte_tokenizer.py @@ -1,15 +1,28 @@ import numpy as np from ._tokenizer import Tokenizer from ..chat import load_template_class -import typing +from typing import List class ByteTokenizer(Tokenizer): def __init__(self, chat_template=None): # directly map integer values to byte strings - tokens = np.array([bytes([i]) for i in range(256)], dtype="object") + all_bytes = [bytes([i]) for i in range(256)] + bos = b"" + tokens = np.array(all_bytes + [bos], dtype="object") chat_template = load_template_class(chat_template) - super().__init__(tokens, chat_template) + super().__init__(tokens, chat_template, bos_token_id=256) - def __call__(self, byte_string) -> typing.List[int]: + def encode(self, byte_string: bytes) -> List[int]: """Returns a list of tokens that represent the given byte string.""" - return list(byte_string) + if isinstance(byte_string, str): + byte_string = byte_string.encode("utf8") + i = 0 + result = [] + while i < len(byte_string): + if byte_string[i:i+3] == b'': + result.append(256) + i += 3 # Skip the next two characters as part of '' + else: + result.append(byte_string[i]) + i += 1 + return result diff --git a/guidance/models/_googleai.py b/guidance/models/_googleai.py index 8cf1c2a5a..4351e39b6 100644 --- a/guidance/models/_googleai.py +++ b/guidance/models/_googleai.py @@ -4,6 +4,14 @@ import tiktoken import os +try: + import google.generativeai as genai + + has_genai = True +except ImportError: + has_genai = False + + _image_token_pattern = re.compile(r"<\|_image:(.*)\|>") @@ -18,9 +26,7 @@ def __init__( compute_log_probs, **kwargs, ): - try: - import google.generativeai as genai - except ModuleNotFoundError: + if not has_genai: raise Exception( "Please install the Google AI Studio(makersuite.google.com) package using `pip install google-generativeai google-ai-generativelanguage` in order to use guidance.models.GoogleAI!" ) diff --git a/guidance/models/_grammarless.py b/guidance/models/_grammarless.py index fd0c03f05..9a622d426 100644 --- a/guidance/models/_grammarless.py +++ b/guidance/models/_grammarless.py @@ -1,3 +1,4 @@ +import os import logging import queue import threading @@ -14,6 +15,13 @@ logger = logging.getLogger(__name__) +try: + from .. import cpp # type: ignore[attr-defined] +except ImportError: + logger.warn( + "Failed to load guidance.cpp, falling back to Python mirror implementations..." + ) + from .. import _cpp as cpp class GrammarlessTokenizer(Tokenizer): def __init__(self, tokenizer): @@ -122,7 +130,7 @@ def __init__(self, tokenizer): super().__init__(byte_tokens, chat_template, bos_token_id, eos_token_id) - def encode(self, byte_string: bytes) -> Sequence[int]: + def encode(self, byte_string: bytes) -> list[int]: """Returns a list of tokens that represent the given byte string.""" assert isinstance(byte_string, bytes) return self._orig_tokenizer.encode(byte_string.decode()) @@ -165,6 +173,11 @@ def __init__( # build the Engine super().__init__(tokenizer=tokenizer, compute_log_probs=compute_log_probs) + # build a prefix tree of the tokens + self._token_trie = cpp.ByteTrie( + self.tokenizer.tokens, np.arange(len(self.tokenizer.tokens)) + ) + def _generator(self, prompt: bytes, temperature: float): raise NotImplementedError("Child classes must implement _generator()") @@ -199,12 +212,8 @@ def _start_generator_stream(self, generator): dqueue.get() dqueue.put(e) - if self._running_stream(): - dqueue.put(self.tokenizer.eos_token) self._not_running_stream.set() - dqueue.put( - b"" - ) # so we never get stuck waiting for a running stream to return something + dqueue.put(b"") # so we never get stuck waiting for a running stream to return something def _start_new_stream(self, prompt: bytes, temperature: float) -> None: assert isinstance(prompt, bytes) @@ -248,25 +257,25 @@ def _reset_shared_data(self, new_data: bytes, temperature: float): self._data = new_data self._last_stream_start = self._data - def get_logits( - self, token_ids: Sequence[int], forced_bytes: bytes, current_temp: float - ): - """Computes the logits for the given token state. - - This overrides a method from the Local class that is used to get - inference results from the model. - """ + def get_next_token( + self, token_ids: list[int], mask: Optional[bytes], temperature: float) -> int: logger.debug( - f"Start Grammarless.get_logits({token_ids=}, {forced_bytes=}, {current_temp=})" + f"Start Grammarless.get_next_token({token_ids=}, {mask=}, {temperature=})" ) if len(token_ids) == 0: raise ValueError("token_ids must contain some tokens.") # compute the prompt bytes + # TODO: we need to get the forced bytes from the mask -- should streamline this? + if mask is not None: + forced_bytes = os.path.commonprefix([self.tokenizer.tokens[i] for i, b in enumerate(mask) if b != 0]) + else: + forced_bytes = b"" + whole_token_prompt = self.tokenizer.decode(token_ids) prompt = whole_token_prompt + forced_bytes - logger.debug(f"Grammarless.get_logits: {prompt=}") + logger.debug(f"Grammarless.get_next_token: {prompt=}") self._last_call = time.time() @@ -274,35 +283,35 @@ def get_logits( token_id = None restarted = False # track if we have restarted the data stream during this call while True: - logger.debug(f"Grammarless.get_logits: Starting main loop") + logger.debug(f"Grammarless.get_next_token: Starting main loop") # if the generation temperature changes we have to restart - if self._current_temp != current_temp: - logger.debug(f"Grammarless.get_logits: Starting new stream") - self._start_new_stream(prompt, current_temp) + if self._current_temp != temperature: + logger.debug(f"Grammarless.get_next_token: Starting new stream") + self._start_new_stream(prompt, temperature) continue # try and get the next token id elif self._data.startswith(prompt): - logger.debug(f"Grammarless.get_logits: Getting next token id") + logger.debug(f"Grammarless.get_next_token: Getting next token id") token_id = self._get_next_token(len(prompt) - len(forced_bytes)) - logger.debug(f"Grammarless.get_logits: {token_id=}") + logger.debug(f"Grammarless.get_next_token: {token_id=}") if token_id is not None: # if we have a non-zero sampling temperature we can't reuse bytes new_used_len = len(whole_token_prompt) + len( self.tokenizer.tokens[token_id] ) - logger.debug(f"Grammarless.get_logits: {new_used_len=}") - if current_temp > 0 and self._used_bytes_len >= new_used_len: - logger.debug(f"Grammarless.get_logits: Need to restart stream") + logger.debug(f"Grammarless.get_next_token: {new_used_len=}") + if temperature > 0 and self._used_bytes_len >= new_used_len: + logger.debug(f"Grammarless.get_next_token: Need to restart stream") token_id = None - self._start_new_stream(prompt, current_temp) + self._start_new_stream(prompt, temperature) continue # ...otherwise we have found the token id we want to emit else: - logger.debug(f"Grammarless.get_logits: Found token id") + logger.debug(f"Grammarless.get_next_token: Found token id") self._used_bytes_len = len(whole_token_prompt) + len( self.tokenizer.tokens[token_id] ) @@ -312,7 +321,7 @@ def get_logits( elif not self._data.startswith(prompt) and len(self._data) >= len( prompt ): # not prompt.startswith(self._data): # len(self._data) >= len(prompt) or - logger.debug(f"Grammarless.get_logits: Data will not match prompt") + logger.debug(f"Grammarless.get_next_token: Data will not match prompt") # check if we have already restarted once and so retrying by default is not likely to be helpful if restarted: raise self._report_failed_match(prompt) @@ -328,7 +337,7 @@ def get_logits( if not found_mismatch: match_len = len(prompt) leftover = prompt[match_len:] - logger.debug(f"Grammarless.get_logits: {leftover=}") + logger.debug(f"Grammarless.get_next_token: {leftover=}") # record any active non-empty role ends. Ignore role ends that are spaces parts: Sequence[Optional[bytes]] = [ @@ -347,7 +356,7 @@ def get_logits( # see if adding an end token would work here (if so we avoid recalling the server and just produce an end token) found_match = False for p in parts: - logger.debug(f"Grammarless.get_logits: Considering part {str(p)}") + logger.debug(f"Grammarless.get_next_token: Considering part {str(p)}") if p is not None: if p.startswith(leftover): self._data = self._data[:match_len] + p @@ -363,7 +372,7 @@ def get_logits( f"restarting a stream because the data we have does not match the ids. We have {str(self._data)} but the prompt is {str(prompt)}" ) restarted = True - self._start_new_stream(prompt, current_temp) + self._start_new_stream(prompt, temperature) # extend our data with a chunk from the model stream if not self._data_queue.empty(): @@ -381,11 +390,15 @@ def get_logits( # but if there is nothing and we are not running then we start a stream elif self._not_running_stream.is_set(): + if (self.tokenizer.eos_token_id is not None) and ( + mask is None or mask[self.tokenizer.eos_token_id] != 0 + ): + return self.tokenizer.eos_token_id logger.debug( "starting a new stream because there is no data to read and no stream running..." ) restarted = True - self._start_new_stream(prompt, current_temp) + self._start_new_stream(prompt, temperature) # we wait for the running stream to put something in the queue else: @@ -400,20 +413,7 @@ def get_logits( # reset out call time to allow the data stream to time out if we happen to be done with it self._last_call = time.time() - # # if we don't have the next byte of data yet then we wait for it (from the streaming thread) - # if len(self._data) == len(prompt): - # self._data += self._data_queue.get() - - # token_id = self._get_next_token(len(prompt)) - - # set the logits to the next byte the model picked - logger.debug(f"Grammarless.get_logits: Creating logits for {token_id=}") - logits = np.ones(len(self.tokenizer.tokens)) * -np.inf - logits[token_id] = 100 - if token_id != self.tokenizer.eos_token: - # we always allow the model to use EOS if that is the only way forward - logits[self.tokenizer.eos_token_id] = 0 - return logits + return token_id def _report_failed_match(self, prompt: bytes): logger.debug(f"_report_failed_match: {prompt=}") diff --git a/guidance/models/_guidance_engine_metrics.py b/guidance/models/_guidance_engine_metrics.py deleted file mode 100644 index cc2c36cfb..000000000 --- a/guidance/models/_guidance_engine_metrics.py +++ /dev/null @@ -1,6 +0,0 @@ -from pydantic import BaseModel, NonNegativeInt - - -class GuidanceEngineMetrics(BaseModel): - engine_input_tokens: NonNegativeInt = 0 - engine_output_tokens: NonNegativeInt = 0 diff --git a/guidance/models/_mock.py b/guidance/models/_mock.py index a4fa1e8e5..0f1e48b41 100644 --- a/guidance/models/_mock.py +++ b/guidance/models/_mock.py @@ -1,21 +1,56 @@ -from typing import Sequence - +from typing import Sequence, Optional import numpy as np +import logging from ._model import Engine, Model, Chat from ._remote import RemoteEngine from ._tokenizer import Tokenizer +logger = logging.getLogger(__name__) + +# TODO: this import pattern happens in a few places, should be cleaned up +try: + from .. import cpp # type: ignore[attr-defined] +except ImportError: + logger.warn( + "Failed to load guidance.cpp, falling back to Python mirror implementations..." + ) + from .. import _cpp as cpp class MockTokenizer(Tokenizer): def __init__(self, tokens: Sequence[bytes]): - super().__init__(tokens, chat_template=None, bos_token_id=0, eos_token_id=0) + self.byte_trie = cpp.ByteTrie(self.tokens, np.arange(len(self.tokens))) + + def encode(self, byte_string: bytes) -> list[int]: + """Simple greedy tokenizer + TODO: could be a method on ByteTrie if we want to reuse it + """ + pos = 0 + tokens = [] + while pos < len(byte_string): + current_node = self.byte_trie + last_match = None + match_pos = pos + + while match_pos < len(byte_string) and current_node.has_child(byte_string[match_pos : match_pos + 1]): + current_node = current_node.child(byte_string[match_pos : match_pos + 1]) + if current_node.value >= 0: + last_match = (current_node.value, match_pos + 1) + match_pos += 1 + + if last_match is not None: + tokens.append(last_match[0]) + pos = last_match[1] + else: + raise ValueError(f"Could not find a match for byte {byte_string[pos]} at position {pos}") - def recode(self, tokens: Sequence[int]) -> Sequence[int]: - # Make a no-op for now return tokens + def recode(self, tokens: Sequence[int]) -> list[int]: + # Make a no-op for now + return list(tokens) + class MockEngine(Engine): def __init__(self, tokenizer, byte_patterns, compute_log_probs, force): @@ -45,10 +80,12 @@ def __init__(self, tokenizer, byte_patterns, compute_log_probs, force): # seed the random number generator self._rand_generator = np.random.default_rng(seed=42) - def get_logits(self, token_ids, forced_bytes, current_temp): - """Pretends to compute the logits for the given token state.""" - self.called_temperatures.append(current_temp) + def get_next_token(self, token_ids: list[int], mask: Optional[bytes], temperature: float) -> int: + self.called_temperatures.append(temperature) + return super().get_next_token(token_ids, mask, temperature) + def get_logits(self, token_ids: list[int]) -> np.ndarray: + """Pretends to compute the logits for the given token state.""" # build the byte strings byte_string = b"".join(self.tokenizer.tokens[i] for i in token_ids) @@ -76,6 +113,15 @@ def get_logits(self, token_ids, forced_bytes, current_temp): return logits def _get_next_tokens(self, byte_string): + special_tokens = [ + (self.tokenizer.bos_token_id, self.tokenizer.bos_token), + (self.tokenizer.eos_token_id, self.tokenizer.eos_token) + ] + for i, t in special_tokens: + # if the byte string starts with a special token then make sure we don't yield any other tokens + if byte_string.startswith(t): + yield i + return for i, t in enumerate(self.tokenizer.tokens): if byte_string.startswith(t): yield i diff --git a/guidance/models/_model.py b/guidance/models/_model.py index b7f9eb4cb..74f3c8af0 100644 --- a/guidance/models/_model.py +++ b/guidance/models/_model.py @@ -8,11 +8,9 @@ import threading import time import warnings -from typing import Union - from pprint import pprint -from typing import Dict, TYPE_CHECKING +from typing import Any, Dict, Iterator, List, Optional, Union, TYPE_CHECKING import numpy as np @@ -23,25 +21,12 @@ ipython_is_imported = True except ImportError: ipython_is_imported = False -try: - import torch - - torch_is_imported = True -except ImportError: - torch_is_imported = False - logger = logging.getLogger(__name__) -try: - from .. import cpp # type: ignore[attr-defined] -except ImportError: - logger.warn( - "Failed to load guidance.cpp, falling back to Python mirror implementations..." - ) - from .. import _cpp as cpp -from ._guidance_engine_metrics import GuidanceEngineMetrics + +from .._schema import EngineCallResponse, GuidanceEngineMetrics from .._utils import softmax, CaptureEvents -from .._parser import EarleyCommitParser, Parser +from .._parser import TokenParser from .._grammar import ( GrammarFunction, string, @@ -52,10 +37,6 @@ unreplace_model_variables, select, ) - -from .. import _serialization_pb2 -from ..chat import load_template_class - from ._tokenizer import Tokenizer if TYPE_CHECKING: @@ -71,116 +52,6 @@ image_pattern = re.compile(r"<\|_image:(.*?)\|>") - - -class EngineCallResponse: - new_bytes: bytes - is_generated: bool - new_bytes_prob: float - capture_groups: dict - capture_group_log_probs: dict - new_token_count: int - - def __init__( - self, - new_bytes, - is_generated, - new_bytes_prob, - capture_groups, - capture_group_log_probs, - new_token_count, - ): - self.new_bytes = new_bytes - self.is_generated = is_generated - self.new_bytes_prob = new_bytes_prob - self.capture_groups = capture_groups - self.capture_group_log_probs = capture_group_log_probs - self.new_token_count = new_token_count - - def _to_proto(self): - """Converts an EngineCallResponse object to its Protobuf representation. - - Returns: - engine_response_pb2.EngineCallResponse: The Protobuf equivalent of this object. - """ - groups = {} - group_log_probs = {} - - def to_protobuf_value(v: Union[str, bytes, float, list]) -> _serialization_pb2.Value: - """Convert Python values to Protobuf Value messages.""" - value = _serialization_pb2.Value() - if isinstance(v, str): - value.string_value = v - elif isinstance(v, bytes): - value.bytes_value = v - elif isinstance(v, float): - value.float_value = v - elif isinstance(v, list): - for item in v: - value.list_value.values.append(to_protobuf_value(item)) - else: - raise TypeError(f"Unsupported type: {type(v)}") - return value - - for k, v in self.capture_groups.items(): - groups[k] = to_protobuf_value(v) - - for k, v in self.capture_group_log_probs.items(): - group_log_probs[k] = to_protobuf_value(v) - - return _serialization_pb2.EngineCallResponse( - new_bytes=self.new_bytes, - is_generated=self.is_generated, - new_bytes_prob=self.new_bytes_prob, - capture_groups=groups, - capture_group_log_probs=group_log_probs, - new_token_count=self.new_token_count, - ) - - def encode(self, charset): - """Used to support FastAPI encoding of EngineCallResponse objects.""" - return self.serialize() - - def serialize(self): - proto = self._to_proto() - return proto.SerializeToString() - - @staticmethod - def deserialize(byte_data): - proto = _serialization_pb2.EngineCallResponse() - proto.ParseFromString(byte_data) - - def from_protobuf_value(value: _serialization_pb2.Value) -> Union[str, bytes, float, list]: - """Convert Protobuf Value message to Python values""" - if value.HasField("string_value"): - return value.string_value - elif value.HasField("bytes_value"): - return value.bytes_value - elif value.HasField("float_value"): - return value.float_value - elif value.HasField("list_value"): - return [from_protobuf_value(item) for item in value.list_value.values] - else: - raise ValueError("Protobuf Value message has no recognized field set") - - groups = {} - for k, v in proto.capture_groups.items(): - groups[k] = from_protobuf_value(v) - - group_log_probs = {} - for k, v in proto.capture_group_log_probs.items(): - group_log_probs[k] = from_protobuf_value(v) - - return EngineCallResponse( - new_bytes=proto.new_bytes, - is_generated=proto.is_generated, - new_bytes_prob=proto.new_bytes_prob, - capture_groups=groups, - capture_group_log_probs=group_log_probs, - new_token_count=proto.new_token_count, - ) - - class Engine: """The engine owns the inference computation and is used/created by the Model class. @@ -193,14 +64,6 @@ class Engine: def __init__(self, tokenizer: Tokenizer, compute_log_probs=False): self.tokenizer = tokenizer self.compute_log_probs = compute_log_probs - - # build a prefix tree of the tokens - self._token_trie = cpp.ByteTrie( - self.tokenizer.tokens, np.arange(len(self.tokenizer.tokens)) - ) - self._token_trie.match = True - self._token_trie.match_version = 0 - self.metrics = GuidanceEngineMetrics() def get_chat_template(self): # TODO [HN]: Add more logic here...should we instantiate class here? do we even need to? @@ -209,645 +72,113 @@ def get_chat_template(self): # TODO [HN]: Add more logic here...should we instan def reset_metrics(self): self.metrics = GuidanceEngineMetrics() - def start(self, parser, grammar, ensure_bos_token=True): + def start(self, prompt, grammar, ensure_bos_token=True) -> TokenParser: """Start processing parser state executed through the grammar. Parameters ---------- - parser : str or Parser + prompt : str or Parser This is represents the current state of a guidance parser that will be extended using the passed grammar. If a string is given then we assume the previous parser state is just a fixed string prompt, if a full Parser is given then we extend that parser by appending the new grammar to the parser's current grammar and then inferencing the model. (TODO: implement full parser extension support) grammar: Grammar - This is the grammar we are extending the parser with. + This is the grammar we are extending the prompt with. """ # def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, ensure_bos_token=True): # assert n == 1, "Still need to add support for n > 1!" + # TODO: re-enable this? llguidance currently doesn't support model variables # note we only support a fixed set of engine variables for the sake of security - self._replacements = replace_model_variables( - grammar, self, allowed_vars=["eos_token", "bos_token"] - ) + # self._replacements = replace_model_variables( + # grammar, self, allowed_vars=["eos_token", "bos_token"] + # ) # right now we only support a text/bytes prompt parser state, so we extract that - if isinstance(parser, bytes): - prompt = parser - elif isinstance(parser, str): - prompt = bytes(parser, encoding="utf8") - elif isinstance(parser, Parser): + if isinstance(prompt, bytes): + prompt = prompt + elif isinstance(prompt, str): + prompt = bytes(prompt, encoding="utf8") + elif isinstance(prompt, TokenParser): raise NotImplementedError( "Still need to implement support for extending a full Parser state." ) else: - raise Exception("The passed parser is of an unknown type!") - - # add the beginning of sequence token if needed - if ( - ensure_bos_token - and self.tokenizer.bos_token is not None - and not prompt.startswith(self.tokenizer.bos_token) - ): - prompt = self.tokenizer.bos_token + prompt - - # run a simple tokenizer (that does not use a grammar) on the prefix for better performance - self._token_ids, self._token_byte_positions = self._tokenize_prefix(prompt) - self._token_ids, self._token_byte_positions = self._cleanup_tokens( - self._token_ids, self._token_byte_positions - ) - if len(self._token_byte_positions) > 0: - self._pre_parser_bytes = self._token_byte_positions[-1] - self._trimmed_prompt_prefix = prompt[: self._token_byte_positions[-1]] - prompt = prompt[self._token_byte_positions[-1] :] - else: - self._trimmed_prompt_prefix = b"" - self._pre_parser_bytes = 0 - - # create a parser with a grammar that includes both our context and the passed grammar - self._parser = EarleyCommitParser(prompt + grammar) - - # loop until we have generated a complete pattern - self._hidden_count = len(prompt) # we don't emit the prompt - self._generated_pos = 0 - self._sampled_token_ind = None - self._token_count = 0 - self._last_token_count = 0 - self._was_forced = False - self._captured_data = {} - self._captured_log_prob_data = {} - - def next(self, logits): - """Move the grammar state machine processing forward to the next point where - either get_logits is required to be called or we have a partial response - to stream back. - - Parameters - ---------- - logits : the logits obtained from the LLM after the last return from next(...) - """ - - logits_state = None - response_state = None - - token_pos = 0 - is_generated = True - - is_new_token = False - if logits is not None: - is_new_token = True - - # if requested we compute the log probabilities so we can track the probabilities of each node - if self.compute_log_probs: - if torch_is_imported: - # note we don't adjust for temp since we consider that a sampling step, not part of the probs - probs = torch.nn.functional.softmax(torch.tensor(logits), dim=-1).cpu().numpy() - else: - # this numpy code is slower, so we don't use it if we have torch... - probs = softmax(logits, axis=-1) - self.tokenizer.clean_duplicate_tokens(probs) - self._trie.compute_probs(probs) # C++ impl - else: - probs = None - - grammar_temp = self._parser.next_byte_temperature() - current_temp = grammar_temp if grammar_temp >= 0 else 0 - - # get the sampling order - if current_temp == 0: - # we need numpy so the enumerate below does not get really slow... - sampling_order = np.argsort(-logits) - else: - # assert top_p == 1, "Still need to add support for top_p!" - if torch_is_imported: - logits = torch.tensor(logits) - torch.div(logits, current_temp, out=logits) - probs_torch = torch.nn.functional.softmax(logits, dim=-1) - sampling_order = torch.multinomial(probs_torch, len(probs_torch)).cpu().numpy() - else: - # this numpy version allows us to drop our dependence on pytorch...but it is way slower - if probs is None: - probs = softmax(logits / current_temp, axis=-1) - probs += 1e-10 # ensure we have no zero probs that mess up numpy - probs /= np.sum(probs) - sampling_order = np.random.choice( - len(probs), size=len(probs), p=probs, replace=False - ) # the 1e-10 is ensure we have no zero probs, which numpy does not like - - # loop over the tokens looking for a valid one - for i, self._sampled_token_ind in enumerate(sampling_order): - self._sampled_token = self.tokenizer.tokens[self._sampled_token_ind] - - # break out if we have reach impossible tokens - if logits[self._sampled_token_ind] <= -np.inf: - break - - # make sure it matches any forced prefix - used_forced_pos = min(self._forced_pos, self._start_pos + len(self._sampled_token)) - if ( - self._start_pos < self._forced_pos - and not self._sampled_token.startswith( - self._parser.bytes[self._start_pos : used_forced_pos] - ) - ): - continue - offset = used_forced_pos - self._start_pos - - # make sure the parse is backed up to the position we want to start checking from TODO: make this account for shared prefixes with the last token - self._parser.pos = used_forced_pos - self._new_bytes_prob = 1.0 - - # if we have gotten to the end of the valid tokens then we stop - # if logits[self._sampled_token_ind] == -np.inf: - # raise self._report_failed_match(self._trimmed_prompt_prefix + self._parser.bytes) - - # check to see if the sampled token is allowed - token_pos = offset - - # this is the Trie node we were left at when we could force the next byte above - node = self._trie - - while token_pos < len(self._sampled_token): - next_byte = self._sampled_token[token_pos : token_pos + 1] - next_node = node.child(next_byte) - - # if we don't have a cached match flag compute it using the grammar - if next_node.match_version < self._token_trie.match_version: - next_byte_mask = self._parser.next_byte_mask() - - # we update all the children since the parser knows the full mask - for byte in node.keys(): - child = node.child(byte) - child.match_version = self._token_trie.match_version - child.match = next_byte_mask[byte[0]] - - # advance or fail according to the (now up-to-date) match cache - if next_node.match: + raise Exception("The passed prompt is of an unknown type!") - # get the parser to consume the next byte - if next_node.prob < 1e-8: - if node.prob < 1e-8: - log_prob_delta = 0 - else: - log_prob_delta = -20 - else: - log_prob_delta = np.log(next_node.prob) - np.log(node.prob) - # log_prob_delta = np.log(next_node.prob) - np.log(node.prob) - self._new_bytes_prob = next_node.prob - commit_point = self._parser.consume_byte( - next_byte, log_prob=log_prob_delta - ) - - # mark that we accepted this byte - node = next_node - token_pos += 1 - - # if we are at a hidden commit point then we need to hide the bytes that match that node - if commit_point is not None and commit_point.node.hidden: - - # if we are capturing the data from this node we need to do that now since we are about to remove it - # TODO: build a whole parse tree under this commit_point node so we can record child node captures - if commit_point.node.capture_name: - self._captured_data[commit_point.node.capture_name] = ( - self._parser.bytes[commit_point.start :] - ) - self._captured_log_prob_data[ - commit_point.node.capture_name - ] = commit_point.log_prob - - # This takes the item and commits to it as part of the parse and then shrinks it to zero width - # in other words this hides the item - self._parser.commit_and_collapse_item(commit_point) - - # keep the bytes we still need to emit - if self._forced_pos < commit_point.start: - self._parser.shadow_rewind(self._forced_pos) - - else: - # pop off any tokens that overlap the hidden bytes - i = len(self._token_byte_positions) - 1 - while ( - i >= 0 - and self._token_byte_positions[i] - - self._pre_parser_bytes - > commit_point.start - ): - self._token_ids.pop() - self._token_byte_positions.pop() - self._token_count -= 1 - i -= 1 - # re-add any bytes we cut too far on - self._parser.shadow_rewind( - self._token_byte_positions[-1] - - self._pre_parser_bytes - ) - is_new_token = False - break - - elif token_pos == len(self._sampled_token): - break # this token is valid - else: - # partially valid tokens are okay if we are running off the end of a grammar, but not otherwise - if not self._parser.matched(): - token_pos = -1 - - break # this token is no longer valid - - # see if we are breaking out of the whole loop - if not is_new_token: - break - - # check if this token is dominated by other longer valid tokens (and hence would never be consistent with greedy tokenization) - # TODO: disabled for now because of sentencepeice non-local issues - # if token_pos == len(self._sampled_token) and not self._parser.matched(): # not we don't check if we have matched, because then we can generate anything afterwards - # if _check_dominated(node, self._parser, self._token_trie.match_version, self._parser.next_byte_mask()): - # token_pos = -1 - - if token_pos > 0: - break # we found a valid token - - if self._parser.matched(): - break # if we already have a full match we don't try more tokens we just give up as soon as the model deviates from the grammar - - is_done = False - while True: # each iteration generates one more token (and some of the associated bytes) - if is_new_token: - # emit whatever we know will not be hidden - new_bytes = self._parser.bytes[self._generated_pos : self._parser.earliest_hidden_start()] - - # if we cannot consume any more tokens then we are done - if ( - not self._is_forced - and token_pos < len(self._sampled_token) - and self._trie == self._token_trie - ): - - # which if can't consume any more tokens, but we are not yet done - if not self._parser.matched(): - self._parser.matched() - raise self._report_failed_match( - self._trimmed_prompt_prefix + self._parser.bytes - ) - - # TODO: if we exactly match the end of the pattern then we can commit to this last token - # if m.span()[1] == len(generated_text): - # self._cache_state["new_token_ids"].append(self._sampled_token_ind) - - # capture the named groups from the parse tree - self._parser.get_captures(self._captured_data, self._captured_log_prob_data) - - # we have no valid log prob data if we didn't compute it - # yield new_bytes[self._hidden_count:], self._is_generated, self._new_bytes_prob, self._captured_data, self._captured_log_prob_data, token_count - last_token_count - - response_state = ( - new_bytes[self._hidden_count :], - is_generated, - self._new_bytes_prob if self.compute_log_probs else 1.0, - self._captured_data, - self._captured_log_prob_data, - self._token_count - self._last_token_count, - ) - - self._last_token_count = self._token_count - - # TODO: we only need to do this when we might re-use the grammar object...we might want to account for that - unreplace_model_variables(self._replacements) - - is_done = True - else: - self._generated_pos += len(new_bytes) - - # yeild the snippet of text created by the next token - out = new_bytes[self._hidden_count :] - if len(out) > 0: - # capture the named groups from the (partial) parse tree, # TODO: disabled for now until we handle list_append correctly - # new_captured_data, new_captured_log_prob_data = self._parser.get_captures() - # self._captured_data.update(new_captured_data) - # self._captured_log_prob_data.update(new_captured_log_prob_data) - # yield out, self._is_generated, self._new_bytes_prob, self._captured_data, self._captured_log_prob_data, self._token_count - self._last_token_count # note that we don't capture groups until a complete parse right now... - - response_state = ( - out, - is_generated, - self._new_bytes_prob if self.compute_log_probs else 1.0, - self._captured_data, - self._captured_log_prob_data, - self._token_count - self._last_token_count, - ) - - self._last_token_count = self._token_count - self._hidden_count = 0 - self._token_count += 1 # note we only update this for tokens that emit non-hidden content - else: - self._hidden_count -= len(new_bytes) - - self._token_ids.append(self._sampled_token_ind) - - # track the byte position of each token - if len(self._token_byte_positions) == 0: - self._token_byte_positions.append(len(self._sampled_token)) - else: - self._token_byte_positions.append( - self._token_byte_positions[-1] + len(self._sampled_token) - ) - - if response_state is not None: - break - - token_pos = 0 - is_generated = False - - is_new_token = True - - # note where we are starting for this token - self._start_pos = self._parser.pos - - # let the parser know that we have advanced another token (used ofr tracking max token limits) - self._parser.mark_new_token() - - # walk down the trie as far as possible before computing the logits - self._trie = self._token_trie - - # this invalidates all the match caches from the previous token - self._trie.match_version += 1 - # self._trie.prob = 0.0 # need to reset when we reset the match_version - while True: - next_byte_mask = self._parser.next_byte_mask() - next_byte_mask_sum = next_byte_mask.sum() - - # see if we reached a dead end of the grammar - if next_byte_mask_sum == 0: - break - - # if there is more than one option we cannot advance without computing the logits - elif next_byte_mask_sum != 1: - break - - # we are not forced if we are at the end of the grammar - elif self._parser.matched(): - break - - # if there is only one possible next byte we can keep forcing - elif next_byte_mask_sum == 1: - - # look for valid children - next_byte = None - for byte in self._trie.keys(): - - # mark this self._trie node with an up-to-date match flag (may save work later) - node = self._trie.child(byte) - node.match_version = self._token_trie.match_version - # node.prob = 0.0 # reset when we reset the match_version - node.match = next_byte_mask[byte[0]] - - # see if we found a match - if node.match: - next_byte = byte - break - - # if we can't extend then this token is forced - if next_byte is None: - break - - # otherwise since there is only one possible next byte we keep going - else: - commit_point = self._parser.consume_byte( - next_byte, log_prob=0.0 - ) - - # if we are at a hidden commit point then we need to hide the bytes that match that node - if commit_point is not None and commit_point.node.hidden: - - # This takes the item and commits to it as part of the parse and then shrinks it to zero width - # in other words this hides the item - self._parser.commit_and_collapse_item(commit_point) - - # keep the bytes we still need to emit - if self._start_pos < commit_point.start: - self._parser.shadow_rewind(self._start_pos) - - else: - # pop off any tokens that overlap the hidden bytes - i = len(self._token_byte_positions) - 1 - while ( - i >= 0 - and self._token_byte_positions[i] - - self._pre_parser_bytes - > commit_point.start - ): - self._token_ids.pop() - self._token_byte_positions.pop() - self._token_count -= 1 - i -= 1 - # re-add any bytes we cut too far on - self._parser.shadow_rewind( - self._token_byte_positions[-1] - - self._pre_parser_bytes - ) - is_new_token = False # this restarts us at the top of the outer token gen loop - break - - self._trie = self._trie.child(next_byte) - - self._forced_pos = self._parser.pos # record how far the bytes are forced - - if is_new_token: - # back up if we got forced up to a point that is not a valid token - if next_byte_mask_sum <= 1: - while self._trie.value < 0 and self._trie.parent() is not None: - self._trie = self._trie.parent() - self._forced_pos -= 1 - self._parser.pos = self._forced_pos - - # if we walked all the way to a forced token then we advance without computing the logits - # we are forced if there are no more options and we are either in the middle of the grammar or at a trie leaf - self._is_forced = next_byte_mask_sum <= 1 and ( - len(self._trie) == 0 - if self._parser.matched() - else self._trie != self._token_trie - ) - if self._is_forced: - self._sampled_token_ind = self._trie.value - self._sampled_token = self.tokenizer.tokens[self._sampled_token_ind] - self._new_bytes_prob = 1.0 - self._was_forced = True - - # we are at the end of the grammar - elif next_byte_mask_sum == 0: - - # mark the token we "sampled" if we have comsumed some bytes - if self._trie != self._token_trie: - self._sampled_token_ind = self._trie.value - self._sampled_token = self.tokenizer.tokens[ - self._sampled_token_ind - ] - self._new_bytes_prob = 1.0 - - # otherwise we need to compute the logits and sample a valid token - else: - - # if we were forced we might need to clean up the greedy tokenization to match the global tokenization behavior as seen in training - if self._was_forced: - self._token_ids, self._token_byte_positions = ( - self._cleanup_tokens( - self._token_ids, self._token_byte_positions - ) - ) - self._was_forced = False - - grammar_temp = self._parser.next_byte_temperature() - current_temp = grammar_temp if grammar_temp >= 0 else 0 - logits_state = ( - self._token_ids, - self._parser.bytes[self._start_pos : self._forced_pos], - current_temp, - ) - break - - return is_done, logits_state, response_state + return TokenParser( + grammar=grammar, + tokenizer=self.tokenizer, + prompt=prompt, + ensure_bos_token=ensure_bos_token + ) - def __call__(self, parser, grammar, ensure_bos_token=True): - """Returns a new updated parser state executed through the grammar. + def __call__(self, prompt, grammar, ensure_bos_token=True) -> Iterator[EngineCallResponse]: + """Main entry point for the inference-parser loop. Yields EngineCallResponse objects as + the parser advances through the grammar. Parameters ---------- - parser : str or Parser + prompt : str or Parser This is represents the current state of a guidance parser that will be extended using the passed grammar. If a string is given then we assume the previous parser state is just a fixed string prompt, if a full Parser is given then we extend that parser by appending the new grammar to the parser's current grammar and then inferencing the model. (TODO: implement full parser extension support) grammar: Grammar - This is the grammar we are extending the parser with. + This is the grammar we are extending the prompt with. """ - - self.start(parser, grammar, ensure_bos_token) - - logits = None - while True: - is_done, logits_state, response_state = self.next(logits) - logits = None - - if response_state is not None: - ( - response_new_bytes, - response_is_generated, - response_new_bytes_prob, - response_capture_groups, - response_capture_group_log_probs, - response_new_token_count, - ) = response_state - - yield EngineCallResponse( - new_bytes=response_new_bytes, - is_generated=response_is_generated, - new_bytes_prob=response_new_bytes_prob, - capture_groups=response_capture_groups, - capture_group_log_probs=response_capture_group_log_probs, - new_token_count=response_new_token_count, - ) - - if logits_state is not None: - token_ids, forced_bytes, current_temp = logits_state - logits = self.get_logits(token_ids, forced_bytes, current_temp) - - if is_done: - break - - def _tokenize_prefix(self, byte_string): - """This is used to speed up the tokenization of long prompts without using the parser.""" - token_ids = [] - token_byte_positions = [] - - # loop trying to decode a new token at each iteration - pos = 0 - while True: - - # walk down the token trie looking for a unique token match - trie = self._token_trie - valid_pos = -1 - valid_value = -1 - while True: - if pos >= len(byte_string): - if len(trie) > 0: - valid_pos = -1 - break - - # check if we can keep going or are at a dead end - if trie.has_child(byte_string[pos : pos + 1]): - trie = trie.child(byte_string[pos : pos + 1]) - pos += 1 - - # record the last valid token down this path as we go - if trie.value >= 0: - valid_pos = pos - valid_value = trie.value + parser = self.start(prompt, grammar, ensure_bos_token) + + token = None + while not parser.done(): + gen_data, response = parser.advance(token) + + if gen_data is not None: + if parser.is_accepting() and self.tokenizer.eos_token_id is not None: + # Whenever we are in an accepting state, we will allow the model to generate whatever it wants + # but we will treat any "illegal" tokens as EOS, allowing the model to finish gracefully. + assert gen_data.mask[self.tokenizer.eos_token_id] + token = self.get_next_token( + token_ids=gen_data.tokens, + mask=None, + temperature=gen_data.temperature + ) + if not gen_data.mask[token]: + token = self.tokenizer.eos_token_id else: - break # we can't go any farther - - if valid_pos == -1: - break - else: - token_ids.append(valid_value) - token_byte_positions.append(valid_pos) - pos = valid_pos - - return token_ids, token_byte_positions - - def _cleanup_tokens(self, token_ids, token_byte_positions): - - # compute a joint tokenization - joint_token_ids = self.tokenizer.recode(token_ids) - - # see if we need to redo the tokenization - redo = False - if len(joint_token_ids) != len(token_ids): - redo = True - else: - for i, id in enumerate(joint_token_ids): - if token_ids[i] != id: - redo = True - break - - if redo: - token_ids = joint_token_ids - last_pos = token_byte_positions[-1] - token_byte_positions = [] - pos = 0 - for i, id in enumerate(joint_token_ids): - pos += len(self.tokenizer.tokens[id]) - token_byte_positions.append(pos) - - # ugly hack to deal with sentence piece craziness of space hiding after special tokens - # TODO: figure out how to make this more robust - if ( - token_byte_positions[-1] == last_pos + 1 - and self.tokenizer.tokens[token_ids[0]] == b"" - and self.tokenizer.tokens[token_ids[1]][0:1] == b" " - ): - for i in range(1, len(token_byte_positions)): - token_byte_positions[i] -= 1 - - # another ugly hack for tokenizers that are not stable on encode/decode cycles - # currently only Phi-3, should generalize this method if we see more of these - if token_byte_positions[-1] != last_pos: - if not hasattr(self, "_disable_retokenize_check"): - msg = textwrap.dedent( - """Self-consistency check in _cleanup_tokens() failed. - - This is not a fatal issue, but if there are subsequent - generation problems, please include this warning in - your bug report.""" + token = self.get_next_token( + token_ids=gen_data.tokens, + mask=gen_data.mask, + temperature=gen_data.temperature ) - warnings.warn(msg) - - return token_ids, token_byte_positions + else: + token = None - def get_logits(self, token_ids, forced_bytes, current_temp): - """A fake method designed to be overriden by subclasses.""" + yield response - # pretend to extend the KV cache and update the log probs - return np.randn(len(self.tokenizer.tokens)) + def get_next_token(self, token_ids: list[int], mask: Optional[bytes], temperature: float) -> int: + """Base implementation for getting the next token from the model which calls get_logits and sample_with_temperature. + Subclasses may override this method, e.g. if they use external APIs that do not support getting logits directly. + """ + logits = self.get_logits(token_ids) + token = self.sample_with_temperature(logits, mask, temperature) + return token + + def get_logits(self, token_ids: list[int]) -> np.ndarray: + raise NotImplementedError + + def sample_with_temperature(self, logits: np.ndarray, mask: Optional[bytes], temperature: float) -> int: + if mask is not None: + logits += np.frombuffer(mask, dtype=np.uint8) + if temperature < 0.0001: + return int(np.argmax(logits)) + # Get probabilities from softmax + probabilities = softmax(logits/temperature) + # Sample an index based on the probabilities + sampled_index = np.random.choice(len(logits), p=probabilities) + return sampled_index def _report_failed_match(self, prompt): """Note that this can be overridden by subclasses that have more likely reasons than a bug in the token set (like remote models).""" @@ -1576,40 +907,3 @@ def __init__(self, *args, **kwargs): self.prompt = kwargs.pop("prompt", None) self.data = kwargs.pop("data", None) super().__init__(*args, **kwargs) - - -# def _compute_probs(trie, probs, found): -# '''Computes the log probabilities for each internal trie node.''' -# if trie.value is not None: -# found[trie.value] = 1 -# trie.prob += probs[trie.value] - -# if len(trie) > 0: -# # child_probs = [] -# for b in trie.keys(): -# child = trie.child(b) -# _compute_probs(child, probs, found) -# trie.prob += child.prob -# # trie.log_prob = np.logaddexp.reduce(child_log_probs) - - -def _check_dominated(node, parser, match_version, next_byte_mask): - curr_pos = parser.pos - for byte_num in next_byte_mask.nonzero()[0]: - next_byte = bytes((byte_num,)) - if not node.has_child(next_byte): - return False # no possible exension this direction, so we are not dominated - child = node.child(next_byte) - if child.match_version < match_version: - child.match_version = match_version - child.match = next_byte_mask[next_byte[0]] - - if not child.match: - return False # this child does not dominate the node, so the node is not dominated - elif child.value is None: # this child might not dominate the node - parser.consume_byte(next_byte, log_prob=0.0) - child_dominate = _check_dominated(child, parser, match_version, parser.next_byte_mask()) - parser.pos = curr_pos - if not child_dominate: - return False - return True \ No newline at end of file diff --git a/guidance/models/_remote.py b/guidance/models/_remote.py index 8ed3f982a..86db8fcc5 100644 --- a/guidance/models/_remote.py +++ b/guidance/models/_remote.py @@ -1,6 +1,6 @@ import requests import os -import base64 +import json from ._model import Engine, EngineCallResponse from ..chat import ChatMLTemplate @@ -33,7 +33,7 @@ def __call__(self, parser, grammar, ensure_bos_token=True): # Prepare the request data data = { "parser": parser, - "grammar": base64.b64encode(grammar.serialize()).decode("utf-8"), + "grammar": json.dumps(grammar.ll_serialize()), } headers = {"x-api-key": self.api_key, "Content-Type": "application/json"} @@ -54,5 +54,5 @@ def __call__(self, parser, grammar, ensure_bos_token=True): # Process and yield the response data # chunk_size=None means it'll stream the content for chunk in response.iter_content(chunk_size=None): - response_data = EngineCallResponse.deserialize(chunk) + response_data = EngineCallResponse.model_validate_json(chunk) yield response_data diff --git a/guidance/models/_tokenizer.py b/guidance/models/_tokenizer.py index ede843ddb..a0f1a5e15 100644 --- a/guidance/models/_tokenizer.py +++ b/guidance/models/_tokenizer.py @@ -85,7 +85,7 @@ def chat_template(self) -> Union[Any, None]: def __call__(self, byte_string: bytes): return self.encode(byte_string) - def encode(self, byte_string: bytes) -> Sequence[int]: + def encode(self, byte_string: bytes) -> list[int]: """Returns a list of tokens that represent the given byte string.""" raise NotImplementedError( "You need to use a Tokenize subclass that overrides the encode method" @@ -95,7 +95,7 @@ def decode(self, tokens: Sequence[int]) -> bytes: """Returns the bytes represented by the given list of tokens.""" return b"".join([self.tokens[t] for t in tokens]) - def recode(self, tokens: Sequence[int]) -> Sequence[int]: + def recode(self, tokens: Sequence[int]) -> list[int]: """Redoes a tokenisation. Encoding a string into tokens does not distribute over concatenation. diff --git a/guidance/models/llama_cpp/_llama_cpp.py b/guidance/models/llama_cpp/_llama_cpp.py index 93da28f15..ee000d14a 100644 --- a/guidance/models/llama_cpp/_llama_cpp.py +++ b/guidance/models/llama_cpp/_llama_cpp.py @@ -88,7 +88,7 @@ def __init__(self, model_obj, chat_template=None): tokens, chat_template, tokenizer.llama.token_bos(), tokenizer.llama.token_eos() ) - def encode(self, byte_string: bytes) -> Sequence[int]: + def encode(self, byte_string: bytes) -> list[int]: # Workaround for the LlamaCpp prepending spaces on encoding raw_tokens = self._model_obj.tokenize( self._sentinel_bytes + byte_string, add_bos=False, special=True @@ -153,7 +153,7 @@ def __init__(self, model, compute_log_probs, chat_template=None, **kwargs): self._n_vocab = len(self.tokenizer.tokens) - def get_logits(self, token_ids, forced_bytes, current_temp): + def get_logits(self, token_ids): """Computes the logits for the given token state. This overrides a method from the LocalEngine class that is used to get diff --git a/guidance/models/transformers/_transformers.py b/guidance/models/transformers/_transformers.py index 2be6af02b..d454bda2d 100644 --- a/guidance/models/transformers/_transformers.py +++ b/guidance/models/transformers/_transformers.py @@ -331,7 +331,7 @@ def _bytes_to_unicode(self): cs = [chr(n) for n in cs] return dict(zip(bs, cs)) - def encode(self, byte_string: bytes) -> Sequence[int]: + def encode(self, byte_string: bytes) -> list[int]: assert isinstance(byte_string, bytes) # HF tokenizers take in strings apparently tokenization = self._orig_tokenizer(byte_string.decode(), add_special_tokens=False) @@ -341,7 +341,7 @@ def decode(self, tokens: Sequence[int]) -> bytes: decoded_str = self._orig_tokenizer.decode(tokens) return decoded_str.encode() - def recode(self, tokens: Sequence[int]) -> Sequence[int]: + def recode(self, tokens: Sequence[int]) -> list[int]: # the encode/decode cycle might not work if we have partial unicode strings used_tokens = len(tokens) for _ in range(3): @@ -414,7 +414,6 @@ def __init__(self, model, tokenizer, compute_log_probs: bool, chat_template=None my_tokenizer, compute_log_probs=compute_log_probs, ) - assert self._token_trie.match def _model(self, model, **kwargs): # intantiate the model if needed @@ -428,7 +427,7 @@ def _model(self, model, **kwargs): model = transformers_package.AutoModelForCausalLM.from_pretrained(model, **kwargs) return model - def get_logits(self, token_ids, forced_bytes, current_temp): + def get_logits(self, token_ids): """Computes the logits for the given token state. This overrides a method from the LocalEngine class that is used to get diff --git a/pyproject.toml b/pyproject.toml index 42ff84ba8..513bbc5a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,11 +33,7 @@ module = "vertexai.*" ignore_missing_imports = true [[tool.mypy.overrides]] -module = "google.generativeai.*" -ignore_missing_imports = true - -[[tool.mypy.overrides]] -module = "google.ai.generativelanguage.*" +module = "google.*" ignore_missing_imports = true [[tool.mypy.overrides]] diff --git a/setup.py b/setup.py index eaab77b75..62a8fa067 100644 --- a/setup.py +++ b/setup.py @@ -25,10 +25,10 @@ "numpy", "ordered_set", "platformdirs", - "protobuf", "pydantic", "requests", "tiktoken>=0.3", + "llguidance", ] # Our basic list of 'extras' @@ -60,12 +60,13 @@ "bitsandbytes", "jupyter", "papermill", + "protobuf", "pytest", "pytest-cov", + "sentencepiece", "torch", "transformers", "mypy==1.9.0", - "types-protobuf", "types-regex", "types-requests", "types-jsonschema", diff --git a/tests/model_integration/library/test_commit_point.py b/tests/model_integration/library/test_commit_point.py index 94dc91a76..d0f2a156c 100644 --- a/tests/model_integration/library/test_commit_point.py +++ b/tests/model_integration/library/test_commit_point.py @@ -1,6 +1,7 @@ +import pytest from guidance import Tool, capture, commit_point, models, select, string - +@pytest.mark.xfail(reason="Commit points are not supported") def test_commit_point(selected_model: models.Model): lm = selected_model tools = [Tool(callable=lambda x: x)] diff --git a/tests/model_integration/library/test_gen.py b/tests/model_integration/library/test_gen.py index bf37c7a2a..3a03aa91c 100644 --- a/tests/model_integration/library/test_gen.py +++ b/tests/model_integration/library/test_gen.py @@ -156,8 +156,8 @@ def test_non_token_force(selected_model: models.Model): "pattern", [ "(Scott is a person|Scott is a persimmon)", - r"Scott is a persimmon.*\.", - r"\d\.*\d+", + r"Scott is a persimmon.{0,20}\.", + r"[0-9]\.{0,20}[0-9]+", ], ) def test_various_regexes(selected_model: models.Model, prompt: str, pattern: str): diff --git a/tests/model_integration/library/test_subgrammar.py b/tests/model_integration/library/test_subgrammar.py new file mode 100644 index 000000000..6134eb8ff --- /dev/null +++ b/tests/model_integration/library/test_subgrammar.py @@ -0,0 +1,84 @@ +import re +import numpy as np +import pytest +from jsonschema import validate +import json + +import guidance +from guidance import ( + gen, + select, + optional, + one_or_more, +) +from guidance.library._subgrammar import subgrammar, lexeme + + +@guidance(stateless=True) +def json_string(lm): + return lm + lexeme(r'"(\\(["\\\/bfnrt]|u[a-fA-F0-9]{4})|[^"\\\x00-\x1F\x7F]+)*"') + + +@guidance(stateless=True) +def json_number(lm): + return lm + lexeme(r"-?(?:0|[1-9][0-9]*)(?:\.[0-9]+)?(?:[eE][+-]?[0-9]+)?") + + +@guidance(stateless=True) +def json_value(lm): + return lm + select( + [ + json_string(), + json_number(), + json_object(), + json_array(), + "true", + "false", + "null", + ] + ) + + +@guidance(stateless=True) +def json_member(lm): + return lm + json_string() + ":" + json_value() + + +@guidance(stateless=True) +def json_object(lm): + return lm + "{" + optional(json_member() + one_or_more("," + json_member())) + "}" + + +@guidance(stateless=True) +def json_array(lm): + return lm + "[" + optional(json_value() + one_or_more("," + json_value())) + "]" + + +@guidance(stateless=True) +def gen_json_object(lm, name: str, max_tokens=100000000): + grm = subgrammar( + name, + body=json_object(), + skip_regex=r"[\x20\x0A\x0D\x09]+", + no_initial_skip=True, + max_tokens=max_tokens + ) + return lm + grm + + +def test_greedy_json_object(selected_model: guidance.models.Model): + lm = selected_model + lm += "John Doe's name, age, and birthday:\n" + lm += gen_json_object("hacker", max_tokens=1000) + lm += "\nScore: " + gen("score", regex="[1-3]") + # make sure it parses as JSON + obj = json.loads(lm["hacker"]) + assert isinstance(obj, dict) + assert lm["score"] in ["1", "2", "3"] + + +def test_greedy_single_terminal(selected_model: guidance.models.Model): + lm = selected_model + lm += "A number: " + lm += subgrammar(body=lexeme(r"[0-9]{3}")) + assert re.search(r": [0-9]{3}$", str(lm)) diff --git a/tests/model_integration/test_model.py b/tests/model_integration/test_model.py index 619c87e96..7f74f1664 100644 --- a/tests/model_integration/test_model.py +++ b/tests/model_integration/test_model.py @@ -1,4 +1,5 @@ import pytest +from unittest.mock import patch import guidance from guidance import byte_range, gen, models, select, zero_or_more @@ -31,9 +32,8 @@ def test_token_count(selected_model): def test_token_healing(selected_model): """Tests a bug where the space is incorrectly forced as token 220, while it should be not forced it might be extended""" - model_type = type(selected_model.engine.model_obj).__name__ - print(model_type) - if model_type != "GPT2LMHeadModel": + model_obj = getattr(selected_model.engine, "model_obj", None) + if model_obj is None or type(model_obj).__name__ != "GPT2LMHeadModel": pytest.skip("Test for GPT2 bug only") gpt2 = selected_model lm = gpt2 + ( @@ -68,3 +68,32 @@ def test_stream_add_multiple(selected_model): lm += "" *_, last_lm = lm assert str(last_lm) in ["item1", "item2"] + + +def test_associativity(selected_model): + prompt = "pi = " + grammar = gen("number", regex=r"\d") + engine = selected_model.engine + + with patch.object(engine, "get_next_token", side_effect=engine.get_next_token) as get_next_token_1: + _ = selected_model + (prompt + grammar) + prompt_tokens_1 = get_next_token_1.call_args.kwargs["token_ids"] + + with patch.object(engine, "get_next_token", side_effect=engine.get_next_token) as get_next_token_2: + _ = (selected_model + prompt) + grammar + prompt_tokens_2 = get_next_token_2.call_args.kwargs["token_ids"] + + # Main assertion: the prompt tokens should be the same + assert prompt_tokens_1 == prompt_tokens_2 + + # Further assert that the tokenization matches the expected tokenization + expected_prompt_tokens = engine.tokenizer.encode(prompt.encode()) + if ( + engine.tokenizer.bos_token is not None + and expected_prompt_tokens[:1] != [engine.tokenizer.bos_token_id] + ): + expected_prompt_tokens = [engine.tokenizer.bos_token_id] + expected_prompt_tokens + expected_prompt_tokens = engine.tokenizer.recode(expected_prompt_tokens) + # token healing may cause the prompt seen by the model to be shorter + assert len(expected_prompt_tokens) >= len(prompt_tokens_1) + assert prompt_tokens_1 == expected_prompt_tokens[:len(prompt_tokens_1)] diff --git a/tests/model_specific/test_transformers.py b/tests/model_specific/test_transformers.py index 8774a8ca6..2b744e925 100644 --- a/tests/model_specific/test_transformers.py +++ b/tests/model_specific/test_transformers.py @@ -50,14 +50,15 @@ def test_recursion_error(): @pytest.mark.parametrize(["model_name", "model_kwargs"], TRANSFORMER_MODELS.items()) def test_transformer_smoke_gen(model_name, model_kwargs): + MAX_TOKENS = 2 my_model = get_model(f"transformers:{model_name}", **model_kwargs) prompt = "How many sides has a triangle?" - lm = my_model + prompt + gen(name="answer", max_tokens=2) + lm = my_model + prompt + gen(name="answer", max_tokens=MAX_TOKENS) assert len(lm["answer"]) > 0, f"Output: {lm['answer']}" - # Inexact, but at least make sure not too much was produced - assert len(lm["answer"]) < 8, f"Output: {lm['answer']}" + # Make sure not too much was produced + assert len(lm.engine.tokenizer.encode(lm["answer"].encode())) <= MAX_TOKENS, f"Output: {lm['answer']}" @pytest.mark.parametrize(["model_name", "model_kwargs"], TRANSFORMER_MODELS.items()) diff --git a/tests/need_credentials/test_azure_guidance.py b/tests/need_credentials/test_azure_guidance.py index cadfaab3d..b1b12f8bc 100644 --- a/tests/need_credentials/test_azure_guidance.py +++ b/tests/need_credentials/test_azure_guidance.py @@ -1,12 +1,15 @@ +import re import numpy as np import pytest +from jsonschema import validate +import json import guidance -from guidance import gen, select, assistant, user +from guidance import gen, select, assistant, user, optional, substring, one_or_more, token_limit +from guidance.library import json as gen_json from ..utils import get_model - @pytest.fixture(scope="module") def azure_guidance_model(selected_model, selected_model_name): if selected_model_name in ["azure_guidance"]: @@ -15,7 +18,94 @@ def azure_guidance_model(selected_model, selected_model_name): pytest.skip("Requires Azure Guidance model") -def test_azure_guidance_gen(azure_guidance_model: guidance.models.Model): +def test_azure_guidance_fill_in_json(azure_guidance_model: guidance.models.Model): + lm = azure_guidance_model + @guidance(stateless=True, dedent=False) + def character_maker(lm, id, description, valid_weapons): + lm += f"""\ + The following is a short character profile for an RPG game in JSON format. + ```json + {{ + "id": "{id}", + "description": "{description}", + "name": "{gen('name', stop='"')}", + "age": {gen('age', regex='[0-9]+', stop=',')}, + "armor": "{select(options=['leather', 'chainmail', 'plate'], name='armor')}", + "weapon": "{select(options=valid_weapons, name='weapon')}", + "class": "{gen('class', stop='"')}", + "mantra": "{gen('mantra', stop='"')}", + "strength": {gen('strength', regex='[0-9]+', stop=',')}, + "items": ["{gen('item', list_append=True, stop='"')}", "{gen('item', list_append=True, stop='"')}", "{gen('item', list_append=True, stop='"')}"] + }}```""" + return lm + lm += character_maker(1, 'A nimble fighter', ['axe', 'sword', 'bow']) + result = str(lm) + json_text = result[result.find("```json") + 8:-3] + json.loads(json_text) # check for valid JSON + + +def test_azure_guidance_basic_1(azure_guidance_model: guidance.models.Model): + lm = azure_guidance_model + lm += "Write a number: " + gen("text", max_tokens=3) + assert len(lm["text"]) >= 3 + +def test_azure_guidance_56_eos(azure_guidance_model: guidance.models.Model): + lm = azure_guidance_model + # make sure we recognize EOS token correctly + lm += "Q: 7 * 8\nA: " + gen("text", regex="[0-9]+", max_tokens=20) + assert lm["text"] == "56" + +def test_azure_guidance_56_newline(azure_guidance_model: guidance.models.Model): + lm = azure_guidance_model + # make sure we recognize EOS token correctly + lm += "Q: 7 * 8\nA: " + gen("text", regex="[0-9]+", max_tokens=20) + "\n" + assert lm["text"] == "56" + +def test_azure_guidance_1003_eos(azure_guidance_model: guidance.models.Model): + lm = azure_guidance_model + lm += "Q: 1000 + 3\nA: " + gen("text", regex="[0-9]+", max_tokens=20) + assert lm["text"] == "1003" + +def test_azure_guidance_dolphins(azure_guidance_model: guidance.models.Model): + lm = azure_guidance_model + # Yes|No has an implicit forced EoS at the end, which should not be actually generated + lm += "Q: Are dolphins fish?\nA: " + gen("dolphins", regex="Yes|No", max_tokens=10) + \ + "\nQ: Are salmons fish?\nA: " + gen("sharks", regex="Yes|No", max_tokens=10) + assert lm["dolphins"] == "No" + assert lm["sharks"] == "Yes" + +def test_azure_guidance_1003_max_tokens(azure_guidance_model: guidance.models.Model): + lm = azure_guidance_model + lm += "Q: 1000 + 3\nA: " + gen("text", regex="[0-9]+", max_tokens=2) + assert lm["text"] == "10" + + +def test_azure_guidance_max_tokens_1(azure_guidance_model: guidance.models.Model): + lm = azure_guidance_model + lm += "one, two, three, " + gen(name="a", max_tokens=1) + gen(name="b", max_tokens=1) + assert lm["a"] == "four" and lm["b"] == "," + +def test_azure_guidance_max_tokens_2(azure_guidance_model: guidance.models.Model): + lm = azure_guidance_model + lm += "one, two, three, " + gen(name="a", max_tokens=2) + gen(name="b", max_tokens=2) + assert lm["a"] == "four," and lm["b"] == " five," + + +def test_azure_guidance_stop_char(azure_guidance_model: guidance.models.Model): + lm = azure_guidance_model + lm += "Count to 10: 1, 2, 3, 4, 5, 6, 7, " + gen("text", stop=",") + assert lm["text"] == "8" + + +def test_azure_guidance_stop_string(azure_guidance_model: guidance.models.Model): + lm = azure_guidance_model + lm += "Count to 10: 1, 2, 3, 4, 5, 6, 7, " + gen("text", stop=", 9") + print(str(lm)) + assert lm["text"] == "8" + + + +def test_azure_guidance_gen_base(azure_guidance_model: guidance.models.Model): lm = azure_guidance_model lm = lm + "this is a test" + gen("test", max_tokens=10) assert len(str(lm)) > len("this is a test") @@ -58,7 +148,7 @@ def test_azure_guidance_repeat_calls(azure_guidance_model: guidance.models.Model a = [] lm = lm_orig + "How much is 2 + 2? " + gen(name="test", max_tokens=10) a.append(lm["test"]) - lm = lm_orig + "How much is 2 + 2? " + gen(name="test", max_tokens=10, regex=r"\d+") + lm = lm_orig + "How much is 2 + 2? " + gen(name="test", max_tokens=10, regex=r"[0-9]+") a.append(lm["test"]) lm = lm_orig + "How much is 2 + 2? " + gen(name="test", max_tokens=10) a.append(lm["test"]) @@ -72,17 +162,23 @@ def test_azure_guidance_suffix(azure_guidance_model: guidance.models.Model): + "1. Here is a sentence " + gen(name="bla", list_append=True, suffix="\n") ) + # list_append + assert isinstance(lm["bla"], list) + assert len(lm["bla"]) == 1 + # the capture should not have a newline + assert lm["bla"][0][-1] != "\n" + # the whole lm object *should* have a newline assert (str(lm))[-1] == "\n" assert (str(lm))[-2] != "\n" -def test_azure_guidance_subtoken_forced(azure_guidance_model: guidance.models.Model): - lm_orig = azure_guidance_model - lm = lm_orig + "How much is 2 + 2? " + gen(name="test", max_tokens=10, regex=r"\(") - assert str(lm) == "How much is 2 + 2? (" +# def test_azure_guidance_subtoken_forced(azure_guidance_model: guidance.models.Model): +# lm_orig = azure_guidance_model +# lm = lm_orig + "How much is 2 + 2? " + gen(name="test", max_tokens=10, regex=r"\(") +# assert str(lm) == "How much is 2 + 2? (" -def test_azure_guidance_with_temp(azure_guidance_model: guidance.models.Model): +def test_azure_guidance_with_temp1(azure_guidance_model: guidance.models.Model): lm = azure_guidance_model lm += "Here is a cute 5-line poem about cats and dogs:\n" for i in range(5): @@ -97,7 +193,7 @@ def test_azure_guidance_with_temp2(azure_guidance_model: guidance.models.Model): assert lm1["answer"] == lm2["answer"] -def test_azure_guidance_max_tokens(azure_guidance_model: guidance.models.Model): +def test_azure_guidance_max_tokens_3(azure_guidance_model: guidance.models.Model): lm = azure_guidance_model lm += "Who won the last Kentucky derby and by how much?" lm += "\n\n<red\n{gen(stop="")} and test2' r = str(lm) @@ -116,20 +212,20 @@ def test_azure_guidance_stop_token(azure_guidance_model: guidance.models.Model): assert "" not in r[20:] assert " and test2" in r[20:] -def test_azure_guidance_basic(azure_guidance_model: guidance.models.Model): +def test_azure_guidance_basic_2(azure_guidance_model: guidance.models.Model): model = azure_guidance_model lm = model + "Count to 20: 1,2,3,4," nl = "\n" lm += "5,6,7" + f"""{gen(max_tokens=1, suffix=nl)}aaaaaa""" assert str(lm)[-6:] == "aaaaaa" -def test_azure_guidance_fstring(azure_guidance_model: guidance.models.Model): +def test_azure_guidance_fstring_simple(azure_guidance_model: guidance.models.Model): lm = azure_guidance_model lm += f'this is a test {select(["item1", "item2"])}' assert str(lm) in ["this is a test item1", "this is a test item2"] -def test_azure_guidance_fstring_custom(azure_guidance_model: guidance.models.Model): +def test_azure_guidance_fstring_custom_statefull(azure_guidance_model: guidance.models.Model): lm = azure_guidance_model @guidance @@ -140,6 +236,17 @@ def my_function(lm): assert str(lm) in ["this is a test another item1", "this is a test another item2"] +def test_azure_guidance_fstring_custom_stateless(azure_guidance_model: guidance.models.Model): + lm = azure_guidance_model + + @guidance(stateless=True) + def my_function(lm): + return lm + f'another {select(["item1", "item2"])}' + + lm += f"this is a test {my_function()}" + assert str(lm) in ["this is a test another item1", "this is a test another item2"] + + def test_azure_guidance_token_count(azure_guidance_model: guidance.models.Model): lm = azure_guidance_model lm2 = lm + " 1 1 1 1 1" + gen(max_tokens=9) + gen(max_tokens=9) @@ -168,7 +275,7 @@ def ble(lm): assert "{{G|" not in str(model + ble()) -def test_azure_guidance_stream(azure_guidance_model: guidance.models.Model): +def test_azure_guidance_stream_0(azure_guidance_model: guidance.models.Model): lm = azure_guidance_model lm = lm + select(["item1", "item2"]) assert str(lm) in ["item1", "item2"] @@ -194,7 +301,7 @@ def test_azure_guidance_stream_add_multiple(azure_guidance_model: guidance.model assert str(lm) in ["item1", "item2"] -def test_azure_guidance(azure_guidance_model: guidance.models.Model): +def test_azure_guidance_1_plus_1(azure_guidance_model: guidance.models.Model): lm = azure_guidance_model with user(): lm += "What is 1 + 1?" @@ -204,7 +311,7 @@ def test_azure_guidance(azure_guidance_model: guidance.models.Model): assert len(lm["text"]) > 0 -def test_azure_guidance_select(azure_guidance_model: guidance.models.Model): +def test_azure_guidance_select1(azure_guidance_model: guidance.models.Model): lm = azure_guidance_model with user(): lm += "Pick a number: " @@ -230,18 +337,19 @@ def test_azure_guidance_loop(azure_guidance_model: guidance.models.Model): def test_azure_guidance_chat(azure_guidance_model: guidance.models.Model): lm = azure_guidance_model + max_tokens=30 with user(): lm += "The economy is crashing!" with assistant(): - lm += gen("test1", max_tokens=100) + lm += gen("test1", max_tokens=max_tokens) with user(): lm += "What is the best again?" with assistant(): - lm += gen("test2", max_tokens=100) + lm += gen("test2", max_tokens=max_tokens) assert len(lm["test1"]) > 0 assert len(lm["test2"]) > 0 @@ -253,13 +361,259 @@ def test_azure_guidance_chat(azure_guidance_model: guidance.models.Model): lm += "The economy is crashing!" with assistant(): - lm += gen("test1", max_tokens=100) + lm += gen("test1", max_tokens=max_tokens) with user(): lm += "What is the best again?" with assistant(): - lm += gen("test2", max_tokens=100) + lm += gen("test2", max_tokens=max_tokens) assert len(lm["test1"]) > 0 assert len(lm["test2"]) > 0 + +def test_azure_guidance_phi3_newline_chat(azure_guidance_model: guidance.models.Model): + lm = azure_guidance_model + lm += "You are a counting bot. Just keep counting numbers." + with user(): + lm += "1\n2\n3\n4\n" + with assistant(): + lm += "\n" + gen(name="five", max_tokens=1) + lm += "\n" + gen(name="six", max_tokens=1) + +def test_azure_guidance_phi3_unstable_tokenization(azure_guidance_model: guidance.models.Model): + lm = azure_guidance_model + lm += "You are a counting bot. Just keep counting numbers." + with user(): + lm += "1,2,3,4," + with assistant(): + lm += "\n" # comment and uncomment this line to get the error + lm += gen(name="five", max_tokens=1) + lm += "," + gen(name="six", max_tokens=1) + + +def test_azure_guidance_simple_recursion(azure_guidance_model: guidance.models.Model): + @guidance(stateless=True, dedent=False) + def grammar(lm, depth): + if depth != 0: + depth -= 1 + lm += "x" + optional(grammar(depth)) + return lm + lm = azure_guidance_model + lm += grammar(5) + + +def test_azure_guidance_mutual_recursion(azure_guidance_model: guidance.models.Model): + @guidance(stateless=True, dedent=False) + def grammar1(lm, depth): + if depth != 0: + depth -= 1 + lm += "x" + grammar2(depth) + return lm + + @guidance(stateless=True, dedent=False) + def grammar2(lm, depth): + if depth != 0: + depth -= 1 + lm += "y" + optional(grammar1(depth)) + return lm + + lm = azure_guidance_model + lm += grammar1(5) + lm += grammar2(5) + +def test_azure_guidance_multiple_mutual_recursion(azure_guidance_model: guidance.models.Model): + @guidance(stateless=True, dedent=False) + def grammar1(lm, depth): + if depth != 0: + depth -= 1 + lm += "x" + grammar2(depth) + return lm + + @guidance(stateless=True, dedent=False) + def grammar2(lm, depth): + if depth != 0: + depth -= 1 + lm += "y" + grammar3(depth) + return lm + + @guidance(stateless=True, dedent=False) + def grammar3(lm, depth): + if depth != 0: + depth -= 1 + lm += "z" + optional(grammar1(depth)) + return lm + + lm = azure_guidance_model + lm += grammar1(5) + lm += grammar2(5) + lm += grammar3(5) + +def test_azure_guidance_branching_mutual_recursion(azure_guidance_model: guidance.models.Model): + @guidance(stateless=True, dedent=False) + def grammar1(lm, depth): + if depth != 0: + depth -= 1 + lm += "x" + grammar2(depth) + return lm + + @guidance(stateless=True, dedent=False) + def grammar2(lm, depth): + if depth != 0: + depth -= 1 + lm += "y" + select([grammar1(depth), grammar3(depth)]) + return lm + + @guidance(stateless=True, dedent=False) + def grammar3(lm, depth): + if depth != 0: + depth -= 1 + lm += "z" + optional(grammar1(depth)) + return lm + + lm = azure_guidance_model + lm += grammar1(5) + lm += grammar2(5) + lm += grammar3(5) + + +# def test_remote_gen_json(azure_guidance_model: guidance.models.Model): +# schema = """ +# { +# "$defs": { +# "A": { +# "properties": { +# "my_str": { +# "default": "me", +# "title": "My Str", +# "type": "string" +# }, +# "next": { +# "anyOf": [ +# { +# "$ref": "#/$defs/A" +# }, +# { +# "type": "null" +# } +# ] +# } +# }, +# "type": "object" +# } +# }, +# "type": "object", +# "properties": { +# "my_list": { +# "anyOf": [ +# { +# "$ref": "#/$defs/A" +# }, +# { +# "type": "null" +# } +# ] +# } +# } +# } +# """ +# schema_obj = json.loads(schema) + +# m = azure_guidance_model +# m += gen_json(schema=schema_obj, name="my_json_string") +# print(f"Raw: {m['my_json_string']}") + +# my_obj = json.loads(m["my_json_string"]) +# print(f"Received object: {json.dumps(my_obj, indent=4)}") +# validate(my_obj, schema_obj) + + +# @pytest.mark.parametrize( +# "test_str", +# [ +# "is this legal", +# "I'm not sure ias;ldlkas is the best", +# "\n\nit works\n\n", +# "0123456789", +# ], +# ) +# def test_mocked_substring(test_str, azure_guidance_model: guidance.models.Model): +# m = azure_guidance_model + +# lm = m + substring(test_str, name="result") +# print(f'Substring: \'{lm["result"]}\' :::: \'{test_str}\'') +# assert lm["result"] in test_str + + +def test_azure_guidance_stateless_inside_stateful(azure_guidance_model: guidance.models.Model): + @guidance(stateless=False, dedent=False) + def stateful_grammar1(lm): + return lm + select(["+", "-"]) + stateful_grammar2() + + @guidance(stateless=False, dedent=False) + def stateful_grammar2(lm): + return lm + "p4" + stateless_grammar1() + + @guidance(stateless=True, dedent=False) + def stateless_grammar1(lm): + return lm + "3L" + stateless_grammar2() + + @guidance(stateless=True, dedent=False) + def stateless_grammar2(lm): + return lm + "Yy" + stateless_grammar3() + + @guidance(stateless=True, dedent=False) + def stateless_grammar3(lm): + return lm + select(["A", "B"]) + + lm = azure_guidance_model + lm += "begin:" + stateful_grammar1() + result = str(lm) + assert result == "begin:+p43LYyA" or result == "begin:-p43LYyA" or result == "begin:+p43LYyB" or result == "begin:-p43LYyB" + + +def test_azure_guidance_string(azure_guidance_model: guidance.models.Model): + model = azure_guidance_model + # limit number of tokens, otherwise test is very slow + s = str(model + "ab" + token_limit(one_or_more("ab"), 30)) + assert len(s) >= 4 + assert bool(re.fullmatch(r'(ab)*', s)) or bool(re.fullmatch(r'(ab)*', s[:-1])) + + +def test_azure_guidance_stop_token_name(azure_guidance_model: guidance.models.Model): + lm = azure_guidance_model + lm += "Name: " + gen('name', regex="E[a-z]+", stop_regex=["[a-b]", "[x-z]"], save_stop_text="saved_name_stop") + assert lm["saved_name_stop"] in ["a", "b", "x", "y", "z"] + assert lm["name"].startswith("E") + +def test_azure_guidance_stop_token_name2(azure_guidance_model: guidance.models.Model): + lm = azure_guidance_model + # repeat the token to get duplicated lexeme + lm += "Name: " + gen('name', regex="E[a-z]+", stop_regex=["[a-b]", "[x-z]"], save_stop_text="saved_name_stop") + \ + "\nName: " + gen('name2', regex="E[a-z]+", stop_regex=["[a-b]", "[x-z]"], save_stop_text=True) + assert lm["saved_name_stop"] in ["a", "b", "x", "y", "z"] + assert lm["name"].startswith("E") + assert lm["name2_stop_text"] in ["a", "b", "x", "y", "z"] + assert lm["name2"].startswith("E") + +def test_azure_guidance_max_tokens_4(azure_guidance_model: guidance.models.Model): + lm = azure_guidance_model + lm += "Name: " + gen('name', max_tokens=5) + " and " + gen('name2', max_tokens=5) + assert len(lm["name"]) > 0 + assert len(lm["name2"]) > 0 + +def test_azure_guidance_zero_temperature(azure_guidance_model: guidance.models.Model): + lm = azure_guidance_model + responses = [] + for _ in range(10): + temp = lm + "Number: " + gen("number", regex=r"\d", temperature=0.0) + responses.append(temp["number"]) + assert len(set(responses)) == 1 + +def test_azure_guidance_high_temperature(azure_guidance_model: guidance.models.Model): + lm = azure_guidance_model + responses = [] + for _ in range(10): + temp = lm + "Number: " + gen("number", regex=r"\d", temperature=0.9) + responses.append(temp["number"]) + assert len(set(responses)) > 1 \ No newline at end of file diff --git a/tests/server/test_server.py b/tests/server/test_server.py index 7bcda4a12..b269da609 100644 --- a/tests/server/test_server.py +++ b/tests/server/test_server.py @@ -21,10 +21,10 @@ def server_process(*, mock_string: Union[str, List[str]] = ""): byte_patterns = [] if isinstance(mock_string, str): - byte_patterns = [f"{mock_string}".encode()] + byte_patterns = [f"{mock_string}".encode()] else: for s in mock_string: - byte_patterns.append(f"{s}".encode()) + byte_patterns.append(f"{s}".encode()) lm = models.Mock(byte_patterns=byte_patterns) temp_lm = lm + gen() @@ -72,7 +72,7 @@ def test_return_mock_string(): my_string = "My roundtrip" with ServerContext(mock_string=my_string): m = models.Model("http://localhost:8392", api_key="SDFSDF") - m2 = m + gen(max_tokens=10, name="captured") + m2 = m + gen(max_tokens=20, name="captured") assert m2["captured"].startswith(my_string) diff --git a/tests/unit/library/test_commit_point.py b/tests/unit/library/test_commit_point.py index e5f6e1678..58b4639b5 100644 --- a/tests/unit/library/test_commit_point.py +++ b/tests/unit/library/test_commit_point.py @@ -1,6 +1,7 @@ -from guidance import commit_point, models - +import pytest +from guidance import commit_point, one_or_more, byte_range, models +@pytest.mark.xfail(reason="Commit points are not supported") def test_hidden(): model = models.Mock() model += " one" + commit_point(" two", hidden=True) + " three" diff --git a/tests/unit/library/test_gen.py b/tests/unit/library/test_gen.py index 98af94de2..91fdec7b3 100644 --- a/tests/unit/library/test_gen.py +++ b/tests/unit/library/test_gen.py @@ -1,3 +1,6 @@ +import pytest +from collections import defaultdict +import guidance from guidance import gen, models @@ -111,6 +114,14 @@ def test_list_append(): assert isinstance(lm["my_list"], list) assert len(lm["my_list"]) == 3 +@pytest.mark.xfail( + reason="llguidance currently emits an additional empty capture group when no explicit stop is provided" +) +def test_list_append_no_explicit_stop(): + model = models.Mock("bbbbbbb") + model += gen("list", list_append=True) + assert model["list"][-1] == "bbbbbbb" + assert len(model["list"]) == 1 def test_list_append_in_grammar(): """This tests is list append works within the same grammar.""" @@ -137,3 +148,86 @@ def test_one_char_stop_and_regex(): model = models.Mock(b"this is\na test") model += gen(regex=".*", stop="\n", max_tokens=20) assert str(model) == "this is" + + +def test_multiline(): + model = models.Mock(b"this\nis\na\ntest") + model += gen(max_tokens=20) + assert str(model) == "this\nis\na\ntest" + + +def test_tool_call(): + called_args = [] + + @guidance(dedent=False) + def square(lm, x): + called_args.append(x) + return lm + str(int(x)**2) + + model = models.Mock(b"Three squared is square(3)9") + model += gen(tools=[square], max_tokens=30) + assert str(model) == "Three squared is square(3)9" + assert called_args == ["3"] + + +def test_tool_call_hidden(): + called_args = [] + + @guidance(dedent=False) + def square(lm, x): + called_args.append(x) + return lm + str(int(x)**2) + + model = models.Mock([ + b"Three squared is square(3)", + b"Three squared is 9" + ]) + model += gen(tools=[square], hide_tool_call=True, max_tokens=30) + assert str(model) == "Three squared is 9" + assert called_args == ["3"] + + +def test_tool_call_multi(): + called_args = defaultdict(list) + + @guidance(dedent=False) + def square(lm, x): + called_args['square'].append(x) + return lm + str(int(x)**2) + + @guidance(dedent=False) + def cube(lm, x): + called_args['cube'].append(x) + return lm + str(int(x)**3) + + model = models.Mock( + b"Three squared is square(3)9, which cubed is cube(9)729. Good job me.", + ) + model += gen(tools=[square, cube], hide_tool_call=False, max_tokens=50) + assert str(model) == "Three squared is square(3)9, which cubed is cube(9)729. Good job me." + assert called_args["square"] == ["3"] + assert called_args["cube"] == ["9"] + + +def test_tool_call_multi_hidden(): + called_args = defaultdict(list) + + @guidance(dedent=False) + def square(lm, x): + called_args['square'].append(x) + return lm + str(int(x)**2) + + @guidance(dedent=False) + def cube(lm, x): + called_args['cube'].append(x) + return lm + str(int(x)**3) + + model = models.Mock([ + b"Three squared is square(3)", + b"Three squared is 9, which cubed is cube(9)", + b"Three squared is 9, which cubed is 729. Good job me." + ]) + model += gen(tools=[square, cube], hide_tool_call=True, max_tokens=50) + assert str(model) == "Three squared is 9, which cubed is 729. Good job me." + assert called_args["square"] == ["3"] + assert called_args["cube"] == ["9"] diff --git a/tests/unit/library/test_json.py b/tests/unit/library/test_json.py index 4669a1071..bd5dcf76c 100644 --- a/tests/unit/library/test_json.py +++ b/tests/unit/library/test_json.py @@ -1,14 +1,14 @@ import json from functools import partial -from typing import Any, Dict, Set, Union +from typing import Any, Dict, Set, Union, Optional import pytest from jsonschema import validate from guidance import json as gen_json from guidance import models -from guidance._grammar import Byte, ByteRange, byte_range -from guidance.library._json import _to_compact_json + +from guidance.library._json import _to_compact_json, WHITESPACE from ...utils import check_match_failure as _check_match_failure from ...utils import check_run_with_temperature @@ -16,7 +16,7 @@ def generate_and_check( - target_obj: Any, schema_obj, desired_temperature: Union[float, None] = None + target_obj: Any, schema_obj, desired_temperature: Optional[float] = None ): # Sanity check what we're being asked validate(instance=target_obj, schema=schema_obj) @@ -26,7 +26,9 @@ def generate_and_check( # Now test that the grammar can recognize and generate prepared_json # We partial in the grammar_callable if desired_temperature is not None: - grammar_callable = partial(gen_json, schema=schema_obj, temperature=desired_temperature) + grammar_callable = partial( + gen_json, schema=schema_obj, temperature=desired_temperature + ) else: grammar_callable = partial(gen_json, schema=schema_obj) @@ -41,23 +43,28 @@ def check_match_failure( bad_string: str, good_bytes: bytes, failure_byte: bytes, - allowed_bytes: Union[Set[Union[Byte, ByteRange]], None], + allowed_bytes: Optional[Set[bytes]], schema_obj: Dict[str, Any], + maybe_whitespace: bool, + compact: bool, ): - grammar = gen_json(schema=schema_obj) + grammar = gen_json(schema=schema_obj, compact=compact) _check_match_failure( bad_string=bad_string, good_bytes=good_bytes, failure_byte=failure_byte, - allowed_bytes=allowed_bytes, + allowed_bytes=( + allowed_bytes.union(WHITESPACE) if (maybe_whitespace and not compact and allowed_bytes is not None) + else allowed_bytes + ), grammar=grammar, ) # Common sets of allowed_bytes -INTEGER_LEADING = {Byte(b"-"), Byte(b"0"), ByteRange(b"19")} -INTEGER_FOLLOWING = {ByteRange(b"09")} - +INTEGER_LEADING = {b"-", b"0", *{bytes([i]) for i in range(ord("1"), ord("9") + 1)}} +INTEGER_FOLLOWING = {bytes([i]) for i in range(ord("0"), ord("9") + 1)} +A_to_Z = {bytes([i]) for i in range(ord("A"), ord("Z") + 1)} def test_null(): schema = """{"type": "null" }""" @@ -99,17 +106,20 @@ def test_integer_schema(self, my_int): generate_and_check(my_int, schema_obj) @pytest.mark.parametrize( - ["bad_string", "good_bytes", "failure_byte", "allowed_bytes"], + "compact", [True, False] + ) + @pytest.mark.parametrize( + ["bad_string", "good_bytes", "failure_byte", "allowed_bytes", "maybe_whitespace"], [ - ("9999a7777", b"9999", b"a", INTEGER_FOLLOWING), - ("123, []", b"123", b",", INTEGER_FOLLOWING), - ("a321", b"", b"a", INTEGER_LEADING), - ("123789.456", b"123789", b".", INTEGER_FOLLOWING), - ("[]", b"", b"[", INTEGER_LEADING), - ('{"a":4}', b"", b"{", INTEGER_LEADING), + ("9999a7777", b"9999", b"a", INTEGER_FOLLOWING, True), + ("123, []", b"123", b",", INTEGER_FOLLOWING, True), + ("a321", b"", b"a", INTEGER_LEADING, False), + ("123789.456", b"123789", b".", INTEGER_FOLLOWING, True), + ("[]", b"", b"[", INTEGER_LEADING, False), + ('{"a":4}', b"", b"{", INTEGER_LEADING, False), ], ) - def test_bad_integer(self, bad_string, good_bytes, failure_byte, allowed_bytes): + def test_bad_integer(self, bad_string, good_bytes, failure_byte, allowed_bytes, maybe_whitespace, compact): schema_obj = json.loads(TestInteger.schema) check_match_failure( bad_string=bad_string, @@ -117,6 +127,8 @@ def test_bad_integer(self, bad_string, good_bytes, failure_byte, allowed_bytes): failure_byte=failure_byte, allowed_bytes=allowed_bytes, schema_obj=schema_obj, + maybe_whitespace=maybe_whitespace, + compact=compact, ) @@ -154,16 +166,19 @@ def test_number(self, target_obj, temperature): generate_and_check(target_obj, schema_obj, desired_temperature=temperature) @pytest.mark.parametrize( - ["bad_string", "good_bytes", "failure_byte", "allowed_bytes"], + "compact", [True, False] + ) + @pytest.mark.parametrize( + ["bad_string", "good_bytes", "failure_byte", "allowed_bytes", "maybe_whitespace"], [ - ("9999a7777", b"9999", b"a", {Byte(b"e"), Byte(b"."), *INTEGER_FOLLOWING}), - ("123.6, []", b"123.6", b",", {Byte(b"e"), *INTEGER_FOLLOWING}), - ("a321", b"", b"a", INTEGER_LEADING), - ("[]", b"", b"[", INTEGER_LEADING), - ('{"a":4}', b"", b"{", INTEGER_LEADING), + ("9999a7777", b"9999", b"a", {b"e", b"E", b".", *INTEGER_FOLLOWING}, True), + ("123.6, []", b"123.6", b",", {b"e", b"E", *INTEGER_FOLLOWING}, True), + ("a321", b"", b"a", INTEGER_LEADING, False), + ("[]", b"", b"[", INTEGER_LEADING, False), + ('{"a":4}', b"", b"{", INTEGER_LEADING, False), ], ) - def test_bad_number(self, bad_string, good_bytes, failure_byte, allowed_bytes): + def test_bad_number(self, bad_string, good_bytes, failure_byte, allowed_bytes, maybe_whitespace, compact): schema_obj = json.loads(TestNumber.schema) check_match_failure( bad_string=bad_string, @@ -171,6 +186,8 @@ def test_bad_number(self, bad_string, good_bytes, failure_byte, allowed_bytes): failure_byte=failure_byte, allowed_bytes=allowed_bytes, schema_obj=schema_obj, + maybe_whitespace=maybe_whitespace, + compact=compact, ) @@ -232,8 +249,8 @@ def test_regex_no_min_max_length(self): @pytest.mark.parametrize( ["bad_string", "good_bytes", "failure_byte", "allowed_bytes"], [ - ('"ab"', b'"a', b"b", set([byte_range(b"A", b"Z")])), - ('"a1"', b'"a', b"1", set([byte_range(b"A", b"Z")])), + ('"ab"', b'"a', b"b", A_to_Z), + ('"a1"', b'"a', b"1", A_to_Z), ], ) def test_regex_bad(self, bad_string: str, good_bytes, failure_byte, allowed_bytes): @@ -247,6 +264,8 @@ def test_regex_bad(self, bad_string: str, good_bytes, failure_byte, allowed_byte failure_byte=failure_byte, allowed_bytes=allowed_bytes, schema_obj=schema_obj, + maybe_whitespace=False, + compact=True, ) @pytest.mark.parametrize( @@ -266,7 +285,7 @@ def test_min_and_maxLength(self, my_string: str): ["bad_string", "good_bytes", "failure_byte", "allowed_bytes"], [ ('""', b'"', b'"', None), - ('"dddd"', b'"ddd', b"d", set([Byte(b'"')])), + ('"dddd"', b'"ddd', b"d", {b'"'}), ], ) def test_min_and_maxLength_bad(self, bad_string: str, good_bytes, failure_byte, allowed_bytes): @@ -280,6 +299,8 @@ def test_min_and_maxLength_bad(self, bad_string: str, good_bytes, failure_byte, failure_byte=failure_byte, allowed_bytes=allowed_bytes, schema_obj=schema_obj, + maybe_whitespace=False, + compact=True, ) @pytest.mark.parametrize( @@ -351,6 +372,8 @@ def test_minLength_bad(self, bad_string: str, good_bytes, failure_byte, allowed_ failure_byte=failure_byte, allowed_bytes=allowed_bytes, schema_obj=schema_obj, + maybe_whitespace=False, + compact=True, ) @pytest.mark.parametrize( @@ -392,8 +415,8 @@ def test_maxLength_zero(self): @pytest.mark.parametrize( ["bad_string", "good_bytes", "failure_byte", "allowed_bytes"], [ - ('"aaa"', b'"aa', b"a", set([Byte(b'"')])), - ('"1111"', b'"11', b"1", set([Byte(b'"')])), + ('"aaa"', b'"aa', b"a", {b'"'}), + ('"1111"', b'"11', b"1", {b'"'}), ], ) def test_maxLength_bad(self, bad_string: str, good_bytes, failure_byte, allowed_bytes): @@ -407,6 +430,8 @@ def test_maxLength_bad(self, bad_string: str, good_bytes, failure_byte, allowed_ failure_byte=failure_byte, allowed_bytes=allowed_bytes, schema_obj=schema_obj, + maybe_whitespace=False, + compact=True, ) @@ -511,14 +536,17 @@ def test_object_containing_list(self, temperature): generate_and_check(target_obj, schema_obj, desired_temperature=temperature) @pytest.mark.parametrize( - ["bad_string", "good_bytes", "failure_byte", "allowed_bytes"], + "compact", [True, False] + ) + @pytest.mark.parametrize( + ["bad_string", "good_bytes", "failure_byte", "allowed_bytes", "maybe_whitespace"], [ - ("9999a7777", b"", b"9", {Byte(b"{")}), - ('{"a":1255.4567}', b'{"a":1255', b".", {Byte(b"}"), *INTEGER_FOLLOWING}), - ('{"a":"123"}', b'{"a":', b'"', INTEGER_LEADING), + ("9999a7777", b"", b"9", {b"{"}, False), + ('{"a":1255.4567}', b'{"a":1255', b".", {b"}", *INTEGER_FOLLOWING}, True), + ('{"a":"123"}', b'{"a":', b'"', INTEGER_LEADING, True), ], ) - def test_bad_object(self, bad_string, good_bytes, failure_byte, allowed_bytes): + def test_bad_object(self, bad_string, good_bytes, failure_byte, allowed_bytes, maybe_whitespace, compact): schema = """{ "type": "object", "properties": { @@ -534,6 +562,8 @@ def test_bad_object(self, bad_string, good_bytes, failure_byte, allowed_bytes): failure_byte=failure_byte, allowed_bytes=allowed_bytes, schema_obj=schema_obj, + maybe_whitespace=maybe_whitespace, + compact=compact, ) @@ -601,14 +631,17 @@ def test_object_list(self, target_obj, temperature): generate_and_check(target_obj, schema_obj, desired_temperature=temperature) @pytest.mark.parametrize( - ["bad_string", "good_bytes", "failure_byte", "allowed_bytes"], + "compact", [True, False] + ) + @pytest.mark.parametrize( + ["bad_string", "good_bytes", "failure_byte", "allowed_bytes", "maybe_whitespace"], [ - ("9999a7777", b"", b"9", {Byte(b"[")}), - ("[321.654]", b"[321", b".", {Byte(b"]"), Byte(b","), *INTEGER_FOLLOWING}), - ('["123"]', b"[", b'"', {Byte(b"]"), *INTEGER_LEADING}), + ("9999a7777", b"", b"9", {b"["}, False), + ("[321.654]", b"[321", b".", {b"]", b",", *INTEGER_FOLLOWING}, True), + ('["123"]', b"[", b'"', {b"]", *INTEGER_LEADING}, True), ], ) - def test_bad_object(self, bad_string, good_bytes, failure_byte, allowed_bytes): + def test_bad_object(self, bad_string, good_bytes, failure_byte, allowed_bytes, maybe_whitespace, compact): schema = """{ "type" : "array", "items" : { @@ -622,6 +655,8 @@ def test_bad_object(self, bad_string, good_bytes, failure_byte, allowed_bytes): failure_byte=failure_byte, allowed_bytes=allowed_bytes, schema_obj=schema_obj, + maybe_whitespace=maybe_whitespace, + compact=compact, ) @@ -710,7 +745,10 @@ def test_good_with_items(self, min_items, max_items, target_obj): generate_and_check(target_obj, schema_obj) @pytest.mark.parametrize( - "min_items, max_items, bad_obj, good_bytes, failure_byte, allowed_bytes", + "compact", [True, False] + ) + @pytest.mark.parametrize( + "min_items, max_items, bad_obj, good_bytes, failure_byte, allowed_bytes, maybe_whitespace", [ ( 1, @@ -718,7 +756,8 @@ def test_good_with_items(self, min_items, max_items, target_obj): [42, "string_not_bool", "hello", "extra"], b"[42,", b'"', - {Byte(b"t"), Byte(b"f")}, + {b"t", b"f"}, + True, ), # Second item does not match prefix schema ( 0, @@ -726,7 +765,8 @@ def test_good_with_items(self, min_items, max_items, target_obj): [42, True, 100], b"[42,true,", b"1", - {Byte(b'"')}, + {b'"'}, + True, ), # Last item does not match general item schema ( 3, @@ -734,7 +774,8 @@ def test_good_with_items(self, min_items, max_items, target_obj): [42, True, "valid", "extra1", "extra2", "too_many"], b'[42,true,"valid","extra1","extra2"', b",", - {Byte(b"]")}, + {b"]"}, + True, ), # Exceeds maxItems ( 2, @@ -742,7 +783,8 @@ def test_good_with_items(self, min_items, max_items, target_obj): [42], b"[42", b"]", - {Byte(b","), *INTEGER_FOLLOWING}, + {b",", *INTEGER_FOLLOWING}, + True, ), # Not enough items ( 1, @@ -750,7 +792,8 @@ def test_good_with_items(self, min_items, max_items, target_obj): [42, True], b"[42", b",", - {Byte(b"]"), *INTEGER_FOLLOWING}, + {b"]", *INTEGER_FOLLOWING}, + True, ), # Too many items for maxItems ( 0, @@ -758,7 +801,8 @@ def test_good_with_items(self, min_items, max_items, target_obj): [42, True, "str"], b"[", b"4", - {Byte(b"]")}, + {b"]"}, + True, ), # maxItems set to 0, but array is not empty ( 3, @@ -766,12 +810,13 @@ def test_good_with_items(self, min_items, max_items, target_obj): [42, True], b"[42,true", b"]", - {Byte(b",")}, + {b","}, + True, ), # Array has one fewer item than required by minItems ], ) def test_bad_with_prefix_and_items( - self, min_items, max_items, bad_obj, good_bytes, failure_byte, allowed_bytes + self, min_items, max_items, bad_obj, good_bytes, failure_byte, allowed_bytes, maybe_whitespace, compact ): schema_obj = { "prefixItems": self.prefix_schema_obj, @@ -787,10 +832,15 @@ def test_bad_with_prefix_and_items( failure_byte=failure_byte, allowed_bytes=allowed_bytes, schema_obj=schema_obj, + maybe_whitespace=maybe_whitespace, + compact=compact, ) @pytest.mark.parametrize( - "min_items, max_items, bad_obj, good_bytes, failure_byte, allowed_bytes", + "compact", [True, False] + ) + @pytest.mark.parametrize( + "min_items, max_items, bad_obj, good_bytes, failure_byte, allowed_bytes, maybe_whitespace", [ ( 2, @@ -798,7 +848,8 @@ def test_bad_with_prefix_and_items( [42], b"[42", b"]", - {Byte(b","), *INTEGER_FOLLOWING}, + {b",", *INTEGER_FOLLOWING}, + True, ), # Array too short to meet minItems, despite matching prefixItems ( 1, @@ -806,7 +857,8 @@ def test_bad_with_prefix_and_items( [42, "not_bool"], b"[42,", b'"', - {Byte(b"t"), Byte(b"f")}, + {b"t", b"f"}, + True, ), # Second item violates prefixItems type requirement ( 0, @@ -814,7 +866,8 @@ def test_bad_with_prefix_and_items( [42, True], b"[42", b",", - {Byte(b"]"), *INTEGER_FOLLOWING}, + {b"]", *INTEGER_FOLLOWING}, + True, ), # Array exceeds maxItems with valid prefixItems types ( 1, @@ -822,7 +875,8 @@ def test_bad_with_prefix_and_items( [42, True, "extra"], b"[42,true", b",", - {Byte(b"]")}, + {b"]"}, + True, ), # Item beyond prefixItems with no "items" schema ( 0, @@ -830,12 +884,13 @@ def test_bad_with_prefix_and_items( [42], b"[", b"4", - {Byte(b"]")}, + {b"]"}, + True, ), # maxItems set to 0, but array is not empty ], ) def test_bad_with_prefix( - self, min_items, max_items, bad_obj, good_bytes, failure_byte, allowed_bytes + self, min_items, max_items, bad_obj, good_bytes, failure_byte, allowed_bytes, maybe_whitespace, compact ): schema_obj = { "prefixItems": self.prefix_schema_obj, @@ -851,10 +906,15 @@ def test_bad_with_prefix( failure_byte=failure_byte, allowed_bytes=allowed_bytes, schema_obj=schema_obj, + maybe_whitespace=maybe_whitespace, + compact=compact, ) @pytest.mark.parametrize( - "min_items, max_items, bad_obj, good_bytes, failure_byte, allowed_bytes", + "compact", [True, False] + ) + @pytest.mark.parametrize( + "min_items, max_items, bad_obj, good_bytes, failure_byte, allowed_bytes, maybe_whitespace", [ ( 1, @@ -862,7 +922,8 @@ def test_bad_with_prefix( ["hello", "world", "extra"], b'["hello","world"', b",", - {Byte(b"]")}, + {b"]"}, + True, ), # Too many items for maxItems ( 2, @@ -870,7 +931,8 @@ def test_bad_with_prefix( ["hello"], b'["hello"', b"]", - {Byte(b",")}, + {b","}, + True, ), # Not enough items ( 2, @@ -878,7 +940,8 @@ def test_bad_with_prefix( ["hello", 42], b'["hello",', b"4", - {Byte(b'"')}, + {b'"'}, + True, ), # Badly typed second item ( 0, @@ -886,12 +949,13 @@ def test_bad_with_prefix( ["hello"], b"[", b'"', - {Byte(b"]")}, + {b"]"}, + True, ), # maxItems set to 0, but array is not empty ], ) def test_bad_with_items( - self, min_items, max_items, bad_obj, good_bytes, failure_byte, allowed_bytes + self, min_items, max_items, bad_obj, good_bytes, failure_byte, allowed_bytes, maybe_whitespace, compact ): schema_obj = { "items": self.items_schema_obj, @@ -906,6 +970,8 @@ def test_bad_with_items( failure_byte=failure_byte, allowed_bytes=allowed_bytes, schema_obj=schema_obj, + maybe_whitespace=maybe_whitespace, + compact=compact, ) @@ -1275,14 +1341,17 @@ def test_enum(self, target_obj, temperature): generate_and_check(target_obj, schema_obj, desired_temperature=temperature) @pytest.mark.parametrize( - "bad_obj, good_bytes, failure_byte, allowed_bytes", + "compact", [True, False] + ) + @pytest.mark.parametrize( + "bad_obj, good_bytes, failure_byte, allowed_bytes, maybe_whitespace", [ - ("1", b'"', b"1", {Byte(b"2")}), - (2, b"", b"2", {Byte(b'"'), Byte(b"1"), Byte(b"f")}), - (True, b"", b"t", {Byte(b'"'), Byte(b"1"), Byte(b"f")}), + ("1", b'"', b"1", {b"2"}, False), + (2, b"", b"2", {b'"', b"1", b"f"}, False), + (True, b"", b"t", {b'"', b"1", b"f"}, False), ], ) - def test_bad_enum(self, bad_obj, good_bytes, failure_byte, allowed_bytes): + def test_bad_enum(self, bad_obj, good_bytes, failure_byte, allowed_bytes, maybe_whitespace, compact): schema_obj = json.loads(self.simple_schema) bad_string = _to_compact_json(bad_obj) check_match_failure( @@ -1291,17 +1360,22 @@ def test_bad_enum(self, bad_obj, good_bytes, failure_byte, allowed_bytes): failure_byte=failure_byte, allowed_bytes=allowed_bytes, schema_obj=schema_obj, + maybe_whitespace=maybe_whitespace, + compact=compact, ) @pytest.mark.parametrize( - "bad_obj, good_bytes, failure_byte, allowed_bytes", + "compact", [True, False] + ) + @pytest.mark.parametrize( + "bad_obj, good_bytes, failure_byte, allowed_bytes, maybe_whitespace", [ - ("ab", b'"a', b"b", {Byte(b"a")}), - ("bc", b'"b', b"c", {Byte(b"b")}), - ("ca", b'"c', b"a", {Byte(b"c")}), + ("ab", b'"a', b"b", {b"a"}, False), + ("bc", b'"b', b"c", {b"b"}, False), + ("ca", b'"c', b"a", {b"c"}, False), ], ) - def test_bad_prefix_enum(self, bad_obj, good_bytes, failure_byte, allowed_bytes): + def test_bad_prefix_enum(self, bad_obj, good_bytes, failure_byte, allowed_bytes, maybe_whitespace, compact): schema_obj = json.loads(self.prefix_schema) bad_string = _to_compact_json(bad_obj) check_match_failure( @@ -1310,6 +1384,8 @@ def test_bad_prefix_enum(self, bad_obj, good_bytes, failure_byte, allowed_bytes) failure_byte=failure_byte, allowed_bytes=allowed_bytes, schema_obj=schema_obj, + maybe_whitespace=maybe_whitespace, + compact=compact, ) @@ -1367,8 +1443,10 @@ def test_constant_precedence(self): bad_string=bad_string, good_bytes=b"", failure_byte=b"2", - allowed_bytes={Byte(b"1")}, + allowed_bytes={b"1"}, schema_obj=schema_obj, + maybe_whitespace=False, + compact=True, ) @@ -1415,18 +1493,22 @@ def test_simple_additional_properties(self, target_obj, temperature): generate_and_check(target_obj, schema_obj, desired_temperature=temperature) @pytest.mark.parametrize( - "bad_obj, good_bytes, failure_byte, allowed_bytes", + "compact", [True, False] + ) + @pytest.mark.parametrize( + "bad_obj, good_bytes, failure_byte, allowed_bytes, maybe_whitespace", [ - ({"a": "1"}, b'{"a":', b'"', INTEGER_LEADING), + ({"a": "1"}, b'{"a":', b'"', INTEGER_LEADING, True), ( {"a": 1, "b": 1.5}, b'{"a":1,"b":1', b".", - {Byte(b","), Byte(b"}"), *INTEGER_FOLLOWING}, + {b",", b"}", *INTEGER_FOLLOWING}, + True, ), ], ) - def test_simple_bad_type(self, bad_obj, good_bytes, failure_byte, allowed_bytes): + def test_simple_bad_type(self, bad_obj, good_bytes, failure_byte, allowed_bytes, maybe_whitespace, compact): schema_obj = json.loads(self.simple_schema) bad_string = _to_compact_json(bad_obj) check_match_failure( @@ -1435,9 +1517,13 @@ def test_simple_bad_type(self, bad_obj, good_bytes, failure_byte, allowed_bytes) failure_byte=failure_byte, allowed_bytes=allowed_bytes, schema_obj=schema_obj, + maybe_whitespace=maybe_whitespace, + compact=compact, ) - @pytest.mark.parametrize("target_obj", [{}, {"a": 1}, {"a": "2"}, {"a": 1, "b": "2"}]) + @pytest.mark.parametrize( + "target_obj", [{}, {"a": 1}, {"a": "2"}, {"a": 1, "b": "2"}] + ) def test_anyOf_additional_properties(self, target_obj): # First sanity check what we're setting up schema_obj = json.loads(self.anyOf_schema) @@ -1447,19 +1533,23 @@ def test_anyOf_additional_properties(self, target_obj): generate_and_check(target_obj, schema_obj) @pytest.mark.parametrize( - "bad_obj, good_bytes, failure_byte, allowed_bytes", + "compact", [True, False] + ) + @pytest.mark.parametrize( + "bad_obj, good_bytes, failure_byte, allowed_bytes, maybe_whitespace", [ - ({"a": 1.5}, b'{"a":1', b".", {Byte(b","), Byte(b"}"), *INTEGER_FOLLOWING}), - ({"a": True}, b'{"a":', b"t", {Byte(b'"'), *INTEGER_LEADING}), + ({"a": 1.5}, b'{"a":1', b".", {b",", b"}", *INTEGER_FOLLOWING}, True), + ({"a": True}, b'{"a":', b"t", {b'"', *INTEGER_LEADING}, True), ( {"a": 1, "b": False}, b'{"a":1,"b":', b"f", - {Byte(b'"'), *INTEGER_LEADING}, + {b'"', *INTEGER_LEADING}, + True, ), ], ) - def test_anyOf_bad_type(self, bad_obj, good_bytes, failure_byte, allowed_bytes): + def test_anyOf_bad_type(self, bad_obj, good_bytes, failure_byte, allowed_bytes, maybe_whitespace, compact): schema_obj = json.loads(self.anyOf_schema) bad_string = _to_compact_json(bad_obj) check_match_failure( @@ -1468,6 +1558,8 @@ def test_anyOf_bad_type(self, bad_obj, good_bytes, failure_byte, allowed_bytes): failure_byte=failure_byte, allowed_bytes=allowed_bytes, schema_obj=schema_obj, + maybe_whitespace=maybe_whitespace, + compact=compact, ) @pytest.mark.parametrize( @@ -1488,14 +1580,19 @@ def test_properties_and_additional_properties(self, target_obj, temperature): generate_and_check(target_obj, schema_obj, desired_temperature=temperature) @pytest.mark.parametrize( - "bad_obj, good_bytes, failure_byte, allowed_bytes", + "compact", [True, False] + ) + @pytest.mark.parametrize( + "bad_obj, good_bytes, failure_byte, allowed_bytes, maybe_whitespace", [ - ({}, b"{", b"}", {Byte(b'"')}), - ({"a": 1}, b'{"', b"a", {Byte(b"m")}), - ({"a": 1, "b": 2}, b'{"', b"a", {Byte(b"m")}), + ({}, b"{", b"}", {b'"'}, True), + ({"a": 1}, b'{"', b"a", {b"m"}, False), + ({"a": 1, "b": 2}, b'{"', b"a", {b"m"}, False), ], ) - def test_combined_missing_properties(self, bad_obj, good_bytes, failure_byte, allowed_bytes): + def test_combined_missing_properties( + self, bad_obj, good_bytes, failure_byte, allowed_bytes, maybe_whitespace, compact + ): schema_obj = json.loads(self.combined_schema) bad_string = _to_compact_json(bad_obj) check_match_failure( @@ -1504,22 +1601,28 @@ def test_combined_missing_properties(self, bad_obj, good_bytes, failure_byte, al failure_byte=failure_byte, allowed_bytes=allowed_bytes, schema_obj=schema_obj, + maybe_whitespace=maybe_whitespace, + compact=compact, ) @pytest.mark.parametrize( - "bad_obj, good_bytes, failure_byte, allowed_bytes", + "compact", [True, False] + ) + @pytest.mark.parametrize( + "bad_obj, good_bytes, failure_byte, allowed_bytes, maybe_whitespace", [ - ({"mystr": 1}, b'{"mystr":', b"1", {Byte(b'"')}), - ({"mystr": 1, "a": 2}, b'{"mystr":', b"1", {Byte(b'"')}), + ({"mystr": 1}, b'{"mystr":', b"1", {b'"'}, True), + ({"mystr": 1, "a": 2}, b'{"mystr":', b"1", {b'"'}, True), ( {"mystr": "hello", "a": False}, b'{"mystr":"hello","a":', b"f", INTEGER_LEADING, + True, ), ], ) - def test_combined_bad_type(self, bad_obj, good_bytes, failure_byte, allowed_bytes): + def test_combined_bad_type(self, bad_obj, good_bytes, failure_byte, allowed_bytes, maybe_whitespace, compact): schema_obj = json.loads(self.combined_schema) bad_string = _to_compact_json(bad_obj) check_match_failure( @@ -1528,6 +1631,8 @@ def test_combined_bad_type(self, bad_obj, good_bytes, failure_byte, allowed_byte failure_byte=failure_byte, allowed_bytes=allowed_bytes, schema_obj=schema_obj, + maybe_whitespace=maybe_whitespace, + compact=compact, ) @@ -1622,34 +1727,41 @@ def test_empty_schema(self, target_obj, temperature): generate_and_check(target_obj, schema_obj, desired_temperature=temperature) @pytest.mark.parametrize( - "bad_string, good_bytes, failure_byte, allowed_bytes", + "compact", [True, False] + ) + @pytest.mark.parametrize( + "bad_string, good_bytes, failure_byte, allowed_bytes, maybe_whitespace", [ # {} is not carte blanche for malformed JSON - ("{a:1}", b"{", b"a", {Byte(b'"'), Byte(b"}")}), + ("{a:1}", b"{", b"a", {b'"', b"}"}, True), ( "[1,2} ", b"[1,2", b"}", - {Byte(b","), Byte(b"]"), Byte(b"e"), Byte(b"."), *INTEGER_FOLLOWING}, + {b",", b"]", b"e", b"E", b".", *INTEGER_FOLLOWING}, + True, ), - ("123a", b"123", b"a", {Byte(b"e"), Byte(b"."), *INTEGER_FOLLOWING}), + ("123a", b"123", b"a", {b"e", b"E", b".", *INTEGER_FOLLOWING}, True), ( "]", b"", b"]", { - Byte(b"["), - Byte(b"{"), - Byte(b'"'), - Byte(b"t"), - Byte(b"f"), - Byte(b"n"), + b"[", + b"{", + b'"', + b"t", + b"f", + b"n", *INTEGER_LEADING, }, + False, ), ], ) - def test_bad_empty_schema(self, bad_string, good_bytes, failure_byte, allowed_bytes): + def test_bad_empty_schema( + self, bad_string, good_bytes, failure_byte, allowed_bytes, maybe_whitespace, compact + ): schema_obj = json.loads(self.empty_schema) check_match_failure( bad_string=bad_string, @@ -1657,6 +1769,8 @@ def test_bad_empty_schema(self, bad_string, good_bytes, failure_byte, allowed_by failure_byte=failure_byte, allowed_bytes=allowed_bytes, schema_obj=schema_obj, + maybe_whitespace=maybe_whitespace, + compact=compact, ) @pytest.mark.parametrize( @@ -1700,14 +1814,17 @@ def test_nested_empty_schema(self, schema_obj, target_obj, temperature): ], ) @pytest.mark.parametrize( - "bad_obj, good_bytes, failure_byte, allowed_bytes", + "compact", [True, False] + ) + @pytest.mark.parametrize( + "bad_obj, good_bytes, failure_byte, allowed_bytes, maybe_whitespace", [ # Missing property -- presence of {} deeper in the schema isn't carte blanche - ({"b": 42}, b'{"', b"b", {Byte(b"a")}), + ({"b": 42}, b'{"', b"b", {b"a"}, False), ], ) def test_nested_empty_schema_bad( - self, schema_obj, bad_obj, good_bytes, failure_byte, allowed_bytes + self, schema_obj, bad_obj, good_bytes, failure_byte, allowed_bytes, maybe_whitespace, compact ): bad_string = _to_compact_json(bad_obj) check_match_failure( @@ -1716,6 +1833,8 @@ def test_nested_empty_schema_bad( failure_byte=failure_byte, allowed_bytes=allowed_bytes, schema_obj=schema_obj, + maybe_whitespace=maybe_whitespace, + compact=compact, ) @pytest.mark.parametrize( @@ -1742,14 +1861,17 @@ def test_nested_empty_schema_with_props(self, target_obj, temperature): generate_and_check(target_obj, schema_obj, desired_temperature=temperature) @pytest.mark.parametrize( - "bad_obj, good_bytes, failure_byte, allowed_bytes", + "compact", [True, False] + ) + @pytest.mark.parametrize( + "bad_obj, good_bytes, failure_byte, allowed_bytes, maybe_whitespace", [ # Missing property -- presence of {} deeper in the schema isn't carte blanche - ({"b": 42}, b'{"', b"b", {Byte(b"a")}), + ({"b": 42}, b'{"', b"b", {b"a"}, False), ], ) def test_nested_empty_schema_with_props_bad( - self, bad_obj, good_bytes, failure_byte, allowed_bytes + self, bad_obj, good_bytes, failure_byte, allowed_bytes, maybe_whitespace, compact ): schema_obj = json.loads(self.nested_empty_schema_with_props) @@ -1760,6 +1882,8 @@ def test_nested_empty_schema_with_props_bad( failure_byte=failure_byte, allowed_bytes=allowed_bytes, schema_obj=schema_obj, + maybe_whitespace=maybe_whitespace, + compact=compact, ) @pytest.mark.parametrize( @@ -1771,19 +1895,23 @@ def test_nested_empty_schema_with_props_bad( ], ) def test_items(self, schema_obj): - schema_obj = {"type": "array"} generate_and_check( [1, 0.4, "hello", False, None, {"a": 42}, [1, 2, 3, "four"]], schema_obj ) - def test_no_items(self): + @pytest.mark.parametrize( + "compact", [True, False] + ) + def test_no_items(self, compact): schema_obj = {"type": "array", "items": False} check_match_failure( bad_string="[42]", good_bytes=b"[", failure_byte=b"4", - allowed_bytes={Byte(b"]")}, # array must be empty + allowed_bytes={b"]"}, # array must be empty schema_obj=schema_obj, + maybe_whitespace=True, + compact=compact, ) @pytest.mark.parametrize( @@ -1808,12 +1936,17 @@ def test_additionalProperties(self, schema_obj): schema_obj, ) - def test_no_additionalProperties(self): + @pytest.mark.parametrize( + "compact", [True, False] + ) + def test_no_additionalProperties(self, compact): schema_obj = {"type": "object", "additionalProperties": False} check_match_failure( bad_string='{"a": 42}', good_bytes=b"{", failure_byte=b'"', - allowed_bytes={Byte(b"}")}, # object must be empty + allowed_bytes={b"}"}, # object must be empty schema_obj=schema_obj, + maybe_whitespace=True, + compact=compact, ) diff --git a/tests/unit/library/test_pydantic.py b/tests/unit/library/test_pydantic.py index e15f623d2..0cf662914 100644 --- a/tests/unit/library/test_pydantic.py +++ b/tests/unit/library/test_pydantic.py @@ -1,15 +1,15 @@ import inspect from json import dumps as json_dumps +from functools import partial from typing import Any, Dict, Generic, List, Literal, Tuple, Type, TypeVar, Union, Set import pydantic import pytest from pydantic.json_schema import to_jsonable_python as pydantic_to_jsonable_python -from guidance import json as gen_json -from guidance import models -from guidance._grammar import Byte, ByteRange -from ...utils import check_match_failure as _check_match_failure +from guidance import models, json as gen_json +from guidance.library._json import WHITESPACE +from ...utils import check_match_failure as _check_match_failure, generate_and_check as _generate_and_check def to_compact_json(target: Any) -> str: @@ -66,40 +66,33 @@ def generate_and_check( ): # Sanity check what we're being asked target_obj = validate_obj(target_obj, pydantic_model) + prepared_json = to_compact_json(target_obj) + assert validate_string(prepared_json, pydantic_model) == target_obj - # Define grammar with capture key - CAPTURE_KEY = "my_capture" - grammar = gen_json(name=CAPTURE_KEY, schema=pydantic_model) - - # Test that grammar matches string - json_string = to_compact_json(target_obj) - matches = grammar.match(json_string, raise_exceptions=True) - assert matches.partial == False - - # Run with the mock model - prepared_string = f"{json_string}" - lm = models.Mock(prepared_string.encode(), echo=False) - lm += grammar - - # Make sure the round trip works - round_trip_object = validate_string(lm[CAPTURE_KEY], pydantic_model) - assert round_trip_object == target_obj + # Check that the grammar can produce the literal prepared_json string + grammar_callable = partial(gen_json, schema=pydantic_model) + _generate_and_check(grammar_callable, prepared_json) def check_match_failure( bad_obj: Any, good_bytes: bytes, failure_byte: bytes, - allowed_bytes: Set[Union[Byte, ByteRange]], + allowed_bytes: Set[bytes], pydantic_model: Union[Type[pydantic.BaseModel], pydantic.TypeAdapter], + maybe_whitespace: bool, + compact: bool, ): bad_string = to_compact_json(bad_obj) - grammar = gen_json(schema=pydantic_model) + grammar = gen_json(schema=pydantic_model, compact=compact) _check_match_failure( bad_string=bad_string, good_bytes=good_bytes, failure_byte=failure_byte, - allowed_bytes=allowed_bytes, + allowed_bytes=( + allowed_bytes.union(WHITESPACE) if (maybe_whitespace and not compact) + else allowed_bytes + ), grammar=grammar, ) @@ -174,14 +167,19 @@ def test_heterogeneous(self): model = pydantic.TypeAdapter(Tuple[int, bool]) generate_and_check((1, True), model) - def test_maxitems(self): + @pytest.mark.parametrize( + "compact", [True, False] + ) + def test_maxitems(self, compact): model = pydantic.TypeAdapter(Tuple[int,]) check_match_failure( bad_obj=(1, 2), good_bytes=b"[1", failure_byte=b",", - allowed_bytes={ByteRange(b"09"), Byte(b"]")}, + allowed_bytes={b"]", *{bytes([i]) for i in range(ord("0"), ord("9") + 1)}}, pydantic_model=model, + maybe_whitespace=True, + compact=compact, ) @@ -252,15 +250,18 @@ def test_generic(self, my_type, my_obj): generate_and_check(obj, model) @pytest.mark.parametrize( - "my_type, my_obj, good_bytes, failure_byte, allowed_bytes", + "compact", [True, False] + ) + @pytest.mark.parametrize( + "my_type, my_obj, good_bytes, failure_byte, allowed_bytes, maybe_whitespace", [ - (bool, "True", b"", b'"', {Byte(b"t"), Byte(b"f")}), - (str, 42, b"", b"4", {Byte(b'"')}), - (int, False, b"", b"f", {Byte(b"0"), ByteRange(b"19"), Byte(b"-")}), + (bool, "True", b"", b'"', {b"t", b"f"}, True), + (str, 42, b"", b"4", {b'"'}, True), + (int, False, b"", b"f", {b"-", *{bytes([i]) for i in range(ord("0"), ord("9") + 1)}}, True), ], ) def test_bad_generic( - self, my_type, my_obj, good_bytes, failure_byte, allowed_bytes + self, my_type, my_obj, good_bytes, failure_byte, allowed_bytes, maybe_whitespace, compact ): model = self.SimpleGeneric[my_type] obj = {"my_obj": my_obj} @@ -270,4 +271,6 @@ def test_bad_generic( failure_byte=failure_byte, allowed_bytes=allowed_bytes, pydantic_model=model, + maybe_whitespace=maybe_whitespace, + compact=compact, ) diff --git a/tests/unit/library/test_regex.py b/tests/unit/library/test_regex.py index 2599ce211..ef079a0a6 100644 --- a/tests/unit/library/test_regex.py +++ b/tests/unit/library/test_regex.py @@ -1,33 +1,89 @@ import pytest from functools import partial -from guidance import regex -from guidance._grammar import Byte, ByteRange +from guidance.library._gen import regex from ...utils import check_match_failure, generate_and_check +def byte_range(byterange: bytes): + start, end = byterange + return {bytes([i]) for i in range(start, end + 1)} + + +ASCII_START_BYTES = byte_range(b"\x00\x7f") +UNICODE_SPECIAL_START_BYTES = byte_range(b"\xc2\xf4") +UNICODE_START_BYTES = ASCII_START_BYTES | UNICODE_SPECIAL_START_BYTES + +# Equivalent to the following (in python 3.12) +# { +# char.encode()[:1] +# for codepoint in range(0x110000) +# if unicodedata.category(char := chr(codepoint)) == "Nd" +# } +UNICODE_DIGIT_START_BYTES = byte_range(b"09") | { + b"\xd9", + b"\xdb", + b"\xdf", + b"\xe0", + b"\xe1", + b"\xea", + b"\xef", + b"\xf0", +} + +# Equivalent to the following (in python 3.12) +# { +# char.encode()[:1] +# for codepoint in range(0x110000) +# if unicodedata.category(char := chr(codepoint))[:1] in {"L", "M", "N"} +# } | {b"_"} +UNICODE_WORD_START_BYTES = ( + byte_range(b"09") + | byte_range(b"az") + | byte_range(b"AZ") + | {b"_"} + | (byte_range(b"\xc2\xdf") | byte_range(b"\xe0\xed") | byte_range(b"\xef\xf0") | {b"\xf3"}) +) + +# Equivalent to the following (in python 3.12) +# { +# char.encode("utf-8", "surrogatepass")[:1] +# for codepoint in range(0x110000) +# if unicodedata.category(char := chr(codepoint))[:1] not in {"L", "M", "N"} +# } - {b"_"} +UNICODE_NON_WORD_START_BYTES = ( + ASCII_START_BYTES - (byte_range(b"09") | byte_range(b"az") | byte_range(b"AZ") | {b"_"}) +) | ( + byte_range(b"\xcd\xcf") + | byte_range(b"\xd4\xd9") + | byte_range(b"\xdb\xe4") + | byte_range(b"\xed\xf4") + | {b"\xc2", b"\xc3", b"\xd2", b"\xcb", b"\xea"} +) + + class TestCharacterClasses: @pytest.mark.parametrize( - "pattern, string, stop_char", + "pattern, string", [ - (r"[abc]+", "cbabbaccabc", chr(7)), - (r"[a-z]+", "thequickbrownfoxjumpsoverthelazydog", chr(7)), - (r"[0-9]+", "9876543210", chr(7)), - (r"[b-y]+", "by", chr(7)), # range is left and right inclusive - (r"[a-f0-9]+", "abcdef0123456789", chr(7)), - (r"[abcA-Z]+", "abcABCXYZ", chr(7)), - (r"[a-z\d]+", "abc123", chr(7)), - (r"[^abc]+", "ABCxyz8743-!@#$%^&*()_+", "a"), - (r"[^\d]+", "abcXYZ-!@#$%^&*()_+", "8"), - (r"[^B-Z]+", "qwertyA", "B"), - (r"[^a-z\d]+", "ABCDEF-!@#$%^&*()_+", "a"), - (r"[^\n]+", "ABCxyz8743-!@#$%^&*()_+", "\n"), + (r"[abc]+", "cbabbaccabc"), + (r"[a-z]+", "thequickbrownfoxjumpsoverthelazydog"), + (r"[0-9]+", "9876543210"), + (r"[b-y]+", "by"), # range is left and right inclusive + (r"[a-f0-9]+", "abcdef0123456789"), + (r"[abcA-Z]+", "abcABCXYZ"), + (r"[a-z\d]+", "abc123"), + (r"[^abc]+", "ABCxyz8743-!@#$%^&*()_+"), + (r"[^\d]+", "abcXYZ-!@#$%^&*()_+"), + (r"[^B-Z]+", "qwertyA"), + (r"[^a-z\d]+", "ABCDEF-!@#$%^&*()_+"), + (r"[^\n]+", "ABCxyz8743-!@#$%^&*()_+"), ], ) - def test_good(self, pattern, string, stop_char): + def test_good(self, pattern, string): grammar_callable = partial(regex, pattern=pattern) - generate_and_check(grammar_callable, string, stop_char=stop_char) + generate_and_check(grammar_callable, string) @pytest.mark.parametrize( "pattern, string, good_bytes, failure_byte, allowed_bytes", @@ -37,91 +93,91 @@ def test_good(self, pattern, string, stop_char): "cbabbaccabcx", b"cbabbaccabc", b"x", - {Byte(b"a"), Byte(b"b"), Byte(b"c")}, + {b"a", b"b", b"c"}, ), ( r"[a-z]+", "thequickbrownfoxjumpsoverthelazydogX", b"thequickbrownfoxjumpsoverthelazydog", b"X", - {ByteRange((b"az"))}, + byte_range(b"az"), ), ( r"[0-9]+", "9876543210x", b"9876543210", b"x", - {ByteRange((b"09"))}, + byte_range(b"09"), ), ( r"[b-y]+", "bya", b"by", b"a", - {ByteRange(b"by")}, + byte_range(b"by"), ), # range doesn't overflow left ( r"[b-y]+", "byz", b"by", b"z", - {ByteRange(b"by")}, + byte_range(b"by"), ), # range doesn't overflow right ( r"[a-f0-9]+", "abcdef0123456789x", b"abcdef0123456789", b"x", - {ByteRange(b"af"), ByteRange(b"09")}, + {*byte_range(b"af"), *byte_range(b"09")}, ), ( r"[abcA-Z]+", "abcABCXYZx", b"abcABCXYZ", b"x", - {Byte(b"a"), Byte(b"b"), Byte(b"c"), ByteRange(b"AZ")}, + {b"a", b"b", b"c", *byte_range(b"AZ")}, ), ( r"[a-z\d]+", "abc123@", b"abc123", b"@", - {ByteRange(b"az"), ByteRange(b"09")}, + byte_range(b"az") | UNICODE_DIGIT_START_BYTES, ), ( r"[^abc]+", "ABCxyz8743-!@#$%^&*()_+a", b"ABCxyz8743-!@#$%^&*()_+", b"a", - {ByteRange(b"\x00`"), ByteRange(b"d\x7f")}, + UNICODE_START_BYTES - {b"a", b"b", b"c"}, ), ( r"[^\d]+", "abcXYZ-!@#$%^&*()_+6", b"abcXYZ-!@#$%^&*()_+", b"6", - {ByteRange(b"\x00/"), ByteRange(b":\x7f")}, + UNICODE_START_BYTES - byte_range(b"09"), ), ( r"[^B-Z]+", "qwertyAB", b"qwertyA", b"B", - {ByteRange(b"\x00A"), ByteRange(b"[\x7f")}, + UNICODE_START_BYTES - byte_range(b"BZ"), ), ( r"[^a-z\d]+", "ABCDEF-!@#$%^&*()_+x", b"ABCDEF-!@#$%^&*()_+", b"x", - {ByteRange(b"\x00/"), ByteRange(b":`"), ByteRange(b"{\x7f")}, + UNICODE_START_BYTES - (byte_range(b"az") | byte_range(b"09")), ), ( r"[^\n]+", "ABCxyz8743-!@#$%^&*()_+\n", b"ABCxyz8743-!@#$%^&*()_+", b"\n", - {ByteRange(b"\x00\t"), ByteRange(b"\x0b\x7f")}, + UNICODE_START_BYTES - {b"\n"}, ), ], ) @@ -185,42 +241,42 @@ def test_nested_quantifiers(self, pattern, string): "axb", b"a", b"x", - {Byte(b"a"), Byte(b"b")}, + {b"a", b"b"}, ), # 'x' disrupts the match ( r"a+b", "b", b"", b"b", - {Byte(b"a")}, + {b"a"}, ), # 'a+' requires at least one 'a' before 'b' ( r"a?b", "x", b"", b"x", - {Byte(b"a"), Byte(b"b")}, + {b"a", b"b"}, ), # 'a?' requires zero or one 'a' before 'b' ( r"a?b", "axb", b"a", b"x", - {Byte(b"b")}, + {b"b"}, ), # 'x' disrupts the match ( r"a?b", "aab", b"a", b"a", - {Byte(b"b")}, + {b"b"}, ), # Second 'a' is too many ( r"(xyz)?abc", "xyabc", b"xy", b"a", - {Byte(b"z")}, + {b"z"}, ), # Expected 'z' ( r"(xyz)?abc", @@ -248,20 +304,18 @@ def test_nested_quantifiers(self, pattern, string): "aab", b"aa", b"b", - {Byte(b"a")}, + {b"a"}, ), # Less than the minimum 'a's before 'b' ( r"a{3,5}b", "aaaaaab", b"aaaaa", b"a", - {Byte(b"b")}, + {b"b"}, ), # More than the maximum 'a's before 'b' ], ) - def test_quantifiers_failure( - self, pattern, string, good_bytes, failure_byte, allowed_bytes - ): + def test_quantifiers_failure(self, pattern, string, good_bytes, failure_byte, allowed_bytes): check_match_failure( bad_string=string, good_bytes=good_bytes, @@ -332,55 +386,53 @@ def test_alternations_with_quantifiers(self, pattern, string): "c", b"", b"c", - {Byte(b"a"), Byte(b"b")}, + {b"a", b"b"}, ), # Neither 'a' nor 'b' ( r"apple|orange", "banana", b"", b"b", - {Byte(b"a"), Byte(b"o")}, + {b"a", b"o"}, ), # Neither 'apple' nor 'orange' ( r"100|200", "300", b"", b"3", - {Byte(b"1"), Byte(b"2")}, + {b"1", b"2"}, ), # Neither '100' nor '200' ( r"(a|b)c|d", "ae", b"a", b"e", - {Byte(b"c"), Byte(b"c")}, + {b"c", b"c"}, ), # Neither 'ac' nor 'bc' nor 'd' ( r"(a|b)+", "abbaabbabc", b"abbaabbab", b"c", - {Byte(b"a"), Byte(b"b")}, + {b"a", b"b"}, ), # 'c' does not match pattern '(a|b)+' ( r"cat|dog", "car", b"ca", b"r", - {Byte(b"t")}, + {b"t"}, ), # 't' should be forced ( r"(dog|cat)s?", "cars", b"ca", b"r", - {Byte(b"t")}, + {b"t"}, ), # 't' should be forced ], ) - def test_alternations_failures( - self, pattern, string, good_bytes, failure_byte, allowed_bytes - ): + def test_alternations_failures(self, pattern, string, good_bytes, failure_byte, allowed_bytes): check_match_failure( bad_string=string, good_bytes=good_bytes, @@ -399,7 +451,7 @@ class TestDot: ) def test_dot(self, pattern, string): grammar_callable = partial(regex, pattern=pattern) - generate_and_check(grammar_callable, string, stop_char="\n") + generate_and_check(grammar_callable, string) @pytest.mark.parametrize( "pattern, string, good_bytes, failure_byte, allowed_bytes", @@ -409,13 +461,11 @@ def test_dot(self, pattern, string): "ABCxyz8743-!@#$%^&*()_+\n", b"ABCxyz8743-!@#$%^&*()_+", b"\n", - {ByteRange(b"\x00\t"), ByteRange(b"\x0b\x7f")}, + UNICODE_START_BYTES - {b"\n"}, ), ], ) - def test_dot_failures( - self, pattern, string, good_bytes, failure_byte, allowed_bytes - ): + def test_dot_failures(self, pattern, string, good_bytes, failure_byte, allowed_bytes): check_match_failure( bad_string=string, good_bytes=good_bytes, @@ -427,25 +477,25 @@ def test_dot_failures( class TestSpecialCharacters: @pytest.mark.parametrize( - "pattern, string, stop_char", + "pattern, string", [ - (r"\d+", "1234567890", chr(7)), - (r"[^\D]+", "1234567890", chr(7)), - (r"\D+", "ABCxyz-!@#$%^&*()_+", "9"), - (r"[^\d]+", "ABCxyz-!@#$%^&*()_+", "9"), - (r"\w+", "abcABC123_", chr(7)), - (r"[^\W]+", "abcABC123_", chr(7)), - (r"\W+", " -!@#$%^&*()+", "9"), - (r"[^\w]+", "-!@#$%^&*()+", "9"), - (r"\s+", " \t\n\r\f\v", chr(7)), - (r"[^\S]+", " \t\n\r\f\v", chr(7)), - (r"\S+", "ABCxyz8743-!@#$%^&*()_+", " "), - (r"[^\s]+", "ABCxyz8743-!@#$%^&*()_+", " "), + (r"\d+", "1234567890"), + (r"[^\D]+", "1234567890"), + (r"\D+", "ABCxyz-!@#$%^&*()_+"), + (r"[^\d]+", "ABCxyz-!@#$%^&*()_+"), + (r"\w+", "abcABC123_"), + (r"[^\W]+", "abcABC123_"), + (r"\W+", " -!@#$%^&*()+"), + (r"[^\w]+", "-!@#$%^&*()+"), + (r"\s+", " \t\n\r\f\v"), + (r"[^\S]+", " \t\n\r\f\v"), + (r"\S+", "ABCxyz8743-!@#$%^&*()_+"), + (r"[^\s]+", "ABCxyz8743-!@#$%^&*()_+"), ], ) - def test_good(self, pattern, string, stop_char): + def test_good(self, pattern, string): grammar_callable = partial(regex, pattern=pattern) - generate_and_check(grammar_callable, string, stop_char=stop_char) + generate_and_check(grammar_callable, string) @pytest.mark.parametrize( "pattern, string, good_bytes, failure_byte, allowed_bytes", @@ -455,55 +505,39 @@ def test_good(self, pattern, string, stop_char): "0123456789x", b"0123456789", b"x", - {ByteRange(b"09")}, + UNICODE_DIGIT_START_BYTES, ), ( r"\D+", "ABCxyz-!@#$%^&*()_+1", b"ABCxyz-!@#$%^&*()_+", b"1", - {ByteRange(b"\x00/"), ByteRange(b":\x7f")}, - ), - ( - r"\w+", - "abcABC123_@", - b"abcABC123_", - b"@", - {ByteRange(b"az"), ByteRange(b"AZ"), ByteRange(b"09"), Byte(b"_")}, - ), - ( - r"\W+", - " -!@#$%^&*()+a", - b" -!@#$%^&*()+", - b"a", - { - ByteRange(b"\x00/"), - ByteRange(b":@"), - ByteRange(b"[^"), - Byte(b"`"), - ByteRange(b"{\x7f"), - }, + UNICODE_START_BYTES - byte_range(b"09"), ), + (r"\w+", "abcABC123_@", b"abcABC123_", b"@", UNICODE_WORD_START_BYTES), + (r"\W+", " -!@#$%^&*()+a", b" -!@#$%^&*()+", b"a", UNICODE_NON_WORD_START_BYTES), ( r"\s+", " \t\n\r\f\v8", b" \t\n\r\f\v", b"8", - { - Byte(b" "), - Byte(b"\t"), - Byte(b"\n"), - Byte(b"\r"), - Byte(b"\f"), - Byte(b"\v"), - }, + {b" ", b"\t", b"\n", b"\r", b"\f", b"\v"} + | {b"\xc2", b"\xe1", b"\xe2", b"\xe3"}, # include unicode whitespace starts ), ( r"\S+", "abcABC123_ ", b"abcABC123_", b" ", - {ByteRange(b"\x00\x08"), ByteRange(b"\x0e\x1f"), ByteRange(b"!\x7f")}, + UNICODE_START_BYTES + - { + b" ", + b"\t", + b"\n", + b"\r", + b"\f", + b"\v", + }, ), ], ) diff --git a/tests/unit/library/test_sequences.py b/tests/unit/library/test_sequences.py index e98d3d8fc..316a5ccdf 100644 --- a/tests/unit/library/test_sequences.py +++ b/tests/unit/library/test_sequences.py @@ -21,9 +21,9 @@ def test_with_select(self, test_string): @pytest.mark.parametrize( ["bad_string", "good_bytes", "failure_byte", "allowed_bytes"], [ - ("bbb", b"bbb", b"B", set([Byte(b"b")])), - ("bbbbb", b"bbbb", b"b", set([Byte(b"B")])), - ("aaaa", b"", b"a", set([Byte(b"b")])), + ("bbb", b"bbb", b"B", {b"b"}), + ("bbbbb", b"bbbb", b"b", {b"B"}), + ("aaaa", b"", b"a", {b"b"}), ], ) def test_bad_repeats(self, bad_string: str, good_bytes, failure_byte, allowed_bytes): @@ -59,8 +59,8 @@ def test_with_select(self, test_string): @pytest.mark.parametrize( ["bad_string", "good_bytes", "failure_byte", "allowed_bytes"], [ - ("bbbbb", b"bbbb", b"b", set([Byte(b"B")])), - ("aaaa", b"", b"a", set([Byte(b"b"), Byte(b"B")])), + ("bbbbb", b"bbbb", b"b", {b"B"}), + ("aaaa", b"", b"a", {b"b", b"B"}), ], ) def test_bad_repeats(self, bad_string: str, good_bytes, failure_byte, allowed_bytes): @@ -105,8 +105,8 @@ def test_unconstrained(self, test_string: str): @pytest.mark.parametrize( ["bad_string", "good_bytes", "failure_byte", "allowed_bytes"], [ - ("ba", b"b", b"a", set([Byte(b"b"), Byte(b"B")])), - ("a", b"", b"a", set([Byte(b"b"), Byte(b"B")])), + ("ba", b"b", b"a", {b"b", b"B"}), + ("a", b"", b"a", {b"b", b"B"}), ], ) def test_bad_repeats_unconstrained( @@ -136,10 +136,10 @@ def test_min_length_zero(self, test_string): @pytest.mark.parametrize( ["bad_string", "good_bytes", "failure_byte", "allowed_bytes"], [ - ("bbb", b"bbb", b"B", set([Byte(b"b")])), - ("aaaa", b"", b"a", set([Byte(b"b")])), - ("bbbba", b"bbbb", b"a", set([Byte(b"b"), Byte(b"B")])), - ("bbbbbba", b"bbbbbb", b"a", set([Byte(b"b"), Byte(b"B")])), + ("bbb", b"bbb", b"B", {b"b"}), + ("aaaa", b"", b"a", {b"b"}), + ("bbbba", b"bbbb", b"a", {b"b", b"B"}), + ("bbbbbba", b"bbbbbb", b"a", {b"b", b"B"}), ], ) def test_bad_repeats_min_length( @@ -170,8 +170,8 @@ def test_max_length_zero(self): @pytest.mark.parametrize( ["bad_string", "good_bytes", "failure_byte", "allowed_bytes"], [ - ("bbb", b"bb", b"b", set([Byte(b"B")])), - ("aa", b"", b"a", set([Byte(b"b"), Byte(b"B")])), + ("bbb", b"bb", b"b", {b"B"}), + ("aa", b"", b"a", {b"b", b"B"}), ], ) def test_bad_repeats_max_length( @@ -235,11 +235,11 @@ def test_min_max_length_equal(self, n_repeats): @pytest.mark.parametrize( ["bad_string", "good_bytes", "failure_byte", "allowed_bytes"], [ - ("", b"", b"B", set([Byte(b"b")])), - ("bbb", b"bb", b"b", set([Byte(b"B")])), - ("aa", b"", b"a", set([Byte(b"b")])), - ("ba", b"b", b"a", set([Byte(b"b"), Byte(b"B")])), - ("bba", b"bb", b"a", set([Byte(b"B")])), + ("", b"", b"B", {b"b"}), + ("bbb", b"bb", b"b", {b"B"}), + ("aa", b"", b"a", {b"b"}), + ("ba", b"b", b"a", {b"b", b"B"}), + ("bba", b"bb", b"a", {b"B"}), ], ) def test_bad_repeats_min_max_length( diff --git a/tests/unit/library/test_subgrammar.py b/tests/unit/library/test_subgrammar.py new file mode 100644 index 000000000..7e81545b5 --- /dev/null +++ b/tests/unit/library/test_subgrammar.py @@ -0,0 +1,27 @@ +import pytest +from guidance.library._subgrammar import subgrammar, lexeme + +class TestEndingLexemeAmbiguous: + @pytest.mark.parametrize( + "skip_rx", + [None, r"\s", r"\s+", r"\s*"] + ) + @pytest.mark.parametrize( + "string", + ["123"] + ) + def test_lexeme_can_be_done_even_if_could_match_more(self, string, skip_rx): + g1 = subgrammar(body=lexeme(r"\d+"), skip_regex=skip_rx, name="mycap") + assert (m := g1.match(string)) is not None and m.captures["mycap"] == string + g2 = g1 + "x" + assert (m := g2.match(f"{string}x")) is not None and m.captures["mycap"] == string + + @pytest.mark.parametrize( + "string", + ["1", "123", "1x", "123x"] + ) + def test_nullable_final_lexeme(self, string): + g = subgrammar(body=lexeme(r"\d+")+lexeme(r"x?"), name="mycap") + match = g.match(string) + assert match is not None + assert match.captures["mycap"] == string diff --git a/tests/unit/library/test_substring.py b/tests/unit/library/test_substring.py index 7c2db7d56..51277d02e 100644 --- a/tests/unit/library/test_substring.py +++ b/tests/unit/library/test_substring.py @@ -21,7 +21,7 @@ ], ) def test_mocked_substring(mock_string, target_string, expected_string): - m = models.Mock(f"{mock_string}") + m = models.Mock(f"{mock_string}") lm = m + substring(target_string, name="result") assert lm["result"] == expected_string diff --git a/tests/unit/test_grammar.py b/tests/unit/test_grammar.py index 3b2774ffc..81f633dbe 100644 --- a/tests/unit/test_grammar.py +++ b/tests/unit/test_grammar.py @@ -1,5 +1,6 @@ +import pytest import guidance -from guidance import gen, models, optional, select +from guidance import gen, models, optional, select, string def test_select_reset_pos(): @@ -15,6 +16,22 @@ def test_select_longer(): assert lm["text"] == "nice man." +@pytest.mark.xfail( + reason="Lexer sees 'a' then 'b' and here decides to continue matching abq)" +) +def test_select_ambiguous_lexeme_boundary(): + lm = models.Mock(b"abQ") + lm += select(options=["a", "abq", "c"], name="prefix") + optional("bQ") + assert lm["prefix"] == "a" + + +def test_select_ambiguous_lexeme_boundary_manual_fix(): + # Manual fix to the issue in test_select_ambiguous_lexeme_boundary by splitting the "abq" lexeme into two lexemes + lm = models.Mock(b"abQ") + lm += select(options=["a", string("a")+string("bq"), "c"], name="prefix") + optional("bQ") + assert lm["prefix"] == "a" + + def test_select_empty(): """This tests to ensure that we save empty capture groups.""" lm = models.Mock(b"This is a test") @@ -90,3 +107,29 @@ def grammar3(lm): grammar1() grammar2() grammar3() + +class TestMatch: + @pytest.mark.parametrize( + "string", + ["456", "456x"] + ) + def test_full_match(self, string): + g = "123" + gen(regex=r"\d+x?", name="mycap") + match = g.match(f"123{string}") + assert match is not None + assert not match.partial + assert match.captures["mycap"] == string + + @pytest.mark.parametrize( + "string", + # "456" fails -- think about supporting? + # (reasonable to expect either behavior) + ["456x"] + ) + def test_partial_match(self, string): + g = "123" + gen(regex=r"\d+x?", name="mycap") + "789" + assert g.match(f"123{string}") is None + match = g.match(f"123{string}", allow_partial=True) + assert match is not None + assert match.partial + assert match.captures["mycap"] == string \ No newline at end of file diff --git a/tests/unit/test_ll.py b/tests/unit/test_ll.py new file mode 100644 index 000000000..56ca1e533 --- /dev/null +++ b/tests/unit/test_ll.py @@ -0,0 +1,380 @@ +from typing import Any, List +import tokenizers +import llguidance +import json +import textwrap +import guidance +import pytest +from guidance import ( + gen, + select, + optional, + byte_range, + one_or_more, + GrammarFunction, + string, +) +from guidance._grammar import as_regular_grammar +from guidance.library._subgrammar import subgrammar, lexeme + +log_level = 10 + + +class PhiTokenizer: + _ll_tokenizer = None + _instance = None + + @staticmethod + def instance(): + if PhiTokenizer._instance is None: + PhiTokenizer._instance = PhiTokenizer() + return PhiTokenizer._instance + + @staticmethod + def ll_tokenizer(): + if PhiTokenizer._ll_tokenizer is None: + PhiTokenizer._ll_tokenizer = llguidance.LLTokenizer( + llguidance.TokenizerWrapper(PhiTokenizer()) + ) + return PhiTokenizer._ll_tokenizer + + def tokenize_str(self, s: str) -> List[int]: + return self.hf_tokenizer.encode(s).ids + + def __init__(self) -> None: + self.hf_tokenizer: tokenizers.Tokenizer = tokenizers.Tokenizer.from_pretrained( + "microsoft/Phi-3-mini-128k-instruct" + ) + empty = self.tokenize_str("") + if empty: + self.bos_token_id = empty[0] + else: + self.bos_token_id = None + eos = self.tokenize_str("") + assert len(eos) == 1 + self.eos_token_id = eos[0] + self.tokens = [] + for i in range(self.hf_tokenizer.get_vocab_size()): + t: str = self.hf_tokenizer.id_to_token(i) + if t.startswith("<0x"): + self.tokens.append(bytes([int(t[3:5], 16)])) + else: + t = t.replace("▁", " ") + self.tokens.append(t.encode("utf-8")) + assert len(self.tokens) == self.hf_tokenizer.get_vocab_size() + + def __call__(self, s): + return self.tokenize_str(s) + + +def check_eq(label: str, tokens: List[int], expected_tokens: str): + if log_level > 0: + print(f"Checking {label}: {repr(expected_tokens)}") + t = PhiTokenizer.ll_tokenizer() + actual_tokens = t.test_trace_tokens(tokens) + assert ( + actual_tokens == expected_tokens + ), f"Tokens mismatch in {label}\n {repr(actual_tokens)}\n {repr(expected_tokens)}" + + +def tokenize_trace(s: str): + if log_level > 0: + print("Tokenizing", repr(s)) + r: List[int] = [] + for word in s.split("‧"): + if word == "≺EOS≻": + r.append(PhiTokenizer.instance().eos_token_id) + continue + tt = PhiTokenizer.ll_tokenizer().tokenize_str(word) + assert len(tt) == 1, f"Expected single token for {repr(word)} got {tt}" + r.append(tt[0]) + return r + + +def check_grammar(grm: GrammarFunction, output: List[str]): + """ + Check that the grammar generates the expected output. + + Output is a list of strings, each of which is a sequence of tokens. + Tokens in the string are separated with "‧". + Strings at even positions are "forced tokens", and strings at odd positions + are "generated tokens". + We check that the grammars forces the forced tokens (first of which is the + prompt), and that it allows in the mask the generated tokens. + + These tests are "recorded" by passing "test_trace": true in the llguidance + request and post-processing. + """ + print("\nChecking grammar") + interp = llguidance.LLInterpreter( + PhiTokenizer.ll_tokenizer(), json.dumps(grm.ll_serialize()), log_level=log_level + ) + prompt = interp.process_prompt(PhiTokenizer.instance().tokenize_str("")) + check_eq("prompt", prompt, output[0]) + idx = 1 + gen_tokens = tokenize_trace(output[idx]) + for _ in range(200): + mask, cmd = interp.mid_process() + cmd = json.loads(cmd) + if log_level >= 1: + print(mask is not None, cmd) + if cmd["stop"]: + assert idx >= len(output) - 1, f"Expected more output at {idx}" + assert not gen_tokens, "Expected more tokens to generate" + break + if mask: + if not gen_tokens: + raise ValueError("No more tokens to generate") + tok = gen_tokens[0] + del gen_tokens[0:1] + assert mask[tok] > 0, f"Token {tok} not allowed" + bt, toks = interp.post_process(tok) + assert bt == 0 + assert toks == [tok] + else: + bt, toks = interp.post_process(None) + assert not gen_tokens, "Expected more tokens to generate" + idx += 1 + expected = output[idx] + if "↶" in expected: + r = expected.split("↶") + assert len(r) == 2 + expected = r[1] + assert bt == int(r[0]), f"Expected backtrack {r[0]} got {bt}" + check_eq(f"step {idx}", toks, expected) + idx += 1 + if idx < len(output): + gen_tokens = tokenize_trace(output[idx]) + + +def test_llparser(): + grm = ( + "Q: Are dolphins fish?\nA: " + + gen("dolphins", regex="Yes|No", max_tokens=10) + + "\nQ: Are sharks fish?\nA: " + + gen("sharks", regex="Yes|No", max_tokens=10) + ) + check_grammar( + grm, + [ + "Q‧:‧ Are‧ dol‧ph‧ins‧ fish‧?‧\n‧A‧:", + " No", # note the prefix space - moved by token healing + "\n‧Q‧:‧ Are‧ sh‧arks‧ fish‧?‧\n‧A‧:", + " Yes", + ], + ) + + grm = ( + "Power frequency is " + + gen("number", regex="[0-9]+", max_tokens=5) + + "Hz; voltage is " + + gen("number", regex="[0-9]+", max_tokens=5) + + "V" + ) + check_grammar( + grm, + [ + "Power‧ frequency‧ is‧ ", + "5‧0‧Hz", # no EoS needed on 50Hz + ";‧ voltage‧ is‧ ", + "2‧2‧0‧V", + ], + ) + + grm = "Q: 7 * 8\nA: " + gen("text", regex="[0-9]+", max_tokens=5) + # EoS finishes generation + check_grammar(grm, ["Q‧:‧ ‧7‧ *‧ ‧8‧\n‧A‧:‧ ", "5‧6‧≺EOS≻"]) + + +@pytest.mark.parametrize( + "grm", + [ + # grammar turned into regex: + "Dolphin name: " + + as_regular_grammar('"' + byte_range(b"A", b"Z") + one_or_more(byte_range(b"a", b"z")) + '"') + + ",", + # regular gen() + "Dolphin name: " + gen(regex=r'"[A-Z][a-z]+"') + ",", + # regular gen(), comma in regex + "Dolphin name: " + gen(regex=r'"[A-Z][a-z]+",'), + # regular gen(), quotes outside + 'Dolphin name: "' + gen(regex=r"[A-Z][a-z]+") + '",', + ], +) +@pytest.mark.parametrize( + "output", + [ + ['D‧olph‧in‧ name‧:‧ "', 'F‧li‧pper‧"', ","], # separate comma + ['D‧olph‧in‧ name‧:‧ "', 'F‧li‧pper‧",'], # check that we allow `",` as a single token: + ], +) +def test_ll_dolphin(grm: GrammarFunction, output: List[str]): + check_grammar(grm, output) + + +def test_ll_backtrack_stop(): + grm = "Count to 10: 1, 2, 3, 4, 5, 6, 7, " + gen("text", stop=",") + "\nNot quite." + check_grammar( + grm, + [ + "Count‧ to‧ ‧1‧0‧:‧ ‧1‧,‧ ‧2‧,‧ ‧3‧,‧ ‧4‧,‧ ‧5‧,‧ ‧6‧,‧ ‧7‧,", + " ‧8‧,", + "1↶\n‧Not‧ quite‧.", + ], + ) + + grm = ( + "Name: " + + gen(regex="E[a-z]+", stop_regex=["[a-b]", "[x-z]"]) + + "\nName: " + + gen(regex="E[a-z]+", stop_regex=["[a-b]", "[x-z]"]) + ) + check_grammar(grm, ["Name‧:", " Em‧ily", "1↶il‧\n‧Name‧:", " Emil‧ie‧a", "1↶"]) + + +def test_ll_pop_tokens(): + grm = "6 * 7 = " + subgrammar(body=lexeme("[0-9]{1,3}")) + "\n" + check_grammar(grm, ["6‧ *‧ ‧7‧ =‧ ", "4‧2‧\n"]) + + +def test_ll_nullable_lexeme(): + # make sure 'a' is not forced + check_grammar(gen(regex="a*"), ["", "a‧≺EOS≻"]) + # this one doesn't work - no lexeme was scanned by EOS, so we allow more lexemes... + check_grammar(gen(regex="a*"), ["", "≺EOS≻"]) + + # see that we can skip 5* + check_grammar( + "6 * 7 = " + gen(regex="5*") + gen(regex="[1-4][0-9]") + "\n", + ["6‧ *‧ ‧7‧ =‧ ", "4‧2", "\n"], + ) + + check_grammar( + "Here: 2 + 2 = " + subgrammar(name="num", body=lexeme("[0-9]+")), + ["Here‧:‧ ‧2‧ +‧ ‧2‧ =‧ ", "4‧≺EOS≻"], + ) + + # make sure it stops at EOS + check_grammar( + "Here: 2 + 2 = " + subgrammar(name="num", body=lexeme("[0-9]+") + lexeme(r"Q?")), + ["Here‧:‧ ‧2‧ +‧ ‧2‧ =‧ ", "4‧≺EOS≻"], + ) + + num = subgrammar( + body=select( + [ + lexeme(r"-?(?:0|[1-9][0-9]*)", contextual=True), + lexeme(r"-?(?:0|[1-9][0-9]*)(?:\.[0-9]+)", contextual=True), + ] + ) + ) + + # avoid early stop + check_grammar(num, ["", "1‧≺EOS≻"]) + check_grammar(num, ["", "0‧≺EOS≻"]) + check_grammar(num, ["", "1‧.‧1‧≺EOS≻"]) + check_grammar(num, ["", "0‧.‧1‧≺EOS≻"]) + + +def test_ll_nice_man(): + g = select(["a", "ab", "c"]) + check_grammar(g, ["", "a‧b"]) + check_grammar(g, ["", "a‧≺EOS≻"]) + check_grammar(g + "d", ["", "a‧d"]) + check_grammar(g + "d", ["", "a‧b", "d"]) + check_grammar(g + optional("d"), ["", "a‧b‧d"]) + check_grammar(g + optional("d"), ["", "a‧b‧≺EOS≻"]) + check_grammar(g + optional("d"), ["", "a‧≺EOS≻"]) + + # the example below should work, but only does when string() is used to + # break "abq" into two lexemes + # g = select(["a", "abq", "c"]) + optional("bQ") + g = select(["a", string("a") + string("bq"), "c"]) + optional("bQ") + check_grammar(g, ["", "a‧b‧q‧≺EOS≻"]) + check_grammar(g, ["", "a‧b‧Q"]) + + +def test_ll_stop_quote_comma(): + grm = ( + '{ "items": ["' + + gen("i1", regex=r"a+", stop='"') + + '",\n "' + + gen("i2", regex=r"b+", stop='"') + + '"] }' + ) + # make sure we allow ", as a single token; also "] + check_grammar(grm, ['{‧ "‧items‧":‧ ["', 'a‧",', '\n‧ ‧ "', 'b‧"]', " }"]) + # and as seprate tokens + check_grammar(grm, ['{‧ "‧items‧":‧ ["', 'a‧"', ',‧\n‧ ‧ "', 'b‧"', "]‧ }"]) + + +def test_ll_max_tokens(): + check_grammar( + "Name: " + gen("name", max_tokens=3) + " Height: " + gen("height", max_tokens=3), + ["Name‧:", " Em‧ily‧ Carter", " Height‧:", " ‧5‧'‧6"], + ) + # here we have two gen() with the same regex (so they are the same lexeme) + # but different max_tokens limits + check_grammar( + "Name: " + gen("name", max_tokens=2) + " Height: " + gen("height", max_tokens=3), + ["Name‧:", " Em‧ily", " Height‧:", " ‧5‧'‧6"], + ) + # now this is a strange case, where gen() is allowed together with the following + # string, and gen() runs out of tokens, so the fixed string takes over + # note how Emily is not repeated + check_grammar( + "Name: " + + gen("name", max_tokens=2) + + "Emily Carter is great; Height: " + + gen("height", max_tokens=3), + ["Name‧:", " Em‧ily", " Carter‧ is‧ great‧;‧ Height‧:", " ‧5‧'‧6"], + ) + + +def test_ll_fighter(): + @guidance(stateless=True) + def character_maker2(lm, id, description, valid_weapons): + lm += textwrap.dedent(f"""\ + {{ + "name": "{gen('name', stop='"')}", + "age": {gen('age', regex='[0-9]+', stop=',')}, + "armor": "{select(options=['leather', 'chainmail', 'plate'], name='armor')}", + "weapon": "{select(options=valid_weapons, name='weapon')}", + "class": "{gen('class', stop='"')}", + "mantra": "{gen('mantra', stop='"')}", + "strength": {gen('strength', regex='[0-9]+', stop=',')}, + "items": ["{gen('item', list_append=True, stop='"')}", "{gen('item', list_append=True, stop='"')}", "{gen('item', list_append=True, stop='"')}"] + }}""") + return lm + + grm = character_maker2(1, "A nimble fighter", ["axe", "sword", "bow"]) + check_grammar( + grm, + [ + '{‧\n‧ ‧ "‧name‧":', + ' "‧John‧ Do‧e‧"', + ',‧\n‧ ‧ "‧age‧":‧ ', + "3‧0‧,", + '\n‧ ‧ "‧arm‧or‧":‧ "', + "chain", + 'mail‧",‧\n‧ ‧ "‧we‧ap‧on‧":‧ "', + "s", + 'word‧",‧\n‧ ‧ "‧class‧":', + ' "‧war‧rior‧"', + ',‧\n‧ ‧ "‧m‧ant‧ra‧":', + ' "‧I‧ am‧ the‧ storm‧,‧ I‧ am‧ the‧ light‧ning‧,‧ I‧ am‧ the‧ th‧under‧."', + ',‧\n‧ ‧ "‧str‧ength‧":‧ ', + "1‧0‧0‧,", + '\n‧ ‧ "‧items‧":‧ ["', + 's‧word‧ of‧ light‧ning‧,‧ shield‧ of‧ th‧under‧,‧ hel‧met‧ of‧ storm‧."', + ",", + ' "‧s‧word‧ of‧ light‧ning‧,‧ shield‧ of‧ th‧under‧,‧ hel‧met‧ of‧ storm‧."', + ",", + ' "‧s‧word‧ of‧ light‧ning‧,‧ shield‧ of‧ th‧under‧,‧ hel‧met‧ of‧ storm‧."', + "]‧\n‧}", + ], + ) + + +if __name__ == "__main__": + test_llparser() diff --git a/tests/unit/test_model.py b/tests/unit/test_model.py index 51c74729b..965c48ce2 100644 --- a/tests/unit/test_model.py +++ b/tests/unit/test_model.py @@ -1,3 +1,4 @@ +import pytest import guidance from guidance import gen, models @@ -21,7 +22,9 @@ def ble(lm): assert "{{G|" not in str(model + ble()) - +@pytest.mark.xfail( + reason="llguidance currently emits an additional empty capture group when no explicit stop is provided" +) def test_model_set(): model = models.Mock() model = model.set("num", "4") diff --git a/tests/unit/test_parser.py b/tests/unit/test_parser.py index 63fafbd93..10569a41f 100644 --- a/tests/unit/test_parser.py +++ b/tests/unit/test_parser.py @@ -1,97 +1,96 @@ from guidance import char_set, one_or_more, select, string, zero_or_more -from guidance._grammar import Byte, ByteRange -from guidance._parser import EarleyCommitParser +from guidance._parser import ByteParser def test_one_or_more(): g = one_or_more("a") - parser = EarleyCommitParser(g) - assert parser.valid_next_bytes() == set([Byte(b"a")]) - parser.consume_byte(b"a") - assert parser.valid_next_bytes() == set([Byte(b"a")]) + parser = ByteParser(g) + assert parser.valid_next_bytes() == set([b"a"]) + parser.consume_bytes(b"a") + assert parser.valid_next_bytes() == set([b"a"]) def test_zero_or_more_and_one_or_more(): g = zero_or_more("a") + one_or_more("b") - parser = EarleyCommitParser(g) - assert parser.valid_next_bytes() == set([Byte(b"a"), Byte(b"b")]) - parser.consume_byte(b"a") - assert parser.valid_next_bytes() == set([Byte(b"a"), Byte(b"b")]) - parser.consume_byte(b"b") - assert parser.valid_next_bytes() == set([Byte(b"b")]) + parser = ByteParser(g) + assert parser.valid_next_bytes() == set([b"a", b"b"]) + parser.consume_bytes(b"a") + assert parser.valid_next_bytes() == set([b"a", b"b"]) + parser.consume_bytes(b"b") + assert parser.valid_next_bytes() == set([b"b"]) - parser = EarleyCommitParser(g) - assert parser.valid_next_bytes() == set([Byte(b"a"), Byte(b"b")]) - parser.consume_byte(b"b") - assert parser.valid_next_bytes() == set([Byte(b"b")]) - parser.consume_byte(b"b") - assert parser.valid_next_bytes() == set([Byte(b"b")]) + parser = ByteParser(g) + assert parser.valid_next_bytes() == set([b"a", b"b"]) + parser.consume_bytes(b"b") + assert parser.valid_next_bytes() == set([b"b"]) + parser.consume_bytes(b"b") + assert parser.valid_next_bytes() == set([b"b"]) def test_zero_or_more_and_one_or_more_mixed(): g = zero_or_more("a") + "test" + one_or_more("b") - parser = EarleyCommitParser(g) - assert parser.valid_next_bytes() == set([Byte(b"a"), Byte(b"t")]) - parser.consume_byte(b"t") - parser.consume_byte(b"e") - parser.consume_byte(b"s") - assert parser.valid_next_bytes() == set([Byte(b"t")]) - parser.consume_byte(b"t") - assert parser.valid_next_bytes() == set([Byte(b"b")]) + parser = ByteParser(g) + assert parser.valid_next_bytes() == set([b"a", b"t"]) + parser.consume_bytes(b"t") + parser.consume_bytes(b"e") + parser.consume_bytes(b"s") + assert parser.valid_next_bytes() == set([b"t"]) + parser.consume_bytes(b"t") + assert parser.valid_next_bytes() == set([b"b"]) def test_select(): g = select(["bob", "bill", "sue"]) - parser = EarleyCommitParser(g) - assert parser.valid_next_bytes() == set([Byte(b"b"), Byte(b"s")]) - parser.consume_byte(b"s") - assert parser.valid_next_bytes() == set([Byte(b"u")]) - parser.consume_byte(b"u") - assert parser.valid_next_bytes() == set([Byte(b"e")]) + parser = ByteParser(g) + assert parser.valid_next_bytes() == set([b"b", b"s"]) + parser.consume_bytes(b"s") + assert parser.valid_next_bytes() == set([b"u"]) + parser.consume_bytes(b"u") + assert parser.valid_next_bytes() == set([b"e"]) def test_select_nested(): g = select(["bob", "bill", select(["mark", "mary"])]) - parser = EarleyCommitParser(g) - assert parser.valid_next_bytes() == set([Byte(b"b"), Byte(b"m")]) - parser.consume_byte(b"m") - assert parser.valid_next_bytes() == set([Byte(b"a")]) - parser.consume_byte(b"a") - assert parser.valid_next_bytes() == set([Byte(b"r")]) - parser.consume_byte(b"r") - assert parser.valid_next_bytes() == set([Byte(b"k"), Byte(b"y")]) + parser = ByteParser(g) + assert parser.valid_next_bytes() == set([b"b", b"m"]) + parser.consume_bytes(b"m") + assert parser.valid_next_bytes() == set([b"a"]) + parser.consume_bytes(b"a") + assert parser.valid_next_bytes() == set([b"r"]) + parser.consume_bytes(b"r") + assert parser.valid_next_bytes() == set([b"k", b"y"]) def test_select_joined(): g = select(["bob", "bill"]) + select(["mark", "mary"]) - parser = EarleyCommitParser(g) - assert parser.valid_next_bytes() == set([Byte(b"b")]) - parser.consume_byte(b"b") - assert parser.valid_next_bytes() == set([Byte(b"o"), Byte(b"i")]) - parser.consume_byte(b"i") - assert parser.valid_next_bytes() == set([Byte(b"l")]) - parser.consume_byte(b"l") - assert parser.valid_next_bytes() == set([Byte(b"l")]) - parser.consume_byte(b"l") - assert parser.valid_next_bytes() == set([Byte(b"m")]) - parser.consume_byte(b"m") - assert parser.valid_next_bytes() == set([Byte(b"a")]) - parser.consume_byte(b"a") - assert parser.valid_next_bytes() == set([Byte(b"r")]) - parser.consume_byte(b"r") - assert parser.valid_next_bytes() == set([Byte(b"k"), Byte(b"y")]) + parser = ByteParser(g) + assert parser.valid_next_bytes() == set([b"b"]) + parser.consume_bytes(b"b") + assert parser.valid_next_bytes() == set([b"o", b"i"]) + parser.consume_bytes(b"i") + assert parser.valid_next_bytes() == set([b"l"]) + parser.consume_bytes(b"l") + assert parser.valid_next_bytes() == set([b"l"]) + parser.consume_bytes(b"l") + assert parser.valid_next_bytes() == set([b"m"]) + parser.consume_bytes(b"m") + assert parser.valid_next_bytes() == set([b"a"]) + parser.consume_bytes(b"a") + assert parser.valid_next_bytes() == set([b"r"]) + parser.consume_bytes(b"r") + assert parser.valid_next_bytes() == set([b"k", b"y"]) def test_char_set(): g = char_set("b-f") - parser = EarleyCommitParser(g) - assert parser.valid_next_bytes() == set([ByteRange(b"bf")]) - parser.consume_byte(b"b") + parser = ByteParser(g) + assert parser.valid_next_bytes() == {bytes([i]) for i in range(ord("b"), ord("f") + 1)} + parser.consume_bytes(b"b") def test_byte_mask_char_set(): g = char_set("b-f") - parser = EarleyCommitParser(g) + parser = ByteParser(g) m = parser.next_byte_mask() for i in range(256): if ord(b"b") <= i <= ord(b"f"): @@ -102,7 +101,7 @@ def test_byte_mask_char_set(): def test_byte_mask_char_set2(): g = char_set("bf") - parser = EarleyCommitParser(g) + parser = ByteParser(g) m = parser.next_byte_mask() for i in range(256): if i == ord(b"b") or i == ord(b"f"): @@ -113,21 +112,21 @@ def test_byte_mask_char_set2(): def test_char_set_one_or_more(): g = one_or_more(char_set("b-f")) - parser = EarleyCommitParser(g) - assert parser.valid_next_bytes() == set([ByteRange(b"bf")]) - parser.consume_byte(b"b") - assert parser.valid_next_bytes() == set([ByteRange(b"bf")]) - parser.consume_byte(b"b") - assert parser.valid_next_bytes() == set([ByteRange(b"bf")]) - parser.consume_byte(b"f") - assert parser.valid_next_bytes() == set([ByteRange(b"bf")]) + parser = ByteParser(g) + assert parser.valid_next_bytes() == {bytes([i]) for i in range(ord("b"), ord("f") + 1)} + parser.consume_bytes(b"b") + assert parser.valid_next_bytes() == {bytes([i]) for i in range(ord("b"), ord("f") + 1)} + parser.consume_bytes(b"b") + assert parser.valid_next_bytes() == {bytes([i]) for i in range(ord("b"), ord("f") + 1)} + parser.consume_bytes(b"f") + assert parser.valid_next_bytes() == {bytes([i]) for i in range(ord("b"), ord("f") + 1)} def test_string_utf8(): b = bytes("¶", encoding="utf8") g = string("¶") - parser = EarleyCommitParser(g) - assert parser.valid_next_bytes() == set([Byte(b[:1])]) - parser.consume_byte(b[:1]) - assert parser.valid_next_bytes() == set([Byte(b[1:])]) - parser.consume_byte(b[1:]) + parser = ByteParser(g) + assert parser.valid_next_bytes() == set([b[:1]]) + parser.consume_bytes(b[:1]) + assert parser.valid_next_bytes() == set([b[1:]]) + parser.consume_bytes(b[1:]) diff --git a/tests/unit/test_protobuf.py b/tests/unit/test_protobuf.py deleted file mode 100644 index 74deea034..000000000 --- a/tests/unit/test_protobuf.py +++ /dev/null @@ -1,73 +0,0 @@ -import pytest -from itertools import chain - -from guidance import ( - byte_range, - char_range, - commit_point, - select, - string, - token_limit, - with_temperature, -) -from guidance._grammar import ( - Byte, - ByteRange, - GrammarFunction, - Join, - ModelVariable, - Select, -) - - -def compare_grammars(g1: GrammarFunction, g2: GrammarFunction) -> bool: - """Recursively compare two GrammarFunction objects for equivalence.""" - - if type(g1) != type(g2): - return False - - # Compare attributes based on type - if isinstance(g1, (Byte, ByteRange, ModelVariable)): - slots = chain.from_iterable(getattr(cls, '__slots__', []) for cls in type(g1).mro()) - return all(getattr(g1, slot) == getattr(g2, slot) for slot in slots) - elif isinstance(g1, (Join, Select)): - slots = chain.from_iterable(getattr(cls, '__slots__', []) for cls in type(g1).mro()) - return (all(getattr(g1, slot) == getattr(g2, slot) for slot in slots if 'values' not in slot) - and len(g1.values) == len(g2.values) # Check both have same number of child nodes - and all(compare_grammars(v1, v2) for v1, v2 in zip(g1.values, g2.values)) # Recursively compare child nodes - ) - else: - raise ValueError(f"Unsupported grammar type: {type(g1)}") - - -@pytest.mark.parametrize( - "grammar", - [ - string("Hello, world!"), - Byte(b"a"), - byte_range(b"\x00", b"\xff"), - char_range("a", "z"), - select(["option1", "option2", "option3"]), - commit_point(string("commit"), hidden=True), - token_limit(string("limited"), max_tokens=5), - with_temperature(string("temp"), temperature=0.5), - ModelVariable("my_variable"), - Join([string("part1"), string("part2")]), - select( - [ - string("option1"), - Join([string("part1"), string("part2")]), - ] - ), - ], -) -def test_grammar_protobuf_roundtrip(grammar: GrammarFunction): - """Test that grammars can be round-tripped through protobuf serialization.""" - serialized_grammar = grammar.serialize() - deserialized_grammar = GrammarFunction.deserialize(serialized_grammar) - - # Recursively compare the grammars - assert compare_grammars( - grammar, deserialized_grammar - ), f"Deserialized grammar does not match original:\nOriginal: {grammar}\nDeserialized: {deserialized_grammar}\n" - diff --git a/tests/utils.py b/tests/utils.py index 084597925..40fb49595 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,13 +1,13 @@ import os -from typing import Any, Set, Union, Protocol +from typing import Set, Optional, Protocol import pytest from huggingface_hub import hf_hub_download import guidance from guidance import models -from guidance._grammar import Byte, ByteRange, GrammarFunction, Join -from guidance._parser import ParserException +from guidance._grammar import GrammarFunction, Join +from guidance._parser import ByteParserException opanai_model_cache = {} @@ -154,7 +154,7 @@ def check_match_failure( bad_string: str, good_bytes: bytes, failure_byte: bytes, - allowed_bytes: Union[Set[Union[Byte, ByteRange]], None], + allowed_bytes: Optional[Set[bytes]], grammar: GrammarFunction, ): """ @@ -164,14 +164,13 @@ def check_match_failure( allowed_bytes is allowed to be None, since it could be really complicated """ - with pytest.raises(ParserException) as pe: + with pytest.raises(ByteParserException) as pe: grammar.match(bad_string, raise_exceptions=True) - assert pe.value.consumed_bytes[:-1] == good_bytes + assert pe.value.consumed_bytes == good_bytes assert pe.value.current_byte == failure_byte if allowed_bytes is not None: assert pe.value.allowed_bytes == allowed_bytes - class GrammarFunctionCallable(Protocol): """ Protocol for a callable that returns a GrammarFunction and accepts @@ -185,12 +184,13 @@ def generate_and_check( grammar_callable: GrammarFunctionCallable, test_string: str, capture_key="my_capture", - stop_char: str = chr(7), + eos_token = "", ) -> models.Mock: # First, validate that the grammar actually accepts the test string grammar = grammar_callable(name=capture_key) match = grammar.match(test_string) - assert match.captures[capture_key].decode() == test_string + assert match is not None + assert match.captures[capture_key] == test_string # The next part is to prevent intermittent test failures # when the temperature is non-zero @@ -201,8 +201,8 @@ def generate_and_check( # with our round trip check. # So append a 'stop' character which we don't # use in any of our tests - assert stop_char not in test_string, f"stop_char {stop_char!r} in string" - prepared_string = f"{test_string}{stop_char}" + assert eos_token not in test_string, f"eos_token {eos_token!r} in string" + prepared_string = f"{eos_token}{test_string}{eos_token}" lm = models.Mock(prepared_string.encode()) # Run with the mock model