-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Structured model state #929
base: main
Are you sure you want to change the base?
Changes from all commits
ac8f832
75ffd29
1ca3a33
5d86bd1
423489c
86b9904
d88dfb2
6b13846
8d51092
88b4a9c
55b9162
b9fb115
7ab695b
c6af474
ced5147
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,7 @@ | |
|
||
|
||
from pprint import pprint | ||
from typing import Dict, TYPE_CHECKING | ||
from typing import Dict, Optional, TYPE_CHECKING | ||
|
||
|
||
import numpy as np | ||
|
@@ -38,6 +38,8 @@ | |
"Failed to load guidance.cpp, falling back to Python mirror implementations..." | ||
) | ||
from .. import _cpp as cpp | ||
|
||
from ._model_state import ModelState, Text, Object, RoleOpener, RoleCloser | ||
from ._guidance_engine_metrics import GuidanceEngineMetrics | ||
from .._utils import softmax, CaptureEvents | ||
from .._parser import EarleyCommitParser, Parser | ||
|
@@ -62,14 +64,6 @@ | |
|
||
# define some constants we will reuse many times | ||
_null_grammar = string("") | ||
format_pattern = re.compile(r"<\|\|_.*?_\|\|>", flags=re.DOTALL) | ||
nodisp_pattern = re.compile( | ||
r"<\|\|_#NODISP_\|\|>.*?<\|\|_/NODISP_\|\|>", flags=re.DOTALL | ||
) | ||
html_pattern = re.compile(r"<\|\|_html:(.*?)_\|\|>", flags=re.DOTALL) | ||
image_pattern = re.compile(r"<\|_image:(.*?)\|>") | ||
|
||
|
||
|
||
|
||
class EngineCallResponse: | ||
|
@@ -850,57 +844,19 @@ def __init__(self, engine, echo=True, **kwargs): | |
self.echo = echo | ||
self.token_count = 0 # tracks how many tokens our byte state represents | ||
self.max_display_rate = 0.2 # this controls how frequently we are allowed to redraw the display (in seconds) | ||
self.opened_blocks = {} # what context blocks have been opened but not closed | ||
self.opened_blocks: dict["ContextBlock", tuple[int, Optional(Object)]] = {} # what context blocks have been opened but not closed | ||
# self.compute_log_probs = compute_log_probs | ||
|
||
# private attributes | ||
self._variables = {} # these are the state variables stored with the model | ||
self._variables_log_probs = {} # these are the state variables stored with the model | ||
self._cache_state = {} # mutable caching state used to save computation | ||
self._state = "" # the current bytes that represent the state of the model | ||
self._state = ModelState() # the current bytes that represent the state of the model | ||
self._event_queue = None # TODO: these are for streaming results in code, but that needs implemented | ||
self._event_parent = None | ||
self._last_display = 0 # used to track the last display call to enable throttling | ||
self._last_event_stream = 0 # used to track the last event streaming call to enable throttling | ||
|
||
@property | ||
def active_role_end(self): | ||
"""The default end patterns we should use for `gen` calls. | ||
TODO: move this logic into the gen call...we can do with if we allow model_variables to run functions. | ||
|
||
These patterns are computed dynamically by the model object because they can depend on | ||
what the current open roles are, which is something | ||
""" | ||
|
||
# add any active non-empty role ends. Ignore role ends that are spaces | ||
parts = [] | ||
for _, role_end_str in self.opened_blocks.values(): | ||
role_end_str = format_pattern.sub("", role_end_str) | ||
if len(role_end_str) > 0 and not re.fullmatch(r"\s+", role_end_str): | ||
parts.append(role_end_str) | ||
|
||
return select(parts) | ||
|
||
def _html(self): | ||
"""Generate HTML that displays the model object.""" | ||
display_out = self._state | ||
for context in reversed(self.opened_blocks): | ||
display_out += self.opened_blocks[context][1] | ||
display_out = html.escape(display_out) | ||
display_out = nodisp_pattern.sub("", display_out) | ||
display_out = html_pattern.sub(lambda x: html.unescape(x.group(1)), display_out) | ||
display_out = image_pattern.sub( | ||
lambda x: '<img src="data:image/png;base64,' | ||
+ base64.b64encode(self[x.groups(1)[0]]).decode() | ||
+ '" style="max-width: 400px; vertical-align: middle; margin: 4px;">', | ||
display_out, | ||
) | ||
display_out = ( | ||
"<pre style='margin: 0px; padding: 0px; vertical-align: middle; padding-left: 8px; margin-left: -8px; border-radius: 0px; border-left: 1px solid rgba(127, 127, 127, 0.2); white-space: pre-wrap; font-family: ColfaxAI, Arial; font-size: 15px; line-height: 23px;'>" | ||
+ display_out | ||
+ "</pre>" | ||
) | ||
return display_out | ||
|
||
def _send_to_event_queue(self, value): | ||
"""For streaming in code. | ||
|
@@ -924,6 +880,7 @@ def copy(self): | |
new_lm._variables = self._variables.copy() | ||
new_lm._variables_log_probs = self._variables_log_probs.copy() | ||
new_lm.opened_blocks = self.opened_blocks.copy() | ||
new_lm._state = self._state.copy() | ||
|
||
# create a new clean event queue | ||
new_lm._event_queue = None # we start with no event queue because nobody is listening to us yet | ||
|
@@ -938,7 +895,7 @@ def copy(self): | |
|
||
return new_lm | ||
|
||
def _inplace_append(self, value, force_silent=False): | ||
def _inplace_append(self, obj: Object, force_silent: bool = False): | ||
"""This is the base way to add content to the current LM object that is being constructed. | ||
|
||
All updates to the model state should eventually use this function. | ||
|
@@ -951,7 +908,7 @@ def _inplace_append(self, value, force_silent=False): | |
""" | ||
|
||
# update the byte state | ||
self._state += str(value) # TODO: make _state to be bytes not a string | ||
self._state.append(obj) | ||
|
||
# see if we should update the display | ||
if not force_silent: | ||
|
@@ -995,20 +952,30 @@ def reset(self, clear_variables=True): | |
self._variables_log_probs = {} | ||
return self | ||
|
||
def _html(self): | ||
out = self._state._html() | ||
for context in reversed(self.opened_blocks): | ||
_, closer = self.opened_blocks[context] | ||
if closer is not None: | ||
out += closer._html() | ||
return out | ||
|
||
def _repr_html_(self): | ||
if ipython_is_imported: | ||
clear_output(wait=True) | ||
return self._html() | ||
|
||
def _current_prompt(self): | ||
def _current_prompt(self) -> str: | ||
"""The current prompt in bytes (which is the state without the context close tags).""" | ||
return format_pattern.sub("", self._state) | ||
return str(self._state) | ||
|
||
def __str__(self): | ||
"""A string representation of the current model object (that includes context closers).""" | ||
out = self._current_prompt() | ||
out = str(self._state) | ||
for context in reversed(self.opened_blocks): | ||
out += format_pattern.sub("", self.opened_blocks[context][1]) | ||
_, closer = self.opened_blocks[context] | ||
if closer is not None: | ||
out += str(closer) | ||
return out | ||
|
||
def __add__(self, value): | ||
|
@@ -1019,6 +986,8 @@ def __add__(self, value): | |
value : guidance grammar | ||
The grammar used to extend the current model. | ||
""" | ||
# Import in function to guard against circular import | ||
from ..library._role import RoleBlock | ||
Comment on lines
+989
to
+990
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have a feeling that some may object to the import inside of a function... Is anyone else better at resolving issues with circular imports? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, we haven't tried turning on |
||
|
||
# create the new lm object we will return | ||
# (we need to do this since Model objects are immutable) | ||
|
@@ -1034,7 +1003,7 @@ def __add__(self, value): | |
new_blocks.append(context) | ||
|
||
# mark this so we don't re-add when computing the opener or closer (even though we don't know the close text yet) | ||
lm.opened_blocks[context] = (0, "") | ||
lm.opened_blocks[context] = (0, None) | ||
|
||
# find what old blocks need to be removed | ||
old_blocks = [] | ||
|
@@ -1046,20 +1015,44 @@ def __add__(self, value): | |
del lm.opened_blocks[context] | ||
|
||
# close any newly closed contexts | ||
for (pos, close_text), context in old_blocks: | ||
for (pos, closer), context in old_blocks: | ||
assert closer is not None | ||
if context.name is not None: | ||
lm._variables[context.name] = format_pattern.sub( | ||
"", lm._state[pos:] | ||
) | ||
lm += context.closer | ||
# Capture | ||
lm._variables[context.name] = str(lm._state[pos:]) | ||
lm._inplace_append(closer) | ||
|
||
# apply any newly opened contexts (new from this object's perspective) | ||
for context in new_blocks: | ||
lm += context.opener | ||
if isinstance(context, RoleBlock): | ||
# Apply the opener (a grammar) | ||
with grammar_only(): | ||
# TODO: be careful about the temp lm's display? (e.g. with silent()) | ||
tmp = lm + context.opener | ||
open_text = str(tmp._state[len(lm._state):]) # get the new state added by calling the opener | ||
# Add that new state as text in a RoleOpener | ||
lm._inplace_append( | ||
RoleOpener( | ||
role_name=context.role_name, | ||
text=open_text, | ||
indent=getattr(lm, "indent_roles", True) | ||
) | ||
) | ||
else: | ||
lm += context.opener | ||
with grammar_only(): | ||
# TODO: be careful about the temp lm's display? (e.g. with silent()) | ||
tmp = lm + context.closer | ||
close_text = tmp._state[len(lm._state):] # get the new state added by calling the closer | ||
lm.opened_blocks[context] = (len(lm._state), close_text) | ||
close_text = str(tmp._state[len(lm._state):]) # get the new state added by calling the closer | ||
if isinstance(context, RoleBlock): | ||
closer = RoleCloser( | ||
role_name=context.role_name, | ||
text=close_text, | ||
indent=getattr(lm, "indent_roles", True) | ||
) | ||
else: | ||
closer = Text(text=close_text) | ||
lm.opened_blocks[context] = (len(lm._state), closer) | ||
|
||
# clear out names that we override | ||
if context.name is not None: | ||
|
@@ -1075,7 +1068,9 @@ def __add__(self, value): | |
|
||
# we have no embedded objects | ||
if len(parts) == 1: | ||
lm._inplace_append(value) | ||
lm._inplace_append( | ||
Text(text=value) | ||
) | ||
out = lm | ||
|
||
# if we have embedded objects we have to convert the string to a grammar tree | ||
|
@@ -1119,7 +1114,8 @@ def __add__(self, value): | |
) | ||
|
||
# this flushes the display | ||
out._inplace_append("") | ||
# TODO: directly call _update_display? | ||
out._inplace_append(Text(text="")) | ||
|
||
return out | ||
|
||
|
@@ -1147,9 +1143,8 @@ def __getitem__(self, key): | |
else: | ||
for context in list(reversed(self.opened_blocks)): | ||
if context.name == key: | ||
return format_pattern.sub( | ||
"", self._state[self.opened_blocks[context][0] :] | ||
) | ||
pos, _ = self.opened_blocks[context] | ||
return str(self._state[pos:]) | ||
|
||
raise KeyError(f"Model does not contain the variable '{key}'") | ||
|
||
|
@@ -1328,11 +1323,19 @@ def _run_stateless(self, stateless_function, temperature=0.0, top_p=1.0, n=1): | |
if len(chunk.new_bytes) > 0: | ||
generated_value += new_text | ||
if chunk.is_generated: | ||
lm += f"<||_html:<span style='background-color: rgba({165*(1-chunk.new_bytes_prob) + 0}, {165*chunk.new_bytes_prob + 0}, 0, {0.15}); border-radius: 3px;' title='{chunk.new_bytes_prob}'>_||>" | ||
lm += new_text | ||
if chunk.is_generated: | ||
lm += "<||_html:</span>_||>" | ||
|
||
self._inplace_append( | ||
Text( | ||
text = new_text, | ||
# TODO: this will be slightly wrong if we have a delayed byte string | ||
probability = chunk.new_bytes_prob | ||
) | ||
) | ||
else: | ||
self._inplace_append( | ||
Text( | ||
text = new_text, | ||
) | ||
) | ||
# last_is_generated = chunk.is_generated | ||
|
||
if len(chunk.capture_groups) > 0: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation of
active_role_end
will actually produce the closing tag twice since it is essentially a stateless grammar and can't popContextBlock
s off of a model. Furthermore, generating active_role_end requires the model to produce special tags which will be hidden from future prompts... A better implementation of this may exist, but I'm not quite sure yet. Deleting it for now