|
| 1 | +import asyncio |
| 2 | +import json |
| 3 | +from typing import List, Optional, Tuple |
| 4 | + |
| 5 | +import structlog |
| 6 | + |
| 7 | +from codegate.dashboard.request_models import ( |
| 8 | + ChatMessage, |
| 9 | + Conversation, |
| 10 | + PartialConversation, |
| 11 | + QuestionAnswer, |
| 12 | +) |
| 13 | +from codegate.db.queries import GetPromptWithOutputsRow |
| 14 | + |
| 15 | +logger = structlog.get_logger("codegate") |
| 16 | + |
| 17 | + |
| 18 | +SYSTEM_PROMPTS = [ |
| 19 | + "Given the following... please reply with a short summary that is 4-12 words in length, " |
| 20 | + "you should summarize what the user is asking for OR what the user is trying to accomplish. " |
| 21 | + "You should only respond with the summary, no additional text or explanation, " |
| 22 | + "you don't need ending punctuation.", |
| 23 | +] |
| 24 | + |
| 25 | + |
| 26 | +async def _is_system_prompt(message: str) -> bool: |
| 27 | + """ |
| 28 | + Check if the message is a system prompt. |
| 29 | + """ |
| 30 | + for prompt in SYSTEM_PROMPTS: |
| 31 | + if prompt in message or message in prompt: |
| 32 | + return True |
| 33 | + return False |
| 34 | + |
| 35 | + |
| 36 | +async def parse_request(request_str: str) -> Optional[str]: |
| 37 | + """ |
| 38 | + Parse the request string from the pipeline and return the message. |
| 39 | + """ |
| 40 | + try: |
| 41 | + request = json.loads(request_str) |
| 42 | + except Exception as e: |
| 43 | + logger.exception(f"Error parsing request: {e}") |
| 44 | + return None |
| 45 | + |
| 46 | + messages = [] |
| 47 | + for message in request.get("messages", []): |
| 48 | + role = message.get("role") |
| 49 | + if not role == "user": |
| 50 | + continue |
| 51 | + content = message.get("content") |
| 52 | + |
| 53 | + message_str = "" |
| 54 | + if isinstance(content, str): |
| 55 | + message_str = content |
| 56 | + elif isinstance(content, list): |
| 57 | + for content_part in content: |
| 58 | + if isinstance(content_part, dict) and content_part.get("type") == "text": |
| 59 | + message_str = content_part.get("text") |
| 60 | + |
| 61 | + if message_str and not await _is_system_prompt(message_str): |
| 62 | + messages.append(message_str) |
| 63 | + |
| 64 | + # We couldn't get anything from the messages, try the prompt |
| 65 | + if not messages: |
| 66 | + message_prompt = request.get("prompt", "") |
| 67 | + if message_prompt and not await _is_system_prompt(message_prompt): |
| 68 | + messages.append(message_prompt) |
| 69 | + |
| 70 | + # If still we don't have anything, return empty string |
| 71 | + if not messages: |
| 72 | + return None |
| 73 | + |
| 74 | + # Only respond with the latest message |
| 75 | + return messages[-1] |
| 76 | + |
| 77 | + |
| 78 | +async def parse_output(output_str: str) -> Tuple[Optional[str], Optional[str]]: |
| 79 | + """ |
| 80 | + Parse the output string from the pipeline and return the message and chat_id. |
| 81 | + """ |
| 82 | + try: |
| 83 | + output = json.loads(output_str) |
| 84 | + except Exception as e: |
| 85 | + logger.exception(f"Error parsing request: {e}") |
| 86 | + return None, None |
| 87 | + |
| 88 | + output_message = "" |
| 89 | + chat_id = None |
| 90 | + if isinstance(output, list): |
| 91 | + for output_chunk in output: |
| 92 | + if not isinstance(output_chunk, dict): |
| 93 | + continue |
| 94 | + chat_id = chat_id or output_chunk.get("id") |
| 95 | + for choice in output_chunk.get("choices", []): |
| 96 | + if not isinstance(choice, dict): |
| 97 | + continue |
| 98 | + delta_dict = choice.get("delta", {}) |
| 99 | + output_message += delta_dict.get("content", "") |
| 100 | + elif isinstance(output, dict): |
| 101 | + chat_id = chat_id or output.get("id") |
| 102 | + for choice in output.get("choices", []): |
| 103 | + if not isinstance(choice, dict): |
| 104 | + continue |
| 105 | + output_message += choice.get("message", {}).get("content", "") |
| 106 | + |
| 107 | + return output_message, chat_id |
| 108 | + |
| 109 | + |
| 110 | +async def parse_get_prompt_with_output( |
| 111 | + row: GetPromptWithOutputsRow, |
| 112 | +) -> Optional[PartialConversation]: |
| 113 | + """ |
| 114 | + Parse a row from the get_prompt_with_outputs query and return a PartialConversation |
| 115 | +
|
| 116 | + The row contains the raw request and output strings from the pipeline. |
| 117 | + """ |
| 118 | + async with asyncio.TaskGroup() as tg: |
| 119 | + request_task = tg.create_task(parse_request(row.request)) |
| 120 | + output_task = tg.create_task(parse_output(row.output)) |
| 121 | + |
| 122 | + request_msg_str = request_task.result() |
| 123 | + output_msg_str, chat_id = output_task.result() |
| 124 | + |
| 125 | + # If we couldn't parse the request or output, return None |
| 126 | + if not request_msg_str or not output_msg_str or not chat_id: |
| 127 | + return None |
| 128 | + |
| 129 | + request_message = ChatMessage( |
| 130 | + message=request_msg_str, |
| 131 | + timestamp=row.timestamp, |
| 132 | + message_id=row.id, |
| 133 | + ) |
| 134 | + output_message = ChatMessage( |
| 135 | + message=output_msg_str, |
| 136 | + timestamp=row.output_timestamp, |
| 137 | + message_id=row.output_id, |
| 138 | + ) |
| 139 | + question_answer = QuestionAnswer( |
| 140 | + question=request_message, |
| 141 | + answer=output_message, |
| 142 | + ) |
| 143 | + return PartialConversation( |
| 144 | + question_answer=question_answer, |
| 145 | + provider=row.provider, |
| 146 | + type=row.type, |
| 147 | + chat_id=chat_id, |
| 148 | + request_timestamp=row.timestamp, |
| 149 | + ) |
| 150 | + |
| 151 | + |
| 152 | +async def match_conversations( |
| 153 | + partial_conversations: List[Optional[PartialConversation]], |
| 154 | +) -> List[Conversation]: |
| 155 | + """ |
| 156 | + Match partial conversations to form a complete conversation. |
| 157 | + """ |
| 158 | + convers = {} |
| 159 | + for partial_conversation in partial_conversations: |
| 160 | + if not partial_conversation: |
| 161 | + continue |
| 162 | + |
| 163 | + # Group by chat_id |
| 164 | + if partial_conversation.chat_id not in convers: |
| 165 | + convers[partial_conversation.chat_id] = [] |
| 166 | + convers[partial_conversation.chat_id].append(partial_conversation) |
| 167 | + |
| 168 | + # Sort by timestamp |
| 169 | + sorted_convers = { |
| 170 | + chat_id: sorted(conversations, key=lambda x: x.request_timestamp) |
| 171 | + for chat_id, conversations in convers.items() |
| 172 | + } |
| 173 | + # Create the conversation objects |
| 174 | + conversations = [] |
| 175 | + for chat_id, sorted_convers in sorted_convers.items(): |
| 176 | + questions_answers = [] |
| 177 | + for partial_conversation in sorted_convers: |
| 178 | + questions_answers.append(partial_conversation.question_answer) |
| 179 | + conversations.append( |
| 180 | + Conversation( |
| 181 | + question_answers=questions_answers, |
| 182 | + provider=partial_conversation.provider, |
| 183 | + type=partial_conversation.type, |
| 184 | + chat_id=chat_id, |
| 185 | + conversation_timestamp=sorted_convers[0].request_timestamp, |
| 186 | + ) |
| 187 | + ) |
| 188 | + |
| 189 | + return conversations |
0 commit comments