Skip to content

Commit 59b2ed1

Browse files
authored
Develop (#20)
* feat: Function indexer improvements - Added support for any chroma client (http, persistent etc) - It is now possible to pass embedding function to use with the collection in chroma (Note: this may not be ideal if somehow different embeddings are used in different runs) - Function Indexer now supports default LLM interface used for function wrappers - All indexed functions are now loaded at Function Indexer instantiation - Function searching now returns named tuple SearchResult (['name', 'wrapper', 'function', 'distance']) - Improved the way a function is called from the wrapper to also accept *args in addition to **kwargs - Each wrapped function now gets an identifier (e.g. `tests.test_findex.TestClass.test_class_method`) to be able to find reference to the function at Function Indexer startup. We also add a hash wich is the SHA-1 of the identifier - Nested functions are not supported in the wrapper. - It is not possible to rehydrate partial functions that have been indexed. Instead at app startup such function should be re-added to the index (not to worry the indexer will check the hashes and will not attempt to add an existing function) - Updated docs for function indexing. BREAKING-CHANGE: The `FunctionIndexer` init and other method signatures have changed and are not backward compatible. Refs: #17 --------- Signed-off-by: Trayan Azarov <[email protected]>
1 parent ddb675c commit 59b2ed1

7 files changed

+379
-66
lines changed

docs/function-indexing/overview.md

+116
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# Function Indexing
2+
3+
The library supports function indexing (with some limitations). This means that you can index your functions and then
4+
query them using the `func-ai` library. This is useful if you want to query your functions using natural language and
5+
especially when you have a lot of functions which cannot fit in LLM context.
6+
7+
The Function Indexer (FI) relies on chromadb vector store to store function descriptions and then perform semantic
8+
search on those descriptions to find the most relevant functions.
9+
10+
Limitations:
11+
12+
- partials while supported for indexing and function wrapping using `OpenAIFunctionWrapper` cannot be rehydrated in the
13+
index once it is reloaded (e.g. app restart). The suggested workaround is at app/script startup to reindex the
14+
partials which will not re-add them in the index but will only rehydrate them in the index map.
15+
16+
## Usage
17+
18+
```python
19+
20+
import chromadb
21+
from chromadb import Settings
22+
from dotenv import load_dotenv
23+
24+
from func_ai.function_indexer import FunctionIndexer
25+
26+
27+
def function_to_index(a: int, b: int) -> int:
28+
"""
29+
This is a function that adds two numbers
30+
31+
:param a: First number
32+
:param b: Second number
33+
:return: Sum of a and b
34+
"""
35+
return a + b
36+
37+
38+
def another_function_to_index() -> str:
39+
"""
40+
This is a function returns hello world
41+
42+
:return: Hello World
43+
"""
44+
45+
return "Hello World"
46+
47+
48+
def test_function_indexer_init_no_args_find_function_enhanced_summary():
49+
load_dotenv()
50+
_indexer = FunctionIndexer(chroma_client=chromadb.PersistentClient(settings=Settings(allow_reset=True)))
51+
_indexer.reset_function_index()
52+
_indexer.index_functions([function_to_index, another_function_to_index], enhanced_summary=True)
53+
_results = _indexer.find_functions("Add two numbers", max_results=10, similarity_threshold=0.2)
54+
assert len(_results) == 1
55+
assert _results[0].function(1, 2) == 3
56+
57+
58+
if __name__ == "__main__":
59+
test_function_indexer_init_no_args_find_function_enhanced_summary()
60+
```
61+
62+
The above code shows how to use the two main functions of the Function Indexer:
63+
64+
- `index_functions` which indexes a list of functions
65+
- `find_functions` which finds functions based on a query string
66+
67+
## API Docs
68+
69+
### FunctionIndexer
70+
71+
Init args:
72+
73+
- `chroma_client`: A chromadb client to use for storing the function index. If not provided a new client will be created
74+
using the default settings (e.g. `chromadb.PersistentClient(settings=Settings(allow_reset=True))`).
75+
- `llm_interface`: An LLM interface to use for function wrapping. If not provided a new LLM interface will be created
76+
using the default settings (e.g. `OpenAIInterface()`).
77+
- `embedding_function`: A function that takes a string and returns an embedding. If not provided the default embedding
78+
function will be used (e.g. `embedding_functions.OpenAIEmbeddingFunction()`).
79+
- `collection_name`: The name of the collection to use for storing the function index. If not provided the defaults
80+
to `function_index`.
81+
82+
> Note: You should always initialize your FunctionIndexer with the same embedding function
83+
84+
#### `index_functions`
85+
86+
Args:
87+
88+
- `functions`: A list of functions to index
89+
- `enhanced_summary`: If True the function summary will be enhanced with the function docstring. Defaults to False.
90+
- `llm_interface`: An LLM interface to use for function wrapping. If not provided the one used in Indexer init will be
91+
used
92+
93+
#### `find_functions`
94+
95+
Args:
96+
97+
- `query`: The query string to use for finding functions
98+
- `max_results`: The maximum number of results to return. Defaults to 2.
99+
- `similarity_threshold`: The similarity threshold to use for filtering results. Defaults to 1.0.
100+
101+
Returns a named tuple `SearchResult` with the following fields:
102+
103+
- `function`: The function actual function that can be directly called
104+
- `name`: The function name
105+
- `wrapper`: The `OpenAIFunctionWrapper` function wrapper
106+
- `distance`: The distance of the function from the query string
107+
108+
> Note: The returned list is sorted by distance in ascending order (e.i. the first result is the closest to the query)
109+
110+
#### `functions_summary`
111+
112+
Returns: A dictionary containing function names and their descriptions.
113+
114+
#### `index_wrapper_functions`
115+
116+
This identical to `index_functions` but the list of functions is a list of `OpenAIFunctionWrapper` objects.

func_ai/function_indexer.py

+149-60
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,107 @@
1+
"""
2+
Function indexer module is responsible for making functions searchable
3+
"""
4+
import importlib
15
import inspect
26
import logging
37
import os
8+
from collections import namedtuple
49

510
import chromadb
611
import openai
712
from chromadb import Settings
13+
from chromadb.api import EmbeddingFunction
814
from chromadb.utils import embedding_functions
9-
from ulid import ULID
1015

11-
from func_ai.utils.llm_tools import OpenAIFunctionWrapper
16+
from func_ai.utils.llm_tools import OpenAIFunctionWrapper, OpenAIInterface, LLMInterface
1217
from func_ai.utils.py_function_parser import func_to_json
1318

1419
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.DEBUG)
1520
logger = logging.getLogger(__name__)
1621

