Skip to content

Commit

Permalink
fix: correct embed and using arbitrary models on function calling
Browse files Browse the repository at this point in the history
  • Loading branch information
henrycunh committed Oct 8, 2023
1 parent 3ca4b11 commit 7429199
Show file tree
Hide file tree
Showing 8 changed files with 607 additions and 435 deletions.
1 change: 0 additions & 1 deletion cursive/compat/pydantic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
try:
from pydantic.v1 import BaseModel, Field, validate_arguments

except ImportError:
from pydantic import BaseModel, Field, validate_arguments

Expand Down
2 changes: 1 addition & 1 deletion cursive/cursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def embed(self, content: str):
"embedding:success",
CursiveHookPayload(
data=result,
time=time() - start,
duration=time() - start,
),
)
self._hooks.call_hook(
Expand Down
57 changes: 56 additions & 1 deletion cursive/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@ class CursiveCustomFunction(BaseModel):
function_schema: dict[str, Any]
pause: bool = False

class Config:
arbitrary_types_allowed = True


class CursiveFunction(CursiveCustomFunction):
def __setup__(self, function: Callable):
definition = function
description = dedent(function.__doc__ or "").strip()
parameters = validate_arguments(function).model.schema()


# Delete ['v__duplicate_kwargs', 'args', 'kwargs'] from parameters
for k in ["v__duplicate_kwargs", "args", "kwargs"]:
if k in parameters["properties"]:
Expand All @@ -33,10 +37,15 @@ def __setup__(self, function: Callable):
if parameters:
schema = parameters

properties = schema.get("properties") or {}
definitions = schema.get("definitions") or {}
resolved_properties = remove_key_deep(resolve_ref(properties, definitions), "title")


function_schema = {
"parameters": {
"type": schema.get("type"),
"properties": schema.get("properties") or {},
"properties": resolved_properties,
"required": schema.get("required") or [],
},
"description": description,
Expand Down Expand Up @@ -66,3 +75,49 @@ def decorator(function: Callable = None):
return CursiveFunction(function, pause=pause)

return decorator

def resolve_ref(data, definitions):
"""
Recursively checks for a $ref key in a dictionary and replaces it with an entry in the definitions
dictionary, changing the key `$ref` to `type`.
Args:
data (dict): The data dictionary to check for $ref keys.
definitions (dict): The definitions dictionary to replace $ref keys with.
Returns:
dict: The data dictionary with $ref keys replaced.
"""
if isinstance(data, dict):
if "$ref" in data:
ref = data["$ref"].split('/')[-1]
if ref in definitions:
definition = definitions[ref]
data = definition
else:
for key, value in data.items():
data[key] = resolve_ref(value, definitions)
elif isinstance(data, list):
for index, value in enumerate(data):
data[index] = resolve_ref(value, definitions)
return data

def remove_key_deep(data, key):
"""
Recursively removes a key from a dictionary.
Args:
data (dict): The data dictionary to remove the key from.
key (str): The key to remove from the dictionary.
Returns:
dict: The data dictionary with the key removed.
"""
if isinstance(data, dict):
data.pop(key, None)
for k, v in data.items():
data[k] = remove_key_deep(v, key)
elif isinstance(data, list):
for index, value in enumerate(data):
data[index] = remove_key_deep(value, key)
return data
19 changes: 19 additions & 0 deletions cursive/tests/test_function_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from cursive.compat.pydantic import BaseModel
from cursive.function import cursive_function

def test_function_schema_allows_arbitrary_types():

class Character(BaseModel):
name: str
age: int

@cursive_function()
def gen_arbitrary_type(character: Character):
"""
A test function.
character: A character.
"""
return f"{character.name} is {character.age} years old."

assert 'description' in gen_arbitrary_type.function_schema
13 changes: 13 additions & 0 deletions examples/compare-embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from cursive import Cursive
import numpy as np

cursive = Cursive()

x1 = cursive.embed("""Pizza""")

x2 = cursive.embed("""Cat""")

def cosine_similarity(x, y):
return np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))

print(cosine_similarity(x1, x2))
28 changes: 28 additions & 0 deletions examples/generate-list-of-objects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import List
from cursive.compat.pydantic import BaseModel
from cursive.function import cursive_function
from cursive import Cursive

class Input(BaseModel):
input: str
idx: int

@cursive_function(pause=True)
def gen_character_list(inputs: List[Input]):
"""
Given a prompt (which is directives for a LLM), generate possible inputs that could be fed to it.
Generate 10 inputs.
inputs: A list of inputs.
"""
return inputs

c = Cursive()

res = c.ask(
prompt="Generate a headline for a SaaS company.",
model="gpt-4",
function_call=gen_character_list
)

print(res.function_result)
Loading

0 comments on commit 7429199

Please sign in to comment.