Skip to content

Commit

Permalink
@overload the guidance decorator in order to improve static type-…
Browse files Browse the repository at this point in the history
…checking (#1014)

Improve annotations to the `guidance` decorator via a stub file making use of `typing.overload`.
  • Loading branch information
hudson-ai authored Sep 11, 2024
1 parent 35591d8 commit 003917c
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 16 deletions.
4 changes: 2 additions & 2 deletions guidance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import types

from . import models
from ._guidance import _decorator, guidance
from ._guidance import guidance

from ._grammar import (
RawFunction,
Expand All @@ -21,7 +21,7 @@ class _Guidance(types.ModuleType):
def __call__(
self, f=None, *, stateless=False, cache=None, dedent=True, model=models.Model
):
return _decorator(
return guidance(
f, stateless=stateless, cache=cache, dedent=dedent, model=model
)

Expand Down
18 changes: 13 additions & 5 deletions guidance/_guidance.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import functools
import inspect

from . import models
from ._grammar import RawFunction, Terminal, string, DeferredReference
from ._grammar import DeferredReference, RawFunction, Terminal, string
from ._utils import strip_multiline_string_indents


def guidance(f=None, *, stateless=False, cache=None, dedent=True, model=models.Model):
from .models import Model


def guidance(
f = None,
*,
stateless = False,
cache = False,
dedent = True,
model = Model,
):
"""Decorator used to define guidance grammars"""
return _decorator(f, stateless=stateless, cache=cache, dedent=dedent, model=model)


Expand Down
95 changes: 95 additions & 0 deletions guidance/_guidance.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import sys
from typing import (
Any,
Callable,
Literal,
TypeVar,
Union,
overload,
)
if sys.version_info >= (3, 10):
from typing import ParamSpec, TypeAlias, Concatenate
else:
from typing_extensions import ParamSpec, TypeAlias, Concatenate

from ._grammar import GrammarFunction, RawFunction
from .models import Model

P = ParamSpec("P")
M: TypeAlias = Any # sort of Union[Model, GrammarFunction]?
R = TypeVar("R", bound = Union[RawFunction, GrammarFunction])
GuidanceWrappable = Callable[Concatenate[M, P], M]
GuidanceFunction = Callable[P, R]
StatefulGuidanceFunction = GuidanceFunction[P, RawFunction]
StatelessGuidanceFunction = GuidanceFunction[P, GrammarFunction]

@overload
def guidance(
f: GuidanceWrappable[P],
*,
stateless: Literal[False] = False,
cache: bool = ...,
dedent: bool = ...,
model: type[Model] = ...,
) -> StatefulGuidanceFunction[P]:
...


@overload
def guidance(
f: None = None,
*,
stateless: Literal[False] = False,
cache: bool = ...,
dedent: bool = ...,
model: type[Model] = ...,
) -> Callable[[GuidanceWrappable[P]], StatefulGuidanceFunction[P]]:
...


@overload
def guidance(
f: GuidanceWrappable[P],
*,
stateless: Literal[True],
cache: bool = ...,
dedent: bool = ...,
model: type[Model] = ...,
) -> StatelessGuidanceFunction[P]:
...


@overload
def guidance(
f: None = None,
*,
stateless: Literal[True],
cache: bool = ...,
dedent: bool = ...,
model: type[Model] = ...,
) -> Callable[[GuidanceWrappable[P]], StatelessGuidanceFunction[P]]:
...


@overload
def guidance(
f: GuidanceWrappable[P],
*,
stateless: Callable[..., bool],
cache: bool = ...,
dedent: bool = ...,
model: type[Model] = ...,
) -> GuidanceFunction[P, Union[RawFunction, GrammarFunction]]:
...


@overload
def guidance(
f: None = None,
*,
stateless: Callable[..., bool],
cache: bool = ...,
dedent: bool = ...,
model: type[Model] = ...,
) -> Callable[[GuidanceWrappable[P]], GuidanceFunction[P, Union[RawFunction, GrammarFunction]]]:
...
2 changes: 1 addition & 1 deletion guidance/library/_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def _gen_json_array(
# must be present before the next one may be added, meaning we have nested optionals:
# (first optional(,second optional(,third (optional(,...)))))
first, *rest = optional_items
tail = ""
tail: Union[str, GrammarFunction] = ""
for item in reversed(rest):
tail = optional("," + item + tail)
tail = first + tail
Expand Down
6 changes: 3 additions & 3 deletions guidance/library/_substring.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Optional
from typing import Optional, Dict, Union

from .._guidance import guidance

# from ._prefix_tree import prefix_tree
from .._grammar import string, select, capture, as_regular_grammar
from .._grammar import string, select, capture, as_regular_grammar, Terminal, GrammarFunction
from ._optional import optional


Expand Down Expand Up @@ -95,7 +95,7 @@ def sa_extend(self, c):
@guidance(stateless=True, dedent=False)
def substring(lm, target_string: str, name: Optional[str] = None):
suffix_automaton = SuffixAutomaton(target_string)
node_cache = {}
node_cache: Dict[int, Union[Terminal, GrammarFunction]] = {}
state_stack = [0] # Start with the initial state index (0) on the stack

# Loop as long as there are states on the stack
Expand Down
10 changes: 5 additions & 5 deletions guidance/library/_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ def __init__(self, call_grammar=None, tool_call=None, callable=None):
self.tool_call = tool_call


arg = lexeme(r"[^,=)]+")
kwarg = arg + "=" + arg
args = arg + zero_or_more("," + arg)
kwargs = kwarg + zero_or_more("," + kwarg)

def basic_func_grammar(name):
arg = lexeme(r"[^,=)]+")
kwarg = arg + "=" + arg
args = arg + zero_or_more("," + arg)
kwargs = kwarg + zero_or_more("," + kwarg)

obj = name + "("
obj += subgrammar(
name="tool_args",
Expand Down

0 comments on commit 003917c

Please sign in to comment.