Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Allow @guidance to decorate methods #1035

Open
wants to merge 56 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
0b9b1f5
make a guidance_metho decorator
hudson-ai Sep 21, 2024
943a801
handle partial one level up
hudson-ai Sep 21, 2024
1e67aaf
dedent up the stack so we can do it for methods too
hudson-ai Sep 21, 2024
19882a2
cache up the stack so we capture the self arg in guidance_method
hudson-ai Sep 21, 2024
eaa875f
add failing guidance_method tests
hudson-ai Sep 23, 2024
657dbc0
drop cache arg
hudson-ai Sep 23, 2024
3edf627
export guidance_method
hudson-ai Sep 23, 2024
00b13bd
move partial to where caching and dedenting happens...
hudson-ai Sep 23, 2024
14e478b
move reference to closure since MethodType has slots
hudson-ai Sep 23, 2024
c66c1c0
remove guidance_method by making consolidated GuidanceDecorator class
hudson-ai Sep 23, 2024
f969a50
add separate GuidanceMethod subclass
hudson-ai Sep 23, 2024
c4a5ee9
cache wrapper so we can do recursion
hudson-ai Sep 23, 2024
364c49f
wraps
hudson-ai Sep 23, 2024
866f1c0
modify signature of GuidanceFunction, not wrapped function
hudson-ai Sep 23, 2024
fa88413
Make method caching tests more extensive
hudson-ai Sep 23, 2024
67d541e
Test that dedenting works with methods
hudson-ai Sep 23, 2024
0be7ea5
Update comment
hudson-ai Sep 23, 2024
41b1be5
reprs
hudson-ai Sep 23, 2024
91ff5cf
copy more metadata with strip_multiline_string_indents
hudson-ai Sep 23, 2024
90fb9ef
add explicit tests for recursion issues (failing for method)
hudson-ai Sep 23, 2024
40417f2
fix method recursion
hudson-ai Sep 23, 2024
e8afc63
Merge branch 'main' into guidance_method
hudson-ai Sep 24, 2024
556b4da
move cache down into __init__ so we cache either functions or bound m…
hudson-ai Sep 24, 2024
f48e8df
Use a weak cache for GuidanceFunction.__get__
hudson-ai Sep 24, 2024
c4fba12
Add tests for GuidanceMethod garbage collection
hudson-ai Sep 24, 2024
0a42cc9
Move cache back up
hudson-ai Sep 24, 2024
ffa354d
Clunky fix to get gc working
hudson-ai Sep 24, 2024
7379191
Clean up a bit
hudson-ai Sep 24, 2024
98f80b4
Add tests to make sure we got the signature right
hudson-ai Sep 24, 2024
acc939c
Cleaner
hudson-ai Sep 24, 2024
1fd4595
signature_pop helper util
hudson-ai Sep 24, 2024
1296183
improve signature tests
hudson-ai Sep 24, 2024
2d0f98d
drop unused owner
hudson-ai Sep 24, 2024
e26cb76
additional cache miss test
hudson-ai Sep 24, 2024
bafa37f
add failing test (good, I was confused about why this worked)
hudson-ai Sep 25, 2024
71aa3d8
Move cache back down but be more careful this time
hudson-ai Sep 25, 2024
64e95e1
Make sure to invalidate the cache when the instance's hash changes
hudson-ai Sep 25, 2024
480479e
Something terrifying is going on
hudson-ai Sep 25, 2024
5c5dcbf
Only the return was needed. Still terrifying.
hudson-ai Sep 25, 2024
7410338
Resolve existential terror
hudson-ai Sep 25, 2024
0a06899
Use WeakMethod
hudson-ai Sep 25, 2024
486cbe3
Fix comment
hudson-ai Sep 25, 2024
308a7c8
make sure our _methods dict acts like a WeakKeyDictionary
hudson-ai Sep 25, 2024
42078b1
Remove extra test case (found and fixed why I couldn't replicate issu…
hudson-ai Sep 25, 2024
307a371
xfail because reference cycles are hard
hudson-ai Sep 25, 2024
d294cfa
Solve so many problems by taking heavy inspiration from WeakMethod
hudson-ai Sep 25, 2024
05e259a
Move impl cache to GuidanceMethod class
hudson-ai Sep 25, 2024
e2bcd2c
Fix repr
hudson-ai Sep 25, 2024
fd9ff5f
comment
hudson-ai Sep 25, 2024
1fcf4a5
fix finalizer
hudson-ai Sep 25, 2024
bf132eb
more informative error in make_weak_bound_method; no need for weak re…
hudson-ai Sep 25, 2024
ddd7fac
comment
hudson-ai Sep 25, 2024
769bca0
words
hudson-ai Sep 26, 2024
ef2d7c2
fix comment
hudson-ai Sep 27, 2024
07d58c1
comment
hudson-ai Sep 27, 2024
612a810
Merge branch 'main' into guidance_method
hudson-ai Oct 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 137 additions & 65 deletions guidance/_guidance.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import functools
import inspect
import threading
from typing import Any
import weakref

from ._grammar import DeferredReference, RawFunction, Terminal, string
from ._utils import strip_multiline_string_indents
from ._utils import strip_multiline_string_indents, make_weak_bound_method, signature_pop
from .models import Model


Expand All @@ -16,85 +18,155 @@ def guidance(
model = Model,
):
"""Decorator used to define guidance grammars"""
return _decorator(f, stateless=stateless, cache=cache, dedent=dedent, model=model)
# if we are not yet being used as a decorator, then save the args

if f is None:
return functools.partial(
guidance, stateless=stateless, cache=cache, dedent=dedent, model=model,
)

_null_grammar = string("")
# this strips out indentation in multiline strings that aligns with the current python indentation
if dedent is True or dedent == "python":
f = strip_multiline_string_indents(f)

return GuidanceFunction(f, stateless=stateless, cache=cache, model=model)


class GuidanceFunction:
def __init__(
self,
f,
*,
stateless = False,
cache = False,
model = Model,
):
self.f = f
self.stateless = stateless
self.cache = cache
self.model = model
self._impl = _decorator(f, stateless=stateless, cache=cache, model=model)
self._methods: dict[Any, GuidanceMethod] = {}

# Update self with the wrapped function's metadata
functools.update_wrapper(self, self._impl)
# Pretend to be one level of wrapping lower than we are
self.__wrapped__ = self._impl.__wrapped__

def __call__(self, *args, **kwargs):
return self._impl(*args, **kwargs)

def __get__(self, instance, owner=None, /):
"""
Return a GuidanceMethod bound to the instance.
"""
if instance is None:
return self
return GuidanceMethod.from_guidance_function(self, instance)

def __repr__(self):
return f"<GuidanceFunction {self.__module__}.{self.__qualname__}{self.__signature__}>"

class GuidanceMethod:
impl_cache = {}
def __init__(self, impl, instance):
# Make object that looks like a method (__self__ and __func__) in order to be able to better support weak referencing via weakref.WeakMethod
# Note we keep a hard reference to the instance to keep it (and therefore our cached impl) alive as long as we are alive
self.__self__ = instance
self.__func__ = impl

# Update self with the wrapped function's metadata
functools.update_wrapper(self, impl)
# Pretend to be one level of wrapping lower than we are
self.__wrapped__ = impl.__wrapped__

@classmethod
def from_guidance_function(cls, guidance_function: GuidanceFunction, instance: Any) -> "GuidanceMethod":
# We can't directly use a weakref.WeakKeyDictionary because those don't really work when the key objects
# are allowed to change their hash value.

# Instead use instance hash in addition to identity to make sure we miss the cache if the instance is meaningfully mutated.
# This should be safe because an id will only be reused after the original object is garbage collected, at which point we
# should have removed the cache entry (since we use weakref.finalize to remove the cache entry when the instance is deleted).
key = (guidance_function.f, hash(instance), id(instance))
try:
impl = cls.impl_cache[key]
except KeyError:
# Make a weak bound method to prevent the instance from being kept alive by the cache
weak_method = make_weak_bound_method(guidance_function.f, instance)
impl = _decorator(weak_method, stateless=guidance_function.stateless, cache=guidance_function.cache, model=guidance_function.model)
cls.impl_cache[key] = impl
# Clean up the cache when the instance is deleted
weakref.finalize(instance, cls.impl_cache.pop, key)
return cls(impl, instance)

def __call__(self, *args, **kwargs):
return self.__func__(*args, **kwargs)

def __repr__(self):
return f"<bound GuidanceMethod {self.__qualname__} of {self.__self__!r}>"


def _decorator(f, *, stateless, cache, dedent, model):
_null_grammar = string("")

# if we are not yet being used as a decorator, then save the args
if f is None:
return functools.partial(
_decorator, stateless=stateless, cache=cache, dedent=dedent, model=model
)

# if we are being used as a decorator then return the decorated function
else:
def _decorator(f, *, stateless, cache, model):
# we cache the function itself if requested
# do this before updating the wrapper so we can maintain the __wrapped__ chain
if cache:
f = functools.cache(f)

# this strips out indentation in multiline strings that aligns with the current python indentation
if dedent is True or dedent == "python":
f = strip_multiline_string_indents(f)
# Use thread local to store the reference to the grammar node for recursive calls
# Otherwise, shared state between threads may otherwise trick us into thinking we are in a recursive call
thread_local = threading.local()

# we cache if requested
if cache:
f = functools.cache(f)
@functools.wraps(f)
def wrapped(*args, **kwargs):

# Use thread local to store the reference to the grammar node for recursive calls
# Otherwise, shared state between threads may otherwise trick us into thinking we are in a recursive call
thread_local = threading.local()
# make a stateless grammar if we can
if stateless is True or (
callable(stateless) and stateless(*args, **kwargs)
):

@functools.wraps(f)
def wrapped(*args, **kwargs):
# if we have a (deferred) reference set, then we must be in a recursive definition and so we return the reference
reference = getattr(thread_local, "_self_call_reference_", None)
if reference is not None:
return reference

# make a stateless grammar if we can
if stateless is True or (
callable(stateless) and stateless(*args, **kwargs)
):
# otherwise we call the function to generate the grammar
else:

# if we have a (deferred) reference set, then we must be in a recursive definition and so we return the reference
reference = getattr(thread_local, "_self_call_reference_", None)
if reference is not None:
return reference
# set a DeferredReference for recursive calls (only if we don't have arguments that might make caching a bad idea)
no_args = len(args) + len(kwargs) == 0
if no_args:
thread_local._self_call_reference_ = DeferredReference()

# otherwise we call the function to generate the grammar
try:
# call the function to get the grammar node
node = f(_null_grammar, *args, **kwargs)
except:
raise
else:

# set a DeferredReference for recursive calls (only if we don't have arguments that might make caching a bad idea)
no_args = len(args) + len(kwargs) == 0
if not isinstance(node, (Terminal, str)):
node.name = f.__name__
# set the reference value with our generated node
if no_args:
thread_local._self_call_reference_ = DeferredReference()

try:
# call the function to get the grammar node
node = f(_null_grammar, *args, **kwargs)
except:
raise
else:
if not isinstance(node, (Terminal, str)):
node.name = f.__name__
# set the reference value with our generated node
if no_args:
thread_local._self_call_reference_.value = node
finally:
if no_args:
del thread_local._self_call_reference_

return node

# otherwise must be stateful (which means we can't be inside a select() call)
else:
return RawFunction(f, args, kwargs)
thread_local._self_call_reference_.value = node
finally:
if no_args:
del thread_local._self_call_reference_

return node

# Remove the first argument from the wrapped function
signature = inspect.signature(f)
params = list(signature.parameters.values())
params.pop(0)
wrapped.__signature__ = signature.replace(parameters=params)
# otherwise must be stateful (which means we can't be inside a select() call)
else:
return RawFunction(f, args, kwargs)

# Remove the first argument from the wrapped function since we're going to drop the `lm` argument
wrapped.__signature__ = signature_pop(inspect.signature(f), 0)

# attach this as a method of the model class (if given)
# if model is not None:
# setattr(model, f.__name__, f)
# attach this as a method of the model class (if given)
# if model is not None:
# setattr(model, f.__name__, f)

return wrapped
return wrapped
26 changes: 26 additions & 0 deletions guidance/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import sys
import textwrap
import types
import weakref
import functools

import numpy as np

Expand Down Expand Up @@ -113,8 +115,32 @@ def strip_multiline_string_indents(f):
closure=f.__closure__,
)
new_f.__kwdefaults__ = f.__kwdefaults__
new_f.__qualname__ = f.__qualname__
new_f.__annotations__ = f.__annotations__
new_f.__doc__ = f.__doc__
new_f.__module__ = f.__module__
return new_f

def make_weak_bound_method(f, instance):
instance_ref = weakref.ref(instance)
instance_repr = repr(instance)
@functools.wraps(f) # ish
def weak_bound_f(*args, **kwargs):
instance = instance_ref()
if instance is None:
raise ReferenceError(f"Lost reference to {instance_repr} and cannot bind {f} to it.")
method = types.MethodType(f, instance)
return method(*args, **kwargs)

# remove the first argument from the wrapped function since it is now bound
weak_bound_f.__signature__ = signature_pop(inspect.signature(f), 0)
return weak_bound_f

def signature_pop(signature, index):
params = list(signature.parameters.values())
params.pop(index)
return signature.replace(parameters=params)

class CaptureEvents:
"""Creates a scope where all the events are captured in a queue.

Expand Down
Loading
Loading