22+
SearchResult = namedtuple('SearchResult', ['name', 'wrapper', 'function', 'distance'])
23+
1724

1825
class FunctionIndexer(object):
1926
"""
2027
Index functions
2128
"""
2229

23-
def __init__(self, db_path: str, collection_name: str = "function_index", **kwargs) -> None:
30+
def __init__(self, llm_interface: LLMInterface = OpenAIInterface(),
31+
chroma_client: chromadb.Client = chromadb.PersistentClient(settings=Settings(allow_reset=True)),
32+
embedding_function: EmbeddingFunction = None,
33+
collection_name: str = "function_index", **kwargs) -> None:
2434
"""
2535
Initialize function indexer
2636
:param db_path: The path where to store the database
2737
:param collection_name: The name of the collection
2838
:param kwargs: Additional arguments
2939
"""
30-
self._client = chromadb.PersistentClient(path=db_path, settings=Settings(
31-
anonymized_telemetry=False,
32-
allow_reset=True,
33-
))
40+
# self._client = chromadb.PersistentClient(path=db_path, settings=Settings(
41+
# anonymized_telemetry=False,
42+
# allow_reset=True,
43+
# ))
44+
self._client = chroma_client
3445
openai.api_key = kwargs.get("openai_api_key", os.getenv("OPENAI_API_KEY"))
46+
if embedding_function is None:
47+
self._embedding_function = embedding_functions.OpenAIEmbeddingFunction()
48+
else:
49+
self._embedding_function = embedding_function
3550
self.collection_name = collection_name
51+
self._init_collection()
52+
53+
self._llm_interface = llm_interface
3654
self._fns_map = {}
3755
self._fns_index_map = {}
3856
self._open_ai_function_map = []
39-
self.openai_ef = embedding_functions.OpenAIEmbeddingFunction(
40-
model_name="text-embedding-ada-002"
41-
)
57+
self._functions = {}
58+
_get_results = self._collection.get()
59+
if _get_results is not None:
60+
for idx, m in enumerate(_get_results['metadatas']):
61+
if "is_partial" in m and bool(m["is_partial"]):
62+
logger.warning(
63+
f"Found partial function {m['name']}. This function will not be rehydrated into the index.")
64+
continue
65+
self._functions[m["hash"]] = OpenAIFunctionWrapper.from_python_function(
66+
func=FunctionIndexer.function_from_ref(m["identifier"]), llm_interface=self._llm_interface)
67+
68+
def _init_collection(self) -> None:
69+
self._collection = self._client.get_or_create_collection(name=self.collection_name,
70+
metadata={"hnsw:space": "cosine"},
71+
embedding_function=self._embedding_function)
72+
73+
@staticmethod
74+
def function_from_ref(ref_identifier: str) -> callable:
75+
"""
76+
Get function from reference
77+
:param ref_identifier: The reference identifier
78+
:return: The function
79+
"""
80+
parts = ref_identifier.split('.')
81+
_fn = parts[-1]
82+
_mod = ""
83+
_last_mod = ""
84+
_module = None
85+
for pt in parts[:-1]:
86+
try:
87+
_last_mod = str(_mod)
88+
_mod += pt
89+
_module = importlib.import_module(_mod)
90+
_mod += "."
91+
# module = importlib.import_module('.'.join(parts[:-1]))
92+
# function = getattr(module, _fn)
93+
except ModuleNotFoundError:
94+
print("Last module: ", _last_mod)
95+
_module = importlib.import_module(_last_mod[:-1] if _last_mod.endswith(".") else _last_mod)
96+
_module = getattr(_module, pt)
97+
# print(f"Module: {getattr(module, pt)}")
98+
# module = importlib.import_module('.'.join(parts[:-1]))
99+
if _module is None:
100+
raise ModuleNotFoundError(f"Could not find module {_mod}")
101+
_fn = _module
102+
part = parts[-1]
103+
_fn = getattr(_fn, part)
104+
return _fn
42105

