Skip to content

Commit

Permalink
Merge pull request #175 from stacklok/setup-dashboard
Browse files Browse the repository at this point in the history
Added initial dashboard functionality
  • Loading branch information
aponcedeleonch authored Dec 4, 2024
2 parents 461d77a + 2aedfdf commit 334e18a
Show file tree
Hide file tree
Showing 8 changed files with 602 additions and 6 deletions.
11 changes: 11 additions & 0 deletions sql/queries/queries.sql
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,14 @@ LEFT JOIN outputs o ON p.id = o.prompt_id
LEFT JOIN alerts a ON p.id = a.prompt_id
WHERE p.id = ?
ORDER BY o.timestamp DESC, a.timestamp DESC;


-- name: GetPromptWithOutputs :many
SELECT
p.*,
o.id as output_id,
o.output,
o.timestamp as output_timestamp
FROM prompts p
LEFT JOIN outputs o ON p.id = o.prompt_id
ORDER BY o.timestamp DESC;
30 changes: 30 additions & 0 deletions src/codegate/dashboard/dashboard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import asyncio
from typing import List

import structlog
from fastapi import APIRouter

from codegate.dashboard.post_processing import match_conversations, parse_get_prompt_with_output
from codegate.dashboard.request_models import Conversation
from codegate.db.connection import DbReader

logger = structlog.get_logger("codegate")

dashboard_router = APIRouter(tags=["Dashboard"])
db_reader = DbReader()


@dashboard_router.get("/dashboard/messages")
async def get_messages() -> List[Conversation]:
"""
Get all the messages from the database and return them as a list of conversations.
"""
prompts_outputs = await db_reader.get_prompts_with_output()

# Parse the prompts and outputs in parallel
async with asyncio.TaskGroup() as tg:
tasks = [tg.create_task(parse_get_prompt_with_output(row)) for row in prompts_outputs]
partial_conversations = [task.result() for task in tasks]

conversations = await match_conversations(partial_conversations)
return conversations
189 changes: 189 additions & 0 deletions src/codegate/dashboard/post_processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import asyncio
import json
from typing import List, Optional, Tuple

import structlog

from codegate.dashboard.request_models import (
ChatMessage,
Conversation,
PartialConversation,
QuestionAnswer,
)
from codegate.db.queries import GetPromptWithOutputsRow

logger = structlog.get_logger("codegate")


SYSTEM_PROMPTS = [
"Given the following... please reply with a short summary that is 4-12 words in length, "
"you should summarize what the user is asking for OR what the user is trying to accomplish. "
"You should only respond with the summary, no additional text or explanation, "
"you don't need ending punctuation.",
]


async def _is_system_prompt(message: str) -> bool:
"""
Check if the message is a system prompt.
"""
for prompt in SYSTEM_PROMPTS:
if prompt in message or message in prompt:
return True
return False


async def parse_request(request_str: str) -> Optional[str]:
"""
Parse the request string from the pipeline and return the message.
"""
try:
request = json.loads(request_str)
except Exception as e:
logger.exception(f"Error parsing request: {e}")
return None

messages = []
for message in request.get("messages", []):
role = message.get("role")
if not role == "user":
continue
content = message.get("content")

message_str = ""
if isinstance(content, str):
message_str = content
elif isinstance(content, list):
for content_part in content:
if isinstance(content_part, dict) and content_part.get("type") == "text":
message_str = content_part.get("text")

if message_str and not await _is_system_prompt(message_str):
messages.append(message_str)

# We couldn't get anything from the messages, try the prompt
if not messages:
message_prompt = request.get("prompt", "")
if message_prompt and not await _is_system_prompt(message_prompt):
messages.append(message_prompt)

# If still we don't have anything, return empty string
if not messages:
return None

# Only respond with the latest message
return messages[-1]


async def parse_output(output_str: str) -> Tuple[Optional[str], Optional[str]]:
"""
Parse the output string from the pipeline and return the message and chat_id.
"""
try:
output = json.loads(output_str)
except Exception as e:
logger.exception(f"Error parsing request: {e}")
return None, None

output_message = ""
chat_id = None
if isinstance(output, list):
for output_chunk in output:
if not isinstance(output_chunk, dict):
continue
chat_id = chat_id or output_chunk.get("id")
for choice in output_chunk.get("choices", []):
if not isinstance(choice, dict):
continue
delta_dict = choice.get("delta", {})
output_message += delta_dict.get("content", "")
elif isinstance(output, dict):
chat_id = chat_id or output.get("id")
for choice in output.get("choices", []):
if not isinstance(choice, dict):
continue
output_message += choice.get("message", {}).get("content", "")

return output_message, chat_id


async def parse_get_prompt_with_output(
row: GetPromptWithOutputsRow,
) -> Optional[PartialConversation]:
"""
Parse a row from the get_prompt_with_outputs query and return a PartialConversation
The row contains the raw request and output strings from the pipeline.
"""
async with asyncio.TaskGroup() as tg:
request_task = tg.create_task(parse_request(row.request))
output_task = tg.create_task(parse_output(row.output))

request_msg_str = request_task.result()
output_msg_str, chat_id = output_task.result()

