Skip to content

Commit 9020992

Browse files
authored
OpenAPI API QA support (#9)
* chore: Added project urls and fixed python vulnerability scan paths * chore: Added some fixes and documentation to py_function_parser. (#2) * chore: Added some fixes and documentation to py_function_parser. Added py parser tests. * fix: Fixing vuln scan action * feat: Added LLM tools, Added class inheritance for classes that need LLM handling. * feat: Added support for enums. Updated tests. Added links to inspiration projects * feat: First 'naive' iteration of openai function calling with openAPI specs * chore: version bump * feat: OpenAPI Parsing and calling is now implemented and working. * feat: Jinja2 template rendering with prompt and OpenAI Functions * feat: Added custom filter support. * feat: Refactored OpenAPI calling. Added an abstraction for OpenAI functions OpenAIFunctionWrapper. It is still in early dev. * feat: Added json serialization of wrappers. * feat: Fixing python vulnerability scan workflow. * chore: Version bump for json serialization feature of OpenAPI wrapper * feat: OpenAPI spec chatbot (#8) Refs: #7 * feat: Added OpenAPI QA bot support and example gradio chatbot Refs: #7 --------- Signed-off-by: Trayan Azarov <[email protected]>
1 parent 5f06f07 commit 9020992

12 files changed

+1235
-83
lines changed

README.md

+55
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,61 @@ Name: John
119119
Age: 20
120120
"""
121121
```
122+
### OpenAPI Spec Chat Bot
123+
124+
This example starts a `gradio` server that allows you to interact with the OpenAPI spec.
125+
126+
```python
127+
import gradio as gr
128+
from dotenv import load_dotenv
129+
130+
from func_ai.utils.llm_tools import OpenAIInterface
131+
from func_ai.utils.openapi_function_parser import OpenAPISpecOpenAIWrapper
132+
133+
_chat_message = []
134+
135+
_spec = None
136+
137+
138+
def add_text(history, text):
139+
global _chat_message
140+
history = history + [(text, None)]
141+
_chat_message.append(_spec.api_qa(text, max_tokens=500))
142+
return history, ""
143+
144+
145+
def add_file(history, file):
146+
history = history + [((file.name,), None)]
147+
return history
148+
149+
150+
def bot(history):
151+
global _chat_message
152+
# print(temp_callback_handler.get_output())
153+
# response = temp_callback_handler.get_output()['output']
154+
history[-1][1] = _chat_message[-1]
155+
return history
156+
157+
158+
with gr.Blocks() as demo:
159+
chatbot = gr.Chatbot([], elem_id="chatbot").style(height=1500)
160+
161+
with gr.Row():
162+
with gr.Column(scale=1):
163+
txt = gr.Textbox(
164+
show_label=False,
165+
placeholder="Enter text and press enter",
166+
).style(container=False)
167+
txt.submit(add_text, [chatbot, txt], [chatbot, txt]).then(
168+
bot, chatbot, chatbot
169+
)
170+
171+
if __name__ == "__main__":
172+
load_dotenv()
173+
_spec = OpenAPISpecOpenAIWrapper.from_url('http://petstore.swagger.io/v2/swagger.json',
174+
llm_interface=OpenAIInterface(), index=True)
175+
demo.launch()
176+
```
122177

123178
## Inspiration
124179

func_ai/function_indexer.py

+34-12
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,40 @@
11
import inspect
2-
import json
2+
import logging
33

44
import chromadb
55
from chromadb import Settings
66
from chromadb.utils import embedding_functions
77
from ulid import ULID
88

9+
from func_ai.utils.llm_tools import OpenAIFunctionWrapper
910
from func_ai.utils.py_function_parser import func_to_json
1011

11-
openai_ef = embedding_functions.OpenAIEmbeddingFunction(
12-
model_name="text-embedding-ada-002"
13-
)
12+
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.DEBUG)
13+
logger = logging.getLogger(__name__)
1414

1515

1616
class FunctionIndexer(object):
1717
"""
1818
Index functions
1919
"""
2020

21-
def __init__(self, db_path: str) -> None:
21+
def __init__(self, db_path: str, collection_name: str = "function_index") -> None:
2222
"""
2323
Initialize function indexer
24-
:param db_path:
24+
:param db_path: The path where to store the database
25+
:param collection_name: The name of the collection
2526
"""
2627
self._client = chromadb.Client(Settings(
2728
chroma_db_impl="duckdb+parquet",
2829
persist_directory=db_path # Optional, defaults to .chromadb/ in the current directory
2930
))
31+
self.collection_name = collection_name
3032
self._fns_map = {}
3133
self._fns_index_map = {}
3234
self._open_ai_function_map = []
35+
self.openai_ef = embedding_functions.OpenAIEmbeddingFunction(
36+
model_name="text-embedding-ada-002"
37+
)
3338

3439
def reset_function_index(self) -> None:
3540
"""
@@ -50,8 +55,8 @@ def index_functions(self, functions: list[callable]) -> None:
5055
"""
5156

5257
_ai_fun_map, _fns_map, _fns_index_map = FunctionIndexer.get_functions(functions)
53-
collection = self._client.get_or_create_collection(name="function_index", metadata={"hnsw:space": "cosine"},
54-
embedding_function=openai_ef)
58+
collection = self._client.get_or_create_collection(name=self.collection_name, metadata={"hnsw:space": "cosine"},
59+
embedding_function=self.openai_ef)
5560
self._fns_map.update(_fns_map)
5661
self._open_ai_function_map.extend(_ai_fun_map)
5762
self._fns_index_map.update(_fns_index_map)
@@ -62,6 +67,22 @@ def index_functions(self, functions: list[callable]) -> None:
6267
_fns_index_map.items()],
6368
ids=[str(ULID()) for _ in _fns_index_map.values()])
6469

70+
def index_wrapper_functions(self, functions: list[OpenAIFunctionWrapper]):
71+
"""
72+
Index one or more functions
73+
Note: Function uniqueness is not checked in this version
74+
:param functions:
75+
:return:
76+
"""
77+
collection = self._client.get_or_create_collection(name=self.collection_name,
78+
metadata={"hnsw:space": "cosine"},
79+
embedding_function=self.openai_ef)
80+
# print(f"Docs: {collection.get()}")
81+
collection.add(documents=[f.description for f in functions],
82+
metadatas=[{"name": f.name, **f.metadata_dict} for f in
83+
functions],
84+
ids=[str(ULID()) for _ in functions])
85+
6586
def rehydrate_function_map(self, functions: list[callable]) -> None:
6687
"""
6788
Rehydrate function map
@@ -92,13 +113,14 @@ def find_functions(self, query: str, max_results: int = 2) -> callable:
92113
:return:
93114
"""
94115
_response = []
95-
collection = self._client.get_or_create_collection(name="function_index", metadata={"hnsw:space": "cosine"},
96-
embedding_function=openai_ef)
97-
print(collection.get())
116+
collection = self._client.get_or_create_collection(name=self.collection_name,
117+
metadata={"hnsw:space": "cosine"},
118+
embedding_function=self.openai_ef)
119+
# print(collection.get())
98120
res = collection.query(query_texts=[query], n_results=max_results)
99121
print(f"Got response for sematic search: {res}")
100122
for r in res['metadatas'][0]:
101-
_response.append(self._fns_map[r['name']])
123+
_response.append(r['name'])
102124
return _response
103125

104126
@staticmethod

func_ai/ui_demos/__init__.py

Whitespace-only changes.

func_ai/ui_demos/api_qa.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import gradio as gr
2+
from dotenv import load_dotenv
3+
4+
from func_ai.utils.llm_tools import OpenAIInterface
5+
from func_ai.utils.openapi_function_parser import OpenAPISpecOpenAIWrapper
6+
7+
_chat_message = []
8+
9+
_spec = None
10+
11+
12+
def add_text(history, text):
13+
global _chat_message
14+
history = history + [(text, None)]
15+
_chat_message.append(_spec.api_qa(text, max_tokens=500))
16+
return history, ""
17+
18+
19+
def add_file(history, file):
20+
history = history + [((file.name,), None)]
21+
return history
22+
23+
24+
def bot(history):
25+
global _chat_message
26+
# print(temp_callback_handler.get_output())
27+
# response = temp_callback_handler.get_output()['output']
28+
history[-1][1] = _chat_message[-1]
29+
return history
30+
31+
32+
with gr.Blocks() as demo:
33+
chatbot = gr.Chatbot([], elem_id="chatbot").style(height=1500)
34+
35+
with gr.Row():
36+
with gr.Column(scale=1):
37+
txt = gr.Textbox(
38+
show_label=False,
39+
placeholder="Enter text and press enter",
40+
).style(container=False)
41+
txt.submit(add_text, [chatbot, txt], [chatbot, txt]).then(
42+
bot, chatbot, chatbot
43+
)
44+
45+
if __name__ == "__main__":
46+
load_dotenv()
47+
_spec = OpenAPISpecOpenAIWrapper.from_url('http://petstore.swagger.io/v2/swagger.json',
48+
llm_interface=OpenAIInterface(), index=True)
49+
demo.launch()
+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
You are an API expert. Your goal is to assist the user in using an API that he/she is not familiar with.
2+
The user will provide commands which you will use to find out information about an API
3+
Then the user will ask you questions about the API and you will answer them.
4+
5+
Rules:
6+
1. You will only answer questions about the API
7+
2. You will keep your output only to the essential information

func_ai/utils/common.py

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import inspect
2+
3+
4+
def arg_in_func(func, arg_name):
5+
# Get the signature of the function
6+
signature = inspect.signature(func)
7+
8+
# Get the parameters of the function from the signature
9+
parameters = signature.parameters
10+
11+
# Check if the arg_name is in the parameters
12+
return arg_name in parameters

func_ai/utils/llm_tools.py

+66-10
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515

1616
from pydantic import BaseModel, Field
1717

18-
from func_ai.utils.py_function_parser import type_mapping
18+
from func_ai.utils.common import arg_in_func
19+
from func_ai.utils.py_function_parser import type_mapping, func_to_json
1920

2021
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.DEBUG)
2122
logger = logging.getLogger(__name__)
@@ -44,6 +45,13 @@ def get_conversation(self) -> list:
4445
"""
4546
return self.conversation
4647

48+
def get_last_message(self) -> Any:
49+
"""
50+
Returns the last message
51+
:return:
52+
"""
53+
return self.conversation[-1]
54+
4755

4856
class OpenAIConversationStore(ConversationStore):
4957

@@ -129,15 +137,28 @@ def get_conversation(self) -> list[any]:
129137
"""
130138
return self.conversation_store.get_conversation()
131139

132-
def add_conversation_message(self, message: any) -> "LLMInterface":
140+
def add_conversation_message(self, message: any, update_llm: bool = False, **kwargs) -> "LLMInterface":
133141
"""
134142
Adds a message to the conversation
135-
:param message:
143+
:param message: The message to add
144+
:param update_llm: Whether to update the LLM with the new conversation
145+
:param kwargs: Parameters to pass to the API
136146
:return:
137147
"""
138148
self.conversation_store.add_message(message)
149+
if update_llm:
150+
self.update_llm_conversation(**kwargs)
139151
return self
140152

153+
def update_llm_conversation(self, **kwargs) -> "LLMInterface":
154+
"""
155+
Sends the updated conversation to the LLM
156+
157+
:param kwargs: Parameters to pass to the API
158+
:return:
159+
"""
160+
raise NotImplementedError
161+
141162

142163
class OpenAIInterface(LLMInterface):
143164
"""
@@ -158,12 +179,15 @@ def __init__(self, *args, **kwargs):
158179

159180
@retry(stop=stop_after_attempt(3), reraise=True, wait=wait_fixed(1),
160181
retry_error_callback=lambda x: logger.warning(x))
161-
def send(self, prompt: str, **kwargs) -> dict:
182+
def update_llm_conversation(self, **kwargs) -> "OpenAIInterface":
183+
"""
184+
Sends the updated conversation to the LLM
185+
186+
:param kwargs: Parameters to pass to the API
187+
:return:
188+
"""
162189
_functions = kwargs.get("functions", None)
163190
_model = kwargs.get("model", self.model)
164-
# print(type(self._conversation_store))
165-
self.conversation_store.add_message({"role": "user", "content": prompt})
166-
logger.debug(f"Prompt: {prompt}")
167191
try:
168192
if _functions:
169193
response = openai.ChatCompletion.create(
@@ -199,12 +223,20 @@ def send(self, prompt: str, **kwargs) -> dict:
199223
self.update_cost(_model, response)
200224
_response_message = response["choices"][0]["message"]
201225
self.conversation_store.add_message(_response_message)
202-
return _response_message
226+
return self
203227
except Exception as e:
204228
logger.error(f"Error: {e}")
205229
traceback.print_exc()
206230
raise e
207231

232+
def send(self, prompt: str, **kwargs) -> dict[str, any]:
233+
# print(type(self._conversation_store))
234+
self.conversation_store.add_message({"role": "user", "content": prompt})
235+
logger.debug(f"Prompt: {prompt}")
236+
response = self.update_llm_conversation(**kwargs)
237+
logger.debug(f"Response: {response}")
238+
return self.conversation_store.get_last_message()
239+
208240
def load_cost_mapping(self, file_path: str) -> None:
209241
with open(file_path) as f:
210242
self.cost_mapping = json.load(f)
@@ -221,7 +253,6 @@ def update_cost(self, model, api_response) -> None:
221253
self.usage[model]["total_tokens"] += api_response['usage']['total_tokens']
222254

223255

224-
225256
class OpenAISchema(BaseModel):
226257
@classmethod
227258
@property
@@ -320,7 +351,10 @@ def __init__(self, llm_interface: LLMInterface, name: str, description: str, par
320351
assert "required" in parameters, "Required field not present in parameters"
321352
self._parameters = parameters
322353
assert callable(func) or isinstance(func, functools.partial), "Function must be callable"
323-
self.func = functools.partial(func, action=self)
354+
if arg_in_func(func, "action"):
355+
self.func = functools.partial(func, action=self)
356+
else:
357+
self.func = func
324358
self._metadata = kwargs
325359
self._llm_calls = []
326360

@@ -369,6 +403,15 @@ def metadata(self) -> dict[str, any]:
369403
"""
370404
return self._metadata
371405

406+
@property
407+
def metadata_dict(self) -> dict[str, str]:
408+
"""
409+
Returns the metadata of the function as a
410+
411+
:return:
412+
"""
413+
return {k: str(v) for k, v in self._metadata.items()}
414+
372415
@property
373416
def schema(self) -> dict[str, any]:
374417
"""
@@ -457,3 +500,16 @@ def from_prompt(self, prompt: str, **kwargs) -> "OpenAIFunctionWrapper":
457500
"""
458501
self.from_response(self.llm_interface.send(prompt, functions=[self.schema]))
459502
return self
503+
504+
@classmethod
505+
def from_python_function(cls, func: callable, llm_interface: LLMInterface, **kwargs) -> "OpenAIFunctionWrapper":
506+
"""
507+
Returns an instance of the class from Python function
508+
509+
:param func: Python function
510+
:param llm_interface: LLM interface
511+
:param kwargs: arguments to be passed to the function
512+
:return:
513+
"""
514+
_func = func_to_json(func)
515+
return cls(llm_interface=llm_interface, func=func, **_func, **kwargs)

0 commit comments

Comments
 (0)