43106
def reset_function_index(self) -> None:
44107
"""
@@ -48,68 +111,91 @@ def reset_function_index(self) -> None:
48111
"""
49112

50113
self._client.reset()
114+
self._init_collection()
51115

52-
def index_functions(self, functions: list[callable]) -> None:
116+
def index_functions(self, functions: list[callable or OpenAIFunctionWrapper],
117+
llm_interface: LLMInterface = None,
118+
enhanced_summary: bool = False) -> None:
53119
"""
54120
Index one or more functions
55121
Note: Function uniqueness is not checked in this version
56122
57-
:param functions:
123+
:param llm_interface: The LLM interface
124+
:param functions: The functions to index
125+
:param enhanced_summary: Whether to use enhanced summary
58126
:return:
59127
"""
60-
61-
_ai_fun_map, _fns_map, _fns_index_map = FunctionIndexer.get_functions(functions)
62-
collection = self._client.get_or_create_collection(name=self.collection_name, metadata={"hnsw:space": "cosine"},
63-
embedding_function=self.openai_ef)
64-
self._fns_map.update(_fns_map)
65-
self._open_ai_function_map.extend(_ai_fun_map)
66-
self._fns_index_map.update(_fns_index_map)
67-
collection.add(documents=[f['description'] for f in _fns_index_map.values()],
68-
metadatas=[{"name": f,
69-
"file": str(inspect.getfile(_fns_map[f])),
70-
"module": inspect.getmodule(_fns_map[f]).__name__} for f, v in
71-
_fns_index_map.items()],
72-
ids=[str(ULID()) for _ in _fns_index_map.values()])
73-
74-
def index_wrapper_functions(self, functions: list[OpenAIFunctionWrapper]):
128+
_fn_llm_interface = llm_interface if llm_interface is not None else self._llm_interface
129+
_wrapped_functions = [
130+
OpenAIFunctionWrapper.from_python_function(func=f, llm_interface=_fn_llm_interface) for f
131+
in functions if not isinstance(f, OpenAIFunctionWrapper)]
132+
_wrapped_functions.extend([f for f in functions if isinstance(f, OpenAIFunctionWrapper)])
133+
_fn_hashes = [f.hash for f in _wrapped_functions]
134+
_existing_fn_results = self._collection.get(ids=_fn_hashes)
135+
print(_existing_fn_results)
136+
# filter wrapped functions that are already in the index
137+
_original_wrapped_functions = _wrapped_functions.copy()
138+
_wrapped_functions = [f for f in _wrapped_functions if f.hash not in _existing_fn_results["ids"]]
139+
if len(_wrapped_functions) == 0:
140+
logger.info("No new functions to index")
141+
self._functions.update(
142+
{f.hash: f for f in _original_wrapped_functions}) # we only rehydrate that are already in the index
143+
return
144+
_docs = []
145+
_metadatas = []
146+
_ids = []
147+
_function_summarizer = OpenAIInterface(max_tokens=200)
148+
for f in _wrapped_functions:
149+
if enhanced_summary:
150+
_function_summarizer.add_conversation_message(
151+
{"role": "system",
152+
"content": "You are an expert summarizer. Your purpose is to provide a good summary of the function so that the user can add the summary in an embedding database which will them be searched."})
153+
_fsummary = _function_summarizer.send(f"Summarize the function below.\n\n{inspect.getsource(f.func)}")
154+
_docs.append(f"{_fsummary['content']}")
155+
_function_summarizer.conversation_store.clear()
156+
else:
157+
_docs.append(f"{f.description}")
158+
_metadatas.append(
159+
{"name": f.name, "identifier": f.identifier, "hash": f.hash, "is_partial": str(f.is_partial),
160+
**f.metadata_dict})
161+
_ids.append(f.hash)
162+
163+
self._collection.add(documents=_docs,
164+
metadatas=_metadatas,
165+
ids=_ids)
166+
self._functions.update({f.hash: f for f in _wrapped_functions})
167+
168+
def index_wrapper_functions(self, functions: list[OpenAIFunctionWrapper],
169+
llm_interface: LLMInterface = None,
170+
enhanced_summary: bool = False) -> None:
75171
"""
76172
Index one or more functions
77173
Note: Function uniqueness is not checked in this version
78-
:param functions:
79-
:return:
174+
:param functions: The functions to index
175+
:param llm_interface: The LLM interface
176+
:param enhanced_summary: Whether to use enhanced summary
177+
:return: None
80178
"""
81-
collection = self._client.get_or_create_collection(name=self.collection_name,
82-
metadata={"hnsw:space": "cosine"},
83-
embedding_function=self.openai_ef)
84-
# print(f"Docs: {collection.get()}")
85-
collection.add(documents=[f.description for f in functions],
86-
metadatas=[{"name": f.name, **f.metadata_dict} for f in
87-
functions],
88-
ids=[str(ULID()) for _ in functions])
179+
self.index_functions(functions=functions, llm_interface=llm_interface, enhanced_summary=enhanced_summary)
89180

