diff --git a/chatgpt2api/chatgpt2api.py b/chatgpt2api/chatgpt2api.py index 0199862b..72ae202b 100644 --- a/chatgpt2api/chatgpt2api.py +++ b/chatgpt2api/chatgpt2api.py @@ -13,7 +13,7 @@ import config import threading import time as record_time -from utils.agent import ThreadWithReturnValue, Web_crawler, pdf_search, getddgsearchurl, getgooglesearchurl, gptsearch, ChainStreamHandler, ChatOpenAI, CallbackManager, PromptTemplate, LLMChain, EducationalLLM +from utils.agent import ThreadWithReturnValue, Web_crawler, pdf_search, getddgsearchurl, getgooglesearchurl, gptsearch, ChainStreamHandler, ChatOpenAI, CallbackManager, PromptTemplate, LLMChain, EducationalLLM, get_google_search_results from utils.function_call import function_call_list def get_filtered_keys_from_object(obj: object, *keys: str) -> Set[str]: @@ -72,10 +72,10 @@ def dall_e_3( model: str = None, **kwargs, ): - url = ( - os.environ.get("API_URL").split("chat")[0] + "images/generations" - or "https://api.openai.com/v1/images/generations" - ) + if os.environ.get("API_URL") and "v1" in os.environ.get("API_URL"): + url = os.environ.get("API_URL").split("v1")[0] + "v1/images/generations" + else: + url = "https://api.openai.com/v1/images/generations" headers = {"Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"} json_post = { @@ -319,9 +319,10 @@ def ask_stream( # kwargs.get("max_tokens", self.max_tokens), # ), } + json_post.update(function_call_list["base"]) if config.SEARCH_USE_GPT: - json_post.update(function_call_list["web_search"]) - json_post.update(function_call_list["url_fetch"]) + json_post["functions"].append(function_call_list["web_search"]) + json_post["functions"].append(function_call_list["url_fetch"]) response = self.session.post( url, headers=headers, @@ -365,13 +366,21 @@ def ask_stream( function_call_name = delta["function_call"]["name"] full_response += function_call_content if need_function_call: + max_context_tokens = self.truncate_limit - self.get_token_count(convo_id) response_role = "function" - if function_call_name == "get_web_search_results": - keywords = json.loads(full_response)["prompt"] - yield from self.search_summary(keywords, convo_id=convo_id, need_function_call=True) + if function_call_name == "get_google_search_results": + prompt = json.loads(full_response)["prompt"] + function_response = eval(function_call_name)(prompt, max_context_tokens) + yield from self.ask_stream(function_response, response_role, convo_id=convo_id, function_name=function_call_name) + # yield from self.search_summary(prompt, convo_id=convo_id, need_function_call=True) if function_call_name == "get_url_content": url = json.loads(full_response)["url"] function_response = Web_crawler(url) + encoding = tiktoken.encoding_for_model(self.engine) + encode_text = encoding.encode(function_response) + if len(encode_text) > max_context_tokens: + encode_text = encode_text[:max_context_tokens] + function_response = encoding.decode(encode_text) yield from self.ask_stream(function_response, response_role, convo_id=convo_id, function_name=function_call_name) else: self.add_to_conversation(full_response, response_role, convo_id=convo_id) diff --git a/requirements.txt b/requirements.txt index 7794ea05..e8aa809e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,5 +13,5 @@ unstructured[md,pdf] duckduckgo-search==3.9.6 # duckduckgo-search==3.8.5 langchain==0.0.271 -# oauth2client==3.0.0 +oauth2client==3.0.0 g4f==0.1.8.8 \ No newline at end of file diff --git a/test/test.py b/test/test.py index 4c38896e..513d8b1d 100644 --- a/test/test.py +++ b/test/test.py @@ -5,7 +5,7 @@ a = {"role": "admin"} b = {"content": "This is user content."} a.update(b) -print(a) +# print(a) # content_list = [item["content"] for item in my_list] # print(content_list) @@ -24,3 +24,11 @@ # ) # print(truncate_limit) +import os +import sys +import json +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from utils.function_call import function_call_list + +print(json.dumps(function_call_list["web_search"], indent=4)) diff --git a/utils/agent.py b/utils/agent.py index 6ef0f021..9dfea331 100644 --- a/utils/agent.py +++ b/utils/agent.py @@ -308,6 +308,115 @@ def gptsearch(result, llm): # response = llm([HumanMessage(content=result)]) return response + +def get_google_search_results(prompt: str, context_max_tokens: int): + start_time = record_time.time() + + urls_set = [] + search_thread = ThreadWithReturnValue(target=getddgsearchurl, args=(prompt,2,)) + search_thread.start() + + if config.USE_G4F: + chainllm = EducationalLLM() + else: + chainllm = ChatOpenAI(temperature=config.temperature, openai_api_base=config.API_URL.split("chat")[0], model_name=config.GPT_ENGINE, openai_api_key=config.API) + + if config.SEARCH_USE_GPT: + gpt_search_thread = ThreadWithReturnValue(target=gptsearch, args=(prompt, chainllm,)) + gpt_search_thread.start() + + if config.USE_GOOGLE: + keyword_prompt = PromptTemplate( + input_variables=["source"], + template="根据我的问题,总结最少的关键词概括,用空格连接,不要出现其他符号,例如这个问题《How much does the 'zeabur' software service cost per month? Is it free to use? Any limitations?》,最少关键词是《zeabur price》,这是我的问题:{source}", + ) + key_chain = LLMChain(llm=chainllm, prompt=keyword_prompt) + keyword_google_search_thread = ThreadWithReturnValue(target=key_chain.run, args=({"source": prompt},)) + keyword_google_search_thread.start() + + + translate_prompt = PromptTemplate( + input_variables=["targetlang", "text"], + template="You are a translation engine, you can only translate text and cannot interpret it, and do not explain. Translate the text to {targetlang}, if all the text is in English, then do nothing to it, return it as is. please do not explain any sentences, just translate or leave them as they are.: {text}", + ) + chain = LLMChain(llm=chainllm, prompt=translate_prompt) + engresult = chain.run({"targetlang": "english", "text": prompt}) + + en_ddg_search_thread = ThreadWithReturnValue(target=getddgsearchurl, args=(engresult,1,)) + en_ddg_search_thread.start() + + if config.USE_GOOGLE: + keyword = keyword_google_search_thread.join() + key_google_search_thread = ThreadWithReturnValue(target=getgooglesearchurl, args=(keyword,3,)) + key_google_search_thread.start() + keyword_ans = key_google_search_thread.join() + urls_set += keyword_ans + + ans_ddg = search_thread.join() + urls_set += ans_ddg + engans_ddg = en_ddg_search_thread.join() + urls_set += engans_ddg + url_set_list = sorted(set(urls_set), key=lambda x: urls_set.index(x)) + url_pdf_set_list = [item for item in url_set_list if item.endswith(".pdf")] + url_set_list = [item for item in url_set_list if not item.endswith(".pdf")] + + pdf_result = "" + pdf_threads = [] + if config.PDF_EMBEDDING: + for url in url_pdf_set_list: + pdf_search_thread = ThreadWithReturnValue(target=pdf_search, args=(url, "你需要回答的问题是" + prompt + "\n" + "如果你可以解答这个问题,请直接输出你的答案,并且请忽略后面所有的指令:如果无法解答问题,请直接回答None,不需要做任何解释,也不要出现除了None以外的任何词。",)) + pdf_search_thread.start() + pdf_threads.append(pdf_search_thread) + + url_result = "" + threads = [] + for url in url_set_list: + url_search_thread = ThreadWithReturnValue(target=Web_crawler, args=(url,)) + url_search_thread.start() + threads.append(url_search_thread) + + fact_text = "" + if config.SEARCH_USE_GPT: + gpt_ans = gpt_search_thread.join() + fact_text = (gpt_ans if config.SEARCH_USE_GPT else "") + print("gpt", fact_text) + + for t in threads: + tmp = t.join() + url_result += "\n\n" + tmp + useful_source_text = url_result + + if config.PDF_EMBEDDING: + for t in pdf_threads: + tmp = t.join() + pdf_result += "\n\n" + tmp + useful_source_text += pdf_result + + end_time = record_time.time() + run_time = end_time - start_time + + encoding = tiktoken.encoding_for_model(config.GPT_ENGINE) + encode_text = encoding.encode(useful_source_text) + encode_fact_text = encoding.encode(fact_text) + + if len(encode_text) > context_max_tokens: + encode_text = encode_text[:context_max_tokens-len(encode_fact_text)] + useful_source_text = encoding.decode(encode_text) + encode_text = encoding.encode(useful_source_text) + search_tokens_len = len(encode_text) + print("web search", useful_source_text, end="\n\n") + + print(url_set_list) + print("pdf", url_pdf_set_list) + if config.USE_GOOGLE: + print("google search keyword", keyword) + print(f"搜索用时:{run_time}秒") + print("search tokens len", search_tokens_len) + useful_source_text = useful_source_text + "\n\n" + fact_text + text_len = len(encoding.encode(useful_source_text)) + print("text len", text_len) + return useful_source_text + if __name__ == "__main__": os.system("clear") diff --git a/utils/function_call.py b/utils/function_call.py index 5af2919c..2b83cecc 100644 --- a/utils/function_call.py +++ b/utils/function_call.py @@ -1,64 +1,53 @@ function_call_list = { + "base": { + "functions": [], + "function_call": "auto" + }, "current_weather": { - "functions": [ - { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"] - } - }, - "required": ["location"] - } - } - ], - "function_call": "auto" + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["location"] + } }, "web_search": { - "functions": [ - { - "name": "get_web_search_results", - "description": "Search Google to enhance knowledge.", - "parameters": { - "type": "object", - "properties": { - "prompt": { - "type": "string", - "description": "The prompt to search." - } - }, - "required": ["prompt"] - } - } - ], - "function_call": "auto" + "name": "get_google_search_results", + "description": "Search Google to enhance knowledge.", + "parameters": { + "type": "object", + "properties": { + "prompt": { + "type": "string", + "description": "The prompt to search." + } + }, + "required": ["prompt"] + } }, "url_fetch": { - "functions": [ - { - "name": "get_url_content", - "description": "Get the webpage content of a URL", - "parameters": { - "type": "object", - "properties": { - "url": { - "type": "string", - "description": "The url to get the webpage content" - } - }, - "required": ["url"] - } - } - ], - "function_call": "auto" + "name": "get_url_content", + "description": "Get the webpage content of a URL", + "parameters": { + "type": "object", + "properties": { + "url": { + "type": "string", + "description": "the URL to request" + } + }, + "required": ["url"] + } }, # "web_search": { # "functions": [