Skip to content

Commit

Permalink
[WIP] Replace EarleyParser with lexeme-based rust implementation (#951)
Browse files Browse the repository at this point in the history
To preface this all, I want to note that the overall user experience of
`guidance` should change minimally if at all.
This is primarily a behind-the-scenes change that offers
1. Speedups (sometimes 6x, at least on my machine!) over the current
implementation of `guidance`
1. Some new pieces of plumbing (not yet a part of the public API) that
aim to simplify development of programming language (PL) grammars
1. Simplification of the core LLM-parser loop that makes our codebase a
little less scary

# `llguidance`
- [llguidance](https://github.com/microsoft/llguidance) is a library
designed to handle the "low level" components of `guidance` (mainly the
implementation of the parser and its interactions with the tokenizer)
- These components have been pulled out of `guidance` so that they can
be reused in the server-side implementation of the `AzureGuidance`
endpoints (ensuring that behavior between these remote models and local
`guidance`-controlled models are as consistent as possible)

# `guidance`

## Parser
- Removed `EarleyParser` class that operated at the **byte** level
- Byte-by-byte consumption (usage: search token trie, using parser to
validate acceptable tokens)
    - Parse-tree built over individual bytes
- Introduced `TokenParser` class that operates at the **token** level
    - Token-by-token consumption (usage: directly compute token mask)
- Parse-tree built over **lexemes** (more on these later, but larger
chunks of work than bytes)
    - Still an Earley parser under the hood
- Light wrapper around rust-based `LLInterpreter` class from
`llguidance`
- Introduced `ByteParser` class which wraps a `TokenParser` and a
`ByteTokenizer` (a tokenizer with tokens directly corresponding to
individual bytes plus a `BOS`/`EOS` token)
- Only used to replicate old byte-by-byte consumption and to give
grammars a `match` method
- These features only used in test suite (attempted to make changes to
tests as minimal as possible)
- Hope to deprecate this class in the future (important bits could maybe
be subsumed into `Mock` model class?)

## Engine
- Removed token-trie search from `Engine` class in favor of directly
using the mask from the `TokenParser`
- A lot of code got deleted, hopefully making this implementation a lot
less scary to newcomers (even if the scary bits just got pushed down
into rust...)
- Current guidance allows grammars that are in an "accepting state" to
sample a token *without constraints*. If the sampled token is not
accepted by the grammar, then it will be treated as an EOS token and
terminate the grammar (@Harsha-Nori @slundberg maybe one of you can
confirm that I have this right?).
This PR keeps this behavior, but note **THIS LIKELY DOES NOT ALIGN WITH
AZURE SERVER-SIDE IMPLEMENTATION**. To ensure that this behavior is
maintained (maybe worth a discussion), it should be moved into
llguidance (@mmoskal)

## Serialization
- The `LLInterpreter` underlying the `TokenParser` expects
JSON-serialized grammars
- This serialization format is consistent with the format expected by
remote `AzureGuidance` endpoints
- The `RemoteEngine` now expects this serialization format (no more
protobuf)
- `LLInterpreter` returns JSON-serialized data (not exactly the same as
what's returned by `AzureGuidance` endpoints, but there is a lot of
shared structure)
- New file `_schema.py` contains `pydantic` schemas used to
validate/parse these response structures
- `EngineCallResponse` now JSON serialized/validated with `pydantic` (no
more protobuf)

## New primitives:
- `Gen`
- `Lexeme`
- `Subgrammar`
- `RegularGrammar`

To understand the new primitives, we need to understand how the new
parser is different from the old one. While both are Earley parsers that
support general context-free grammars, the smallest "atoms" that the new
parser works with are more coarse-grained than the old one. The new
parser works with **lexemes**, while the old one works with **bytes**.

Roughly, lexemes correspond to *regular expressions* (and string
literals). These are larger chunks of text, making the parser more
efficient in a lot of cases. Because lexemes are regular, the lexer (the
lexeme sub-parser) can run much more quickly than the outer Earley
parser.

While the Earley parser is able to handle *ambiguities* (e.g
`one_or_more("a") + one_or_more(select("a", "b"))` -- which expression
is responsible for the second "a" in "aab"?), the lexer can't. We need a
deterministic set of rules that tells us how any given string should be
lexed (what lexemes are responsible for what parts of the text).

Lexemes can be **lazy** or **greedy**.
- Lazy:
    - Completes as early as possible
- A lazy lexeme will complete as soon as it matches, e.g. a lazy lexeme
`r"a+"` will only ever generate a single "a" and `r"a*"` will only ever
generate the empty string.
- Greedy:
    - Completes as late as possible
- A greedy lexeme will not complete until it *fails* to match, e.g. a
greedy lexeme `r"a+"` can produce as many "a"s as it wants before moving
on to the next lexeme.

### Gen
`gen` is composed of two sub-expressions - the "body" regex and the
(optional) "stop" regex. If no body regex is passed, it defaults to
`r"(?s:.*)"`, i.e. `.*` that additionally matches the newline character.

When the stop regex is provided, `gen` behaves as a **lazy** lexeme. As
soon as the full body+stop regex matches the generated text, we exit the
`gen` (discarding the "stop" text) and move on to the next lexeme. This
ensures that gen actually stops when the stop expression is produced.

When no `stop` regex is provided, it behaves as a **greedy** lexeme
(with one caveat -- it can be terminated by an `EOS` token, which stops
with lazy semantics). Note that `regex` is now just an alias for `gen`
with no stop expression.

Examples:
- `gen(regex=r"[0-9]+") + "xyz"`
    - No stop provided, so the gen is greedy.
- Must produce at least one digit, after which it is allowed to either
produce more digits, the EOS token, or the letter "x".
    - If it produces "x", then the string "yz" will be forced.
- If it produces the EOS token (or hits max_tokens), then the string
"xyz" will be forced (resulting string will NOT have the EOS -- it will
be dropped).
- `gen() + "xyz"`
    - No stop provided, so the gen is greedy.
    - Implicitly allowed to generate anything.
- Generating "x" will NOT force "yz", since generating "x" does not
terminate the gen (`r"(?s:.*)"` has not yet failed to match). Note that
current guidance ALSO won't force "yz", as the parser will be in a
"superposition" that doesn't know whether or not the `gen` has
completed.
    - Only an EOS (or hitting max_tokens) can terminate the gen
    - **GOTCHA** (difference from current guidance):
- Current guidance: any string ending in "xyz" is *allowed* to terminate
the grammar (i.e. an EOS is allowed but not forced). In practice, this
probably won't happen often.
- This PR: the gen must first be terminated, at which point an "xyz"
will be forced and the overall grammar will terminate
- `gen(regex=r"[0-9]+") + "123"`
    - Greedy
    - Generating "1" will not force "23"...
- (Same "gotcha" as above): only way to complete this expression is for
the model to generate at least one number then the EOS (or it hits the
token limit), at which poing "123" will be forced
- `gen(regex=r"[0-9]+", stop="123")`
    - Lazy
- Will produce any number of digits, terminating as soon as "123" is
generated (e.g. "2346452341412134123")

The subtle changes around EOS *should* be a fairly small detail . 

### Lexeme
Not (yet) part of "public" api (available via `guidance._grammar.Lexeme`
or `guidance.library._subgrammar.lexeme`). Should only really be used
when writing `Subgrammar`s / translating EBNF.

- Consist of a single regular expression
- They are always greedy.
- Model not allowed EOS as an "out" (unless we are at the end of a
grammar, of course)
- TODO: `lexeme` to support a `contextual` flag (more on that later)

### Subgrammar
Not (yet) part of "public" api (available via
`guidance._grammar.Subgrammar` or
`guidance.library._subgrammar.subgrammar`). Mostly exists to better
support generating programming languages.

- Wraps a guidance grammar, which will then be treated as
"atomic"/"terminal" from the perspective of the outer grammar's Earley
parser (i.e. treated as a greedy lexeme).
- Can be terminated by an EOS or by generating non-matching string (e.g.
if `json` is a subgrammar, `json() + "```"` will terminate if a backtick
is generated after some valid JSON)
- `"ignore_regex"` kwarg specifies a regular expression that will be
"ignored" between lexemes. This can be used to allow flexible whitespace
when generating JSON or code, for example.
- Non-contextual lexemes given priority whenever any lexeme is being
generated (e.g. to support keywords in PLs)
- Note: `json` has been reimplemented as a `Subgrammar`.

### RegularGrammar
Not (yet) part of "public" api (available via
`guidance._grammar.RegularGrammar` or
`guidance.library._grammar.as_regular_grammar`)

**NOTE**: "manually" building regex-esque grammars should now be
discouraged
- e.g. `select(["0", char_range("1", "9") + zero_or_more(char_range("0",
"9")])` should be rewritten as `regex(r"0|(?:[1-9][0-9]*)")`

This is because the lexemes here are individual characters, requiring
the expensive Earley parser to run. Rewriting as a `regex` makes the
entire grammar into a lexeme, allowing the cheap lexer to do all the
work.

If directly writing the regex is not possible, `as_regular_grammar`
(name subject to change) can *wrap* a grammar like `select(["0",
char_range("1", "9") + zero_or_more(char_range("0", "9")])` and (try to)
convert it into a regex lexeme. Grammars that are not regular will fail
this construction.

In the future, it would be nice to automatically wrap grammars when we
can, preventing users from having to think about this.

## Deprecations:
- `commit_point` (raises `NotImplementedError`, may reimplement in the
future?)
- was only used in `gen` to support the current "stop" mechanics and in
tool calling
- `gen` does not support tool calling (working on this, hopefully will
have something working before this PR goes through)

## Biggest gotchas / changes from current `guidance`
- `regex(r"\d*") + "7"`
- Current guidance is allowed to emit an EOS after any sequence of
digits ending in "7"
- Under this PR, guidance is allowed to emit an EOS after any sequence
of digits, at which point a "7" will be forced
- `r"\d"` now matches unicode digits

## TODOs
- server-side engine variables
- EOS can't be explicitly referenced in the grammar, only implicitly at
the end of `Gen`s or `Subgrammar`s
- stopping gen on active role end
    - Should be more trivial with server-side engine variables

---------

Co-authored-by: Michal Moskal <[email protected]>
  • Loading branch information
hudson-ai and mmoskal authored Aug 20, 2024
1 parent 7da7df7 commit 17823bf
Show file tree
Hide file tree
Showing 48 changed files with 2,864 additions and 3,090 deletions.
715 changes: 456 additions & 259 deletions guidance/_grammar.py

Large diffs are not rendered by default.

893 changes: 259 additions & 634 deletions guidance/_parser.py

Large diffs are not rendered by default.

120 changes: 120 additions & 0 deletions guidance/_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from typing import Any, Literal, Optional, Union

from pydantic import BaseModel, Field, NonNegativeInt, RootModel, model_validator, computed_field
from typing_extensions import Annotated
from functools import cached_property


class GuidanceEngineMetrics(BaseModel):
engine_input_tokens: NonNegativeInt = 0
engine_output_tokens: NonNegativeInt = 0


class EngineCallResponse(BaseModel):
new_bytes: bytes
is_generated: bool
new_bytes_prob: float
capture_groups: dict
capture_group_log_probs: dict
new_token_count: NonNegativeInt


class GenData(BaseModel):
tokens: list[int]
mask: bytes
temperature: float

@computed_field # type: ignore[misc]
@cached_property
def valid_next_tokens(self) -> list[int]:
return [i for i, b in enumerate(self.mask) if b != 0]


class LLProgressCapture(BaseModel):
object: Literal["capture"]
name: str
hex: str
log_prob: float
list_append: bool = False

@model_validator(mode="before")
def strip_list_append_prefix(cls, values):
name = values["name"]
if name.startswith("__LIST_APPEND:"):
values["name"] = name[14:]
# Override whatever was set
values["list_append"] = True
return values


class LLProgressText(BaseModel):
object: Literal["text"]
hex: str
num_tokens: NonNegativeInt
log_prob: float
is_generated: bool


class LLProgressFinalText(BaseModel):
object: Literal["final_text"]
# we don't need to handle this for now


LLProgressItem = Annotated[
Union[LLProgressCapture, LLProgressText, LLProgressFinalText],
Field(discriminator="object"),
]


class LLProgress(RootModel):
root: list[LLProgressItem]

def to_engine_call_response(self) -> EngineCallResponse:
new_bytes = b""
new_token_count = 0
new_bytes_prob = 0.0
is_generated = False
capture_groups: dict[str, Any] = {}
capture_group_log_probs: dict[str, Any] = {}
num_text_entries = 0

for j in self.root:
if isinstance(j, LLProgressCapture):
is_generated = True
cname = j.name
data = bytes.fromhex(j.hex)
if j.list_append:
if cname not in capture_groups or not isinstance(
capture_groups[cname], list
):
capture_groups[cname] = []
capture_group_log_probs[cname] = []
capture_groups[cname].append(data)
capture_group_log_probs[cname].append(j.log_prob)
else:
capture_groups[cname] = data
capture_group_log_probs[cname] = j.log_prob
elif isinstance(j, LLProgressText):
# it actually should only happen once per round...
new_bytes += bytes.fromhex(j.hex)
new_token_count += j.num_tokens
new_bytes_prob += j.log_prob
is_generated |= j.is_generated
num_text_entries += 1
if num_text_entries > 0:
new_bytes_prob /= num_text_entries

return EngineCallResponse(
new_bytes=new_bytes,
new_token_count=new_token_count,
new_bytes_prob=new_bytes_prob,
is_generated=is_generated,
capture_groups=capture_groups,
capture_group_log_probs=capture_group_log_probs,
)


class LLInterpreterResponse(BaseModel):
progress: LLProgress
stop: bool
temperature: Optional[float]
101 changes: 0 additions & 101 deletions guidance/_serialization.proto

This file was deleted.

54 changes: 0 additions & 54 deletions guidance/_serialization_pb2.py

This file was deleted.

Loading

0 comments on commit 17823bf

Please sign in to comment.