1
+ """
2
+ Function indexer module is responsible for making functions searchable
3
+ """
4
+ import importlib
1
5
import inspect
2
6
import logging
3
7
import os
8
+ from collections import namedtuple
4
9
5
10
import chromadb
6
11
import openai
7
12
from chromadb import Settings
13
+ from chromadb .api import EmbeddingFunction
8
14
from chromadb .utils import embedding_functions
9
- from ulid import ULID
10
15
11
- from func_ai .utils .llm_tools import OpenAIFunctionWrapper
16
+ from func_ai .utils .llm_tools import OpenAIFunctionWrapper , OpenAIInterface , LLMInterface
12
17
from func_ai .utils .py_function_parser import func_to_json
13
18
14
19
logging .basicConfig (format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' , level = logging .DEBUG )
15
20
logger = logging .getLogger (__name__ )
16
21
22
+ SearchResult = namedtuple ('SearchResult' , ['name' , 'wrapper' , 'function' , 'distance' ])
23
+
17
24
18
25
class FunctionIndexer (object ):
19
26
"""
20
27
Index functions
21
28
"""
22
29
23
- def __init__ (self , db_path : str , collection_name : str = "function_index" , ** kwargs ) -> None :
30
+ def __init__ (self , llm_interface : LLMInterface = OpenAIInterface (),
31
+ chroma_client : chromadb .Client = chromadb .PersistentClient (settings = Settings (allow_reset = True )),
32
+ embedding_function : EmbeddingFunction = None ,
33
+ collection_name : str = "function_index" , ** kwargs ) -> None :
24
34
"""
25
35
Initialize function indexer
26
36
:param db_path: The path where to store the database
27
37
:param collection_name: The name of the collection
28
38
:param kwargs: Additional arguments
29
39
"""
30
- self ._client = chromadb .PersistentClient (path = db_path , settings = Settings (
31
- anonymized_telemetry = False ,
32
- allow_reset = True ,
33
- ))
40
+ # self._client = chromadb.PersistentClient(path=db_path, settings=Settings(
41
+ # anonymized_telemetry=False,
42
+ # allow_reset=True,
43
+ # ))
44
+ self ._client = chroma_client
34
45
openai .api_key = kwargs .get ("openai_api_key" , os .getenv ("OPENAI_API_KEY" ))
46
+ if embedding_function is None :
47
+ self ._embedding_function = embedding_functions .OpenAIEmbeddingFunction ()
48
+ else :
49
+ self ._embedding_function = embedding_function
35
50
self .collection_name = collection_name
51
+ self ._init_collection ()
52
+
53
+ self ._llm_interface = llm_interface
36
54
self ._fns_map = {}
37
55
self ._fns_index_map = {}
38
56
self ._open_ai_function_map = []
39
- self .openai_ef = embedding_functions .OpenAIEmbeddingFunction (
40
- model_name = "text-embedding-ada-002"
41
- )
57
+ self ._functions = {}
58
+ _get_results = self ._collection .get ()
59
+ if _get_results is not None :
60
+ for idx , m in enumerate (_get_results ['metadatas' ]):
61
+ if "is_partial" in m and bool (m ["is_partial" ]):
62
+ logger .warning (
63
+ f"Found partial function { m ['name' ]} . This function will not be rehydrated into the index." )
64
+ continue
65
+ self ._functions [m ["hash" ]] = OpenAIFunctionWrapper .from_python_function (
66
+ func = FunctionIndexer .function_from_ref (m ["identifier" ]), llm_interface = self ._llm_interface )
67
+
68
+ def _init_collection (self ) -> None :
69
+ self ._collection = self ._client .get_or_create_collection (name = self .collection_name ,
70
+ metadata = {"hnsw:space" : "cosine" },
71
+ embedding_function = self ._embedding_function )
72
+
73
+ @staticmethod
74
+ def function_from_ref (ref_identifier : str ) -> callable :
75
+ """
76
+ Get function from reference
77
+ :param ref_identifier: The reference identifier
78
+ :return: The function
79
+ """
80
+ parts = ref_identifier .split ('.' )
81
+ _fn = parts [- 1 ]
82
+ _mod = ""
83
+ _last_mod = ""
84
+ _module = None
85
+ for pt in parts [:- 1 ]:
86
+ try :
87
+ _last_mod = str (_mod )
88
+ _mod += pt
89
+ _module = importlib .import_module (_mod )
90
+ _mod += "."
91
+ # module = importlib.import_module('.'.join(parts[:-1]))
92
+ # function = getattr(module, _fn)
93
+ except ModuleNotFoundError :
94
+ print ("Last module: " , _last_mod )
95
+ _module = importlib .import_module (_last_mod [:- 1 ] if _last_mod .endswith ("." ) else _last_mod )
96
+ _module = getattr (_module , pt )
97
+ # print(f"Module: {getattr(module, pt)}")
98
+ # module = importlib.import_module('.'.join(parts[:-1]))
99
+ if _module is None :
100
+ raise ModuleNotFoundError (f"Could not find module { _mod } " )
101
+ _fn = _module
102
+ part = parts [- 1 ]
103
+ _fn = getattr (_fn , part )
104
+ return _fn
42
105
43
106
def reset_function_index (self ) -> None :
44
107
"""
@@ -48,68 +111,91 @@ def reset_function_index(self) -> None:
48
111
"""
49
112
50
113
self ._client .reset ()
114
+ self ._init_collection ()
51
115
52
- def index_functions (self , functions : list [callable ]) -> None :
116
+ def index_functions (self , functions : list [callable or OpenAIFunctionWrapper ],
117
+ llm_interface : LLMInterface = None ,
118
+ enhanced_summary : bool = False ) -> None :
53
119
"""
54
120
Index one or more functions
55
121
Note: Function uniqueness is not checked in this version
56
122
57
- :param functions:
123
+ :param llm_interface: The LLM interface
124
+ :param functions: The functions to index
125
+ :param enhanced_summary: Whether to use enhanced summary
58
126
:return:
59
127
"""
60
-
61
- _ai_fun_map , _fns_map , _fns_index_map = FunctionIndexer .get_functions (functions )
62
- collection = self ._client .get_or_create_collection (name = self .collection_name , metadata = {"hnsw:space" : "cosine" },
63
- embedding_function = self .openai_ef )
64
- self ._fns_map .update (_fns_map )
65
- self ._open_ai_function_map .extend (_ai_fun_map )
66
- self ._fns_index_map .update (_fns_index_map )
67
- collection .add (documents = [f ['description' ] for f in _fns_index_map .values ()],
68
- metadatas = [{"name" : f ,
69
- "file" : str (inspect .getfile (_fns_map [f ])),
70
- "module" : inspect .getmodule (_fns_map [f ]).__name__ } for f , v in
71
- _fns_index_map .items ()],
72
- ids = [str (ULID ()) for _ in _fns_index_map .values ()])
73
-
74
- def index_wrapper_functions (self , functions : list [OpenAIFunctionWrapper ]):
128
+ _fn_llm_interface = llm_interface if llm_interface is not None else self ._llm_interface
129
+ _wrapped_functions = [
130
+ OpenAIFunctionWrapper .from_python_function (func = f , llm_interface = _fn_llm_interface ) for f
131
+ in functions if not isinstance (f , OpenAIFunctionWrapper )]
132
+ _wrapped_functions .extend ([f for f in functions if isinstance (f , OpenAIFunctionWrapper )])
133
+ _fn_hashes = [f .hash for f in _wrapped_functions ]
134
+ _existing_fn_results = self ._collection .get (ids = _fn_hashes )
135
+ print (_existing_fn_results )
136
+ # filter wrapped functions that are already in the index
137
+ _original_wrapped_functions = _wrapped_functions .copy ()
138
+ _wrapped_functions = [f for f in _wrapped_functions if f .hash not in _existing_fn_results ["ids" ]]
139
+ if len (_wrapped_functions ) == 0 :
140
+ logger .info ("No new functions to index" )
141
+ self ._functions .update (
142
+ {f .hash : f for f in _original_wrapped_functions }) # we only rehydrate that are already in the index
143
+ return
144
+ _docs = []
145
+ _metadatas = []
146
+ _ids = []
147
+ _function_summarizer = OpenAIInterface (max_tokens = 200 )
148
+ for f in _wrapped_functions :
149
+ if enhanced_summary :
150
+ _function_summarizer .add_conversation_message (
151
+ {"role" : "system" ,
152
+ "content" : "You are an expert summarizer. Your purpose is to provide a good summary of the function so that the user can add the summary in an embedding database which will them be searched." })
153
+ _fsummary = _function_summarizer .send (f"Summarize the function below.\n \n { inspect .getsource (f .func )} " )
154
+ _docs .append (f"{ _fsummary ['content' ]} " )
155
+ _function_summarizer .conversation_store .clear ()
156
+ else :
157
+ _docs .append (f"{ f .description } " )
158
+ _metadatas .append (
159
+ {"name" : f .name , "identifier" : f .identifier , "hash" : f .hash , "is_partial" : str (f .is_partial ),
160
+ ** f .metadata_dict })
161
+ _ids .append (f .hash )
162
+
163
+ self ._collection .add (documents = _docs ,
164
+ metadatas = _metadatas ,
165
+ ids = _ids )
166
+ self ._functions .update ({f .hash : f for f in _wrapped_functions })
167
+
168
+ def index_wrapper_functions (self , functions : list [OpenAIFunctionWrapper ],
169
+ llm_interface : LLMInterface = None ,
170
+ enhanced_summary : bool = False ) -> None :
75
171
"""
76
172
Index one or more functions
77
173
Note: Function uniqueness is not checked in this version
78
- :param functions:
79
- :return:
174
+ :param functions: The functions to index
175
+ :param llm_interface: The LLM interface
176
+ :param enhanced_summary: Whether to use enhanced summary
177
+ :return: None
80
178
"""
81
- collection = self ._client .get_or_create_collection (name = self .collection_name ,
82
- metadata = {"hnsw:space" : "cosine" },
83
- embedding_function = self .openai_ef )
84
- # print(f"Docs: {collection.get()}")
85
- collection .add (documents = [f .description for f in functions ],
86
- metadatas = [{"name" : f .name , ** f .metadata_dict } for f in
87
- functions ],
88
- ids = [str (ULID ()) for _ in functions ])
179
+ self .index_functions (functions = functions , llm_interface = llm_interface , enhanced_summary = enhanced_summary )
89
180
90
- def rehydrate_function_map (self , functions : list [ callable ] ) -> None :
181
+ def get_ai_fn_abbr_map (self ) -> dict [ str , str ] :
91
182
"""
92
- Rehydrate function map
183
+ Get AI function abbreviated map
93
184
94
- :param functions:
95
- :return:
185
+ :return: Map of function name (key) and description (value)
96
186
"""
97
187
98
- _ai_fun_map , _fns_map , _fns_index_map = FunctionIndexer .get_functions (functions )
99
- self ._fns_map .update (_fns_map )
100
- self ._open_ai_function_map .extend (_ai_fun_map )
101
- self ._fns_index_map .update (_fns_index_map )
188
+ return {f ['name' ]: f ['description' ] for f in self ._open_ai_function_map }
102
189
103
- def get_ai_fn_abbr_map (self ) -> dict [str , str ]:
190
+ def functions_summary (self ) -> dict [str , str ]:
104
191
"""
105
- Get AI function abbreviated map
192
+ Get functions summary
106
193
107
194
:return: Map of function name (key) and description (value)
108
195
"""
196
+ return {f .name : f .description for f in self ._functions .values ()}
109
197
110
- return {f ['name' ]: f ['description' ] for f in self ._open_ai_function_map }
111
-
112
- def find_functions (self , query : str , max_results : int = 2 , similarity_threshold : float = 1.0 ) -> callable :
198
+ def find_functions (self , query : str , max_results : int = 2 , similarity_threshold : float = 1.0 ) -> list [SearchResult ]:
113
199
"""
114
200
Find functions by description
115
201
@@ -119,16 +205,19 @@ def find_functions(self, query: str, max_results: int = 2, similarity_threshold:
119
205
:return:
120
206
"""
121
207
_response = []
122
- collection = self ._client .get_or_create_collection (name = self .collection_name ,
123
- metadata = {"hnsw:space" : "cosine" },
124
- embedding_function = self .openai_ef )
125
- # print(collection.get())
126
- res = collection .query (query_texts = [query ], n_results = max_results )
127
- print (f"Got results from sematic search: { res } " )
128
- for r in range (len (res ['documents' ][0 ])):
129
- print (f"Distance: { res ['distances' ][0 ][r ]} vs threshold: { similarity_threshold } " )
130
- if res ['distances' ][0 ][r ] <= similarity_threshold :
131
- _response .append (res ['metadatas' ][0 ][r ]['name' ])
208
+ # print(self._functions.keys())
209
+ res = self ._collection .query (query_texts = [query ], n_results = max_results )
210
+ # print(f"Got results from sematic search: {res}")
211
+ for idx , _ in enumerate (res ['documents' ][0 ]):
212
+ print (f"Distance: { res ['distances' ][0 ][idx ]} vs threshold: { similarity_threshold } " )
213
+ if res ['distances' ][0 ][idx ] <= similarity_threshold :
214
+ _search_res = SearchResult (name = res ['metadatas' ][0 ][idx ]['name' ],
215
+ function = self ._functions [res ['metadatas' ][0 ][idx ]['hash' ]].func ,
216
+ wrapper = self ._functions [res ['metadatas' ][0 ][idx ]['hash' ]],
217
+ distance = res ['distances' ][0 ][idx ])
218
+ _response .append (_search_res )
219
+
220
+ _response .sort (key = lambda x : x .distance )
132
221
return _response
133
222
134
223
@staticmethod
0 commit comments