90-
def rehydrate_function_map(self, functions: list[callable]) -> None:
181+
def get_ai_fn_abbr_map(self) -> dict[str, str]:
91182
"""
92-
Rehydrate function map
183+
Get AI function abbreviated map
93184
94-
:param functions:
95-
:return:
185+
:return: Map of function name (key) and description (value)
96186
"""
97187

98-
_ai_fun_map, _fns_map, _fns_index_map = FunctionIndexer.get_functions(functions)
99-
self._fns_map.update(_fns_map)
100-
self._open_ai_function_map.extend(_ai_fun_map)
101-
self._fns_index_map.update(_fns_index_map)
188+
return {f['name']: f['description'] for f in self._open_ai_function_map}
102189

103-
def get_ai_fn_abbr_map(self) -> dict[str, str]:
190+
def functions_summary(self) -> dict[str, str]:
104191
"""
105-
Get AI function abbreviated map
192+
Get functions summary
106193
107194
:return: Map of function name (key) and description (value)
108195
"""
196+
return {f.name: f.description for f in self._functions.values()}
109197

110-
return {f['name']: f['description'] for f in self._open_ai_function_map}
111-
112-
def find_functions(self, query: str, max_results: int = 2, similarity_threshold: float = 1.0) -> callable:
198+
def find_functions(self, query: str, max_results: int = 2, similarity_threshold: float = 1.0) -> list[SearchResult]:
113199
"""
114200
Find functions by description
115201
@@ -119,16 +205,19 @@ def find_functions(self, query: str, max_results: int = 2, similarity_threshold:
119205
:return:
120206
"""
121207
_response = []
122-
collection = self._client.get_or_create_collection(name=self.collection_name,
123-
metadata={"hnsw:space": "cosine"},
124-
embedding_function=self.openai_ef)
125-
# print(collection.get())
126-
res = collection.query(query_texts=[query], n_results=max_results)
127-
print(f"Got results from sematic search: {res}")
128-
for r in range(len(res['documents'][0])):
129-
print(f"Distance: {res['distances'][0][r]} vs threshold: {similarity_threshold}")
130-
if res['distances'][0][r] <= similarity_threshold:
131-
_response.append(res['metadatas'][0][r]['name'])
208+
# print(self._functions.keys())
209+
res = self._collection.query(query_texts=[query], n_results=max_results)
210+
# print(f"Got results from sematic search: {res}")
211+
for idx, _ in enumerate(res['documents'][0]):
212+
print(f"Distance: {res['distances'][0][idx]} vs threshold: {similarity_threshold}")
213+
if res['distances'][0][idx] <= similarity_threshold:
214+
_search_res = SearchResult(name=res['metadatas'][0][idx]['name'],
215+
function=self._functions[res['metadatas'][0][idx]['hash']].func,
216+
wrapper=self._functions[res['metadatas'][0][idx]['hash']],
217+
distance=res['distances'][0][idx])
218+
_response.append(_search_res)
219+
220+
_response.sort(key=lambda x: x.distance)
132221
return _response
133222

134223
@staticmethod

0 commit comments

Comments
 (0)