-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patharchive.py
420 lines (362 loc) · 16.1 KB
/
archive.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
import asyncio
import io
import json
import os
import re
import time
from datetime import date
from difflib import IS_LINE_JUNK
from typing import Any
import openai
import openai as aierror
from dotenv import load_dotenv
from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from langchain.agents import AgentType
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
import agent
from langchain.callbacks.streaming_stdout_final_only
from langchain.memory import ConversationBufferWindowMemory
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.schema import LLMResult, OutputParserException
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_openai import AzureChatOpenAI
from pydantic import BaseModel
from agent.chatprompt import RESPONSE_PROMPT
from agent.util import CustomError, UtilFunctions
# loads local environment
load_dotenv(".env")
# instantiating important function calls
app = FastAPI()
util = UtilFunctions()
# system wide variables
environment = os.environ.get("ENVIRONMENT", "")
memory_len = int(os.environ.get("MEMORY_LENGTH", "3")) # Default to 5
temperature = float(os.environ.get("TEMPERATURE", "0.2"))
# following defines llm variables
verbose = False
model = os.getenv("AZURE_OPENAI_MODEL")
deployment_name = os.getenv("AZURE_OPENAI_DEPLOYMENT_ID")
openai.api_type = "azure"
openai.api_version = "2023-03-15-preview"
openai.api_base = os.getenv("AZURE_OPENAI_ENDPOINT")
openai.api_key = os.getenv("AZURE_OPENAI_KEY")
# setting up llm through azure
llm = AzureChatOpenAI(azure_deployment=deployment_name,
model=model,
temperature=temperature,
streaming=True,
max_tokens=100,
callback_manager=BaseCallbackManager([
StreamingStdOutCallbackHandler()]),
openai_api_key=openai.api_key,
verbose=verbose,
openai_api_version=openai.api_version,
top_p=1,
frequency_penalty=0.0,
presence_penalty=0.0)
# defines a memory based on the context of the chat
memory = ConversationBufferWindowMemory(
memory_key="chat_history",
context_key="context",
return_messages=True,
output_key="output",
)
# defines the llm agent
agent = initialize_agent(
agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION,
tools=[],
llm=llm,
verbose=True,
max_iterations=3,
early_stopping_method="generate",
memory=memory,
return_intermediate_steps=False,
handle_parsing_errors=True,
)
# Handles the output from asynchronous token generation
class AsyncCallbackHandler(AsyncIteratorCallbackHandler):
content: str = "" # Initializes content to accumulate tokens
final_answer: bool = False # Flag to indicate if the final answer has been reached
def __init__(self) -> None:
super().__init__()
# Handles each new token generated by the LLM
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
self.content += token
# If final answer has been reached, process incoming tokens
if self.final_answer:
# Filter out certain tokens for output handling
if token not in ['"', "}"]:
self.queue.put_nowait(token) # Add the token to the output queue for asynchronous handling
# Detect when the final answer starts
elif "Final Answer" in self.content:
self.final_answer = True # Set the flag to indicate final answer is in progress
self.content = "" # Clear the content for the final answer processing
# Handles the end of LLM's response
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
if self.final_answer:
self.content = "" # Clear the content after processing the final answer
self.final_answer = False # Reset the final answer flag
self.done.set() # Signal that the processing is complete
else:
self.content = "" # Reset content if final answer wasn't reached
# Function to run the query and stream the response
async def run_call(query: str, stream_it: AsyncCallbackHandler):
agent.agent.llm_chain.llm.callbacks = [stream_it]
await agent.acall(inputs={"input": query}) # Asynchronously call the agent with the query input
# Base model for the query request
class Query(BaseModel):
Request: str
# Function to create and handle the generation of responses from the LLM
async def create_gen(query: str, stream_it: AsyncCallbackHandler, callback=None):
task = asyncio.create_task(run_call(query, stream_it)) # Create an async task for the query execution
response: any = ""
# Asynchronously iterate over the generated tokens
async for token in stream_it.aiter():
if not token:
break # Stop if no more tokens are generated
yield token
response += token # Accumulate the tokens into the response
await task
# async function to handle callback
if callback:
async def handle_response_async(response):
try:
# Await the callback and process the response
result = await callback(response)
return result
except Exception as e:
print(f"Error during callback execution: {e}")
# Execute the callback asynchronously as a separate task
await asyncio.create_task(handle_response_async(response))
# Example callback function to handle the response
async def handle_response(response): # Example of a callback function (can be sync or async)
print("")
# handles the CORS middleware based on a regular expression defined
class RegexCORSMiddleware(CORSMiddleware):
def __init__(
self,
app: FastAPI,
allow_origins: list = ["*"],
allow_origin_regex: list = None,
allow_methods: list = None,
allow_headers: list = None,
expose_headers: list = None,
allow_credentials: bool = False,
max_age: int = 600,
):
super().__init__(
app,
allow_origins=allow_origins,
allow_methods=allow_methods,
allow_headers=allow_headers,
expose_headers=expose_headers,
allow_credentials=allow_credentials,
max_age=max_age,
)
self.allow_origin_regex = allow_origin_regex or []
async def is_allowed_origin(self, origin: str) -> bool:
if "*" in self.allow_origins:
return True
for regex in self.allow_origin_regex:
if await regex.match(origin):
return True
return False
# Define allowed origins using regular expressions
allowed_origins = [
re.compile(r"https?://localhost(:\d+)?"),
re.compile(r"https?://127.0.0.1(:\d+)?"), # Regex for localhost with optional port
]
# Enable CORS
app.add_middleware(
RegexCORSMiddleware,
allow_origins=allowed_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# formatting for input messages
class Message(BaseModel):
UserId: int
Request: str
Environment: str
FirstName: str
LastName: str
UserId: int
GUID: str
# main chat function for the managament and processing of the user input
@app.post("/api/dls-chat/v1")
async def chat(message: Message, request: Request, background_tasks: BackgroundTasks):
try:
user_id = int(message.UserId)
except Exception as err:
# Log the error type and message for debugging
print(f"Error: {type(err)}, {str(err)}")
print(f"An unexpected error occurred: {type(err).__name__}, {str(err)}")
question = message.Request.strip()
print(f"Question:{question}")
standard_question = ""
query = "Question: " + question
# Error handling
try:
start_time = time.time()
if not util.check_relevance(question):
error = (
"Hello! It looks like you've asked something I am not trained on "
f"'{question}'\n"
f"Please ask something related to editing or providing critique.</b> "
f"I'll do my best to assist you."
)
raise CustomError(error, error_type="CustomError")
# Code that may raise an error
if not util.check_valid_string(question):
error = (
"Hello! It looks like you've entered incomplete prompt "
f"'{question}'\n"
f'Replace the keyword "{{keyword}}" with an relevant term. '
f"I'll do my best to assist you."
)
raise CustomError(error, error_type="CustomError")
if len(question) <= 1:
error = (
"Hello! It looks like you've entered nothing. "
f"If you have a specific question or need help with "
f"something, \n Please feel free to elaborate, and I'll do my best to assist you. "
)
raise CustomError(error, error_type="CustomError")
if question.lower() == "hi" or question.lower() == "hello":
error = "Hello! How can I assist you today? "
raise CustomError(error, error_type="CustomError")
if len(question) <= 3 or util.is_junk_string(question):
error = (
"Hello! It looks like you've entered just the letters "
f"'{question}'\n"
f"If you have a specific question or need help with something, "
f"Please feel free to elaborate, and I'll do my best to assist you."
)
raise CustomError(error, error_type="CustomError")
if IS_LINE_JUNK(question):
error = (
"Hello! It looks like you've not entered anything. "
f"Can you provide more information or clarify your "
f"question? \n I'd be happy to help with whatever you need."
)
raise CustomError(error, error_type="CustomError")
# standardizes the question based on memory
chat_history = memory.load_memory_variables({})["chat_history"]
contextualize_q_system_prompt = """
Given a chat history and the latest user question \
which might reference context in the chat history, formulate a standalone question \
which can be understood without the chat history or If the questions is about individual. \
Do NOT answer the question, just reformulate it if needed and otherwise return it as is. \
Chat history: {chat_history}
"""
# contextualize based on prompting
contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
("system", contextualize_q_system_prompt),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{question}"),
]
)
contextualize_q_chain = contextualize_q_prompt | llm | StrOutputParser()
# looks back in chat history
if len(chat_history):
standard_question = await contextualize_q_chain.ainvoke(
{
"chat_history": chat_history,
"question": question,
}
)
# defines the star
if standard_question == "" or standard_question == " ":
standard_question = question
else:
chat_history.extend(
[
HumanMessage(content=question),
AIMessage(content=standard_question),
]
)
error_code = 200
res_message = "Success"
year = date.today().year
baseURL = os.environ.get("BaseURL")
prompt_response = ChatPromptTemplate.from_template(
RESPONSE_PROMPT
)
prompt_question = prompt_response.format(
question=standard_question,
baseURL = baseURL,
)
stream_it = AsyncCallbackHandler()
gen = create_gen(prompt_question, stream_it, callback=handle_response)
return StreamingResponse(gen, media_type="text/event-stream")
# Consolidated multiple exception blocks into a single except block for improved readability.
except Exception as err:
# Custom error handling
print("Error:", type(err))
error_message = {
aierror.ConflictError: (
"Issue connecting to our services, please try again later and report this error if possible."
),
aierror.NotFoundError: (
"Requested resource does not exist, please try again later and report this error if possible."
),
aierror.APIStatusError: (
"Something went wrong processing your request, please try again later."
),
aierror.AuthenticationError: (
"Your API key or token was invalid, expired, or revoked. Please try again later and report this error if possible."
),
aierror.InternalServerError: (
"Something went wrong processing your request, please try again later."
),
aierror.PermissionDeniedError: (
"No access to the requested resource, please try again later."
),
aierror.UnprocessableEntityError: (
"Something went wrong processing your request, please try again later."
),
aierror.BadRequestError: (
"vChat likely used too many words while processing your request. Try limiting how many results by adding something similar to 'only give me the top 3 results'."
),
aierror.RateLimitError: (
"vChat has exceeded the current quota, please try again after some time."
),
aierror.APITimeoutError: (
"vChat took too long to process your request, try again in a little while."
),
OutputParserException: (
"vChat ran into an issue parsing some text, try modifying your question."
),
aierror.APIConnectionError: (
f"Something went wrong with the OpenAI API, please try again later and report this error if possible."
),
aierror.APIError: (
f"Something went wrong with OpenAI API, please try again later and report this error if possible."
),
aierror.OpenAIError: (
f"Something went wrong with OpenAI API, please try again later and report this error if possible."
),
}
res_message, error_code = (str(err), 400) # Unpack error message and code
ai_response = error_message.get(type(err), res_message)
if ai_response is None:
if environment.lower() == "local":
ai_response = (
"An unknown error occured, please report this or try again later(.) "
+ "Error reason "
+ str(err)
)
else:
ai_response = "An unknown error occured while connecting to database, please report this and try again!!"
if environment.lower() == "local":
ai_response = (
ai_response + "<br> <b> Transcation ID - " + message.GUID + " </b>"
)
return StreamingResponse(io.StringIO(ai_response), status_code=error_code, media_type="text/event-stream")