# If we couldn't parse the request or output, return None
if not request_msg_str or not output_msg_str or not chat_id:
return None

request_message = ChatMessage(
message=request_msg_str,
timestamp=row.timestamp,
message_id=row.id,
)
output_message = ChatMessage(
message=output_msg_str,
timestamp=row.output_timestamp,
message_id=row.output_id,
)
question_answer = QuestionAnswer(
question=request_message,
answer=output_message,
)
return PartialConversation(
question_answer=question_answer,
provider=row.provider,
type=row.type,
chat_id=chat_id,
request_timestamp=row.timestamp,
)


async def match_conversations(
partial_conversations: List[Optional[PartialConversation]],
) -> List[Conversation]:
"""
Match partial conversations to form a complete conversation.
"""
convers = {}
for partial_conversation in partial_conversations:
if not partial_conversation:
continue

# Group by chat_id
if partial_conversation.chat_id not in convers:
convers[partial_conversation.chat_id] = []
convers[partial_conversation.chat_id].append(partial_conversation)

# Sort by timestamp
sorted_convers = {
chat_id: sorted(conversations, key=lambda x: x.request_timestamp)
for chat_id, conversations in convers.items()
}
# Create the conversation objects
conversations = []
for chat_id, sorted_convers in sorted_convers.items():
questions_answers = []
for partial_conversation in sorted_convers:
questions_answers.append(partial_conversation.question_answer)
conversations.append(
Conversation(
question_answers=questions_answers,
provider=partial_conversation.provider,
type=partial_conversation.type,
chat_id=chat_id,
conversation_timestamp=sorted_convers[0].request_timestamp,
)
)

return conversations
47 changes: 47 additions & 0 deletions src/codegate/dashboard/request_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import datetime
from typing import List, Optional

from pydantic import BaseModel


class ChatMessage(BaseModel):
"""
Represents a chat message.
"""

message: str
timestamp: datetime.datetime
message_id: str


class QuestionAnswer(BaseModel):
"""
Represents a question and answer pair.
"""

question: ChatMessage
answer: ChatMessage


class PartialConversation(BaseModel):
"""
Represents a partial conversation obtained from a DB row.
"""

question_answer: QuestionAnswer
provider: Optional[str]
type: str
chat_id: str
request_timestamp: datetime.datetime


class Conversation(BaseModel):
"""
Represents a conversation.
"""

question_answers: List[QuestionAnswer]
provider: Optional[str]
type: str
chat_id: str
conversation_timestamp: datetime.datetime
34 changes: 29 additions & 5 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import uuid
from pathlib import Path
from typing import AsyncGenerator, AsyncIterator, Optional
from typing import AsyncGenerator, AsyncIterator, List, Optional

import structlog
from litellm import ChatCompletionRequest, ModelResponse
Expand All @@ -13,11 +13,12 @@
from sqlalchemy.ext.asyncio import create_async_engine

from codegate.db.models import Output, Prompt
from codegate.db.queries import AsyncQuerier, GetPromptWithOutputsRow

logger = structlog.get_logger("codegate")


class DbRecorder:
class DbCodeGate:

def __init__(self, sqlite_path: Optional[str] = None):
# Initialize SQLite database engine with proper async URL
Expand All @@ -27,6 +28,9 @@ def __init__(self, sqlite_path: Optional[str] = None):
else:
self._db_path = Path(sqlite_path).absolute()

# Initialize SQLite database engine with proper async URL
current_dir = Path(__file__).parent
self._db_path = (current_dir.parent.parent.parent / "codegate.db").absolute()
logger.debug(f"Initializing DB from path: {self._db_path}")
engine_dict = {
"url": f"sqlite+aiosqlite:///{self._db_path}",
Expand All @@ -35,13 +39,20 @@ def __init__(self, sqlite_path: Optional[str] = None):
}
self._async_db_engine = create_async_engine(**engine_dict)
self._db_engine = create_engine(**engine_dict)
if not self.does_db_exist():
logger.info(f"Database does not exist at {self._db_path}. Creating..")
asyncio.run(self.init_db())

def does_db_exist(self):
return self._db_path.is_file()


class DbRecorder(DbCodeGate):

def __init__(self, sqlite_path: Optional[str] = None):
super().__init__(sqlite_path)

if not self.does_db_exist():
logger.info(f"Database does not exist at {self._db_path}. Creating..")
asyncio.run(self.init_db())

async def init_db(self):
"""Initialize the database with the schema."""
if self.does_db_exist():
Expand Down Expand Up @@ -177,6 +188,19 @@ async def record_output_non_stream(
return await self._record_output(prompt, output_str)


class DbReader(DbCodeGate):

def __init__(self, sqlite_path: Optional[str] = None):
super().__init__(sqlite_path)

async def get_prompts_with_output(self) -> List[GetPromptWithOutputsRow]:
conn = await self._async_db_engine.connect()
querier = AsyncQuerier(conn)
prompts = [prompt async for prompt in querier.get_prompt_with_outputs()]
await conn.close()
return prompts


def init_db_sync():
"""DB will be initialized in the constructor in case it doesn't exist."""
db = DbRecorder()
Expand Down
Loading

0 comments on commit 334e18a

Please sign in to comment.