Skip to content

Commit

Permalink
Refactor token usage tracking code. Show in UI.
Browse files Browse the repository at this point in the history
  • Loading branch information
paulovcmedeiros committed Nov 11, 2023
1 parent 334bdd1 commit 3437bba
Show file tree
Hide file tree
Showing 9 changed files with 162 additions and 224 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
license = "MIT"
name = "pyrobbot"
readme = "README.md"
version = "0.1.1"
version = "0.1.2"

[build-system]
build-backend = "poetry.core.masonry.api"
Expand Down
30 changes: 25 additions & 5 deletions pyrobbot/app/app_page_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import uuid
from abc import ABC, abstractmethod
from json.decoder import JSONDecodeError
from loguru import logger

import streamlit as st
from loguru import logger
from PIL import Image

from pyrobbot import GeneralConstants
Expand Down Expand Up @@ -152,15 +152,28 @@ def render_chat_history(self):
with st.chat_message(role, avatar=self.avatars.get(role)):
st.markdown(message["content"])

def render(self):
def render_cost_estimate_page(self):
"""Render the estimated costs information in the chat."""
general_df = self.chat_obj.general_token_usage_db.get_usage_balance_dataframe()
chat_df = self.chat_obj.token_usage_db.get_usage_balance_dataframe()
dfs = {"All Recorded Chats": general_df, "Current Chat": chat_df}

st.header(dfs["Current Chat"].attrs["description"], divider="rainbow")
with st.container():
for category, df in dfs.items():
st.subheader(f"**{category}**")
st.dataframe(df)
st.write()
st.caption(df.attrs["disclaimer"])

def _render_chatbot_page(self):
"""Render a chatbot page.
Adapted from:
<https://docs.streamlit.io/knowledge-base/tutorials/build-conversational-apps>
"""
st.title(self.title)
st.divider()
st.header(self.title, divider="rainbow")

if self.chat_history:
self.render_chat_history()
Expand Down Expand Up @@ -230,4 +243,11 @@ def render(self):

self.title = title
self.sidebar_title = title
st.title(title)
st.header(title, divider="rainbow")

def render(self):
"""Render the app's chatbot or costs page, depending on user choice."""
if st.session_state.get("toggle_show_costs"):
self.render_cost_estimate_page()
else:
self._render_chatbot_page()
7 changes: 6 additions & 1 deletion pyrobbot/app/multipage.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ class MultipageChatbotApp(AbstractMultipageApp):

def init_openai_client(self):
"""Initializes the OpenAI client with the API key provided in the Streamlit UI."""
# Initialize the OpenAI API client
placeholher = (
"OPENAI_API_KEY detected"
if GeneralConstants.OPENAI_API_KEY
Expand Down Expand Up @@ -190,6 +189,12 @@ def render(self, **kwargs):
tab1, tab2 = st.tabs(["Chats", "Settings for Current Chat"])
self.sidebar_tabs = {"chats": tab1, "settings": tab2}
with tab1:
# Add button to show the costs table
st.toggle(
key="toggle_show_costs",
label=":moneybag:",
help="Show estimated token usage and associated costs",
)
# Add button to create a new chat
new_chat_button = st.button(label=":heavy_plus_sign: New Chat")

Expand Down
79 changes: 38 additions & 41 deletions pyrobbot/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import json
import shutil
import uuid
from collections import defaultdict
from pathlib import Path

from loguru import logger
Expand All @@ -12,7 +11,7 @@
from .chat_configs import ChatOptions
from .chat_context import EmbeddingBasedChatContext, FullHistoryChatContext
from .openai_utils import make_api_chat_completion_call
from .tokens import TokenUsageDatabase, get_n_tokens_from_msgs
from .tokens import TokenUsageDatabase


class Chat:
Expand Down Expand Up @@ -43,9 +42,6 @@ def __init__(self, configs: ChatOptions = None):

self.cache_dir.mkdir(parents=True, exist_ok=True)

self.token_usage = defaultdict(lambda: {"input": 0, "output": 0})
self.token_usage_db = TokenUsageDatabase(fpath=self.token_usage_db_path)

if self.context_model == "full-history":
self.context_handler = FullHistoryChatContext(parent_chat=self)
elif self.context_model == "text-embedding-ada-002":
Expand Down Expand Up @@ -76,6 +72,21 @@ def clear_cache(self):
"""Remove the cache directory."""
shutil.rmtree(self.cache_dir, ignore_errors=True)

@property
def token_usage_db_path(self):
"""Return the path to the chat's token usage database."""
return self.cache_dir / "chat_token_usage.db"

@property
def token_usage_db(self):
"""Return the chat's token usage database."""
return TokenUsageDatabase(fpath=self.token_usage_db_path)

@property
def general_token_usage_db(self):
"""Return the general token usage database for all chats."""
return TokenUsageDatabase(fpath=self.general_token_usage_db_path)

@property
def configs_file(self):
"""File to store the chat's configs."""
Expand Down Expand Up @@ -132,14 +143,6 @@ def base_directive(self):
return {"role": "system", "name": self.system_name, "content": msg_content}

def __del__(self):
# Store token usage to database
for model in [self.model, self.context_model]:
self.token_usage_db.insert_data(
model=model,
n_input_tokens=self.token_usage[model]["input"],
n_output_tokens=self.token_usage[model]["output"],
)

cache_empty = self.cache_dir.exists() and not next(
self.cache_dir.iterdir(), False
)
Expand Down Expand Up @@ -220,48 +223,25 @@ def respond_system_prompt(self, prompt: str, **kwargs):
def yield_response_from_msg(self, prompt_msg: dict, add_to_history: bool = True):
"""Yield response from a prompt message."""
# Get appropriate context for prompt from the context handler
prompt_context_request = self.context_handler.get_context(msg=prompt_msg)
context = prompt_context_request["context_messages"]

# Update token_usage with tokens used in context handler for prompt
self.token_usage[self.context_model]["input"] += sum(
prompt_context_request["tokens_usage"].values()
)

contextualised_prompt = [self.base_directive, *context, prompt_msg]
# Update token_usage with tokens used in chat input
self.token_usage[self.model]["input"] += get_n_tokens_from_msgs(
messages=contextualised_prompt, model=self.model
)
context = self.context_handler.get_context(msg=prompt_msg)

# Make API request and yield response chunks
full_reply_content = ""
for chunk in make_api_chat_completion_call(
conversation=contextualised_prompt, chat_obj=self
conversation=[self.base_directive, *context, prompt_msg], chat_obj=self
):
full_reply_content += chunk
yield chunk

# Update token_usage ith tokens used in chat output
reply_as_msg = {"role": "assistant", "content": full_reply_content}
self.token_usage[self.model]["output"] += get_n_tokens_from_msgs(
messages=[reply_as_msg], model=self.model
)

if add_to_history:
# Put current chat exchange in context handler's history
history_entry_reg_tokens_usage = self.context_handler.add_to_history(
self.context_handler.add_to_history(
msg_list=[
prompt_msg,
{"role": "assistant", "content": full_reply_content},
]
)

# Update token_usage with tokens used in context handler for reply
self.token_usage[self.context_model]["output"] += sum(
history_entry_reg_tokens_usage.values()
)

def start(self):
"""Start the chat."""
# ruff: noqa: T201
Expand All @@ -280,9 +260,26 @@ def start(self):
print("", end="\r")
logger.info("Exiting chat.")

def report_token_usage(self, current_chat: bool = True):
def report_token_usage(self, report_current_chat=True, report_general: bool = False):
"""Report token usage and associated costs."""
self.token_usage_db.print_usage_costs(self.token_usage, current_chat=current_chat)
dfs = {}
if report_general:
dfs[
"All Recorded Chats"
] = self.general_token_usage_db.get_usage_balance_dataframe()
if report_current_chat:
dfs["Current Chat"] = self.token_usage_db.get_usage_balance_dataframe()

if dfs:
for category, df in dfs.items():
header = f"{df.attrs['description']}: {category}"
table_separator = "=" * (len(header) + 4)
print(table_separator)
print(f" {header} ")
print(table_separator)
print(df)
print()
print(df.attrs["disclaimer"])

def _respond_prompt(self, prompt: str, role: str, **kwargs):
prompt_as_msg = {"role": role.lower().strip(), "content": prompt.strip()}
Expand Down
2 changes: 1 addition & 1 deletion pyrobbot/chat_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class ChatOptions(OpenAiApiCallOptions):
),
description="Initial instructions for the AI",
)
token_usage_db_path: Optional[Path] = Field(
general_token_usage_db_path: Optional[Path] = Field(
default=GeneralConstants.TOKEN_USAGE_DATABASE,
description="Path to the token usage database",
)
Expand Down
76 changes: 38 additions & 38 deletions pyrobbot/chat_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,81 +38,81 @@ def context_file_path(self):

def add_to_history(self, msg_list: list[dict]):
"""Add message exchange to history."""
embedding_request = self.request_embedding(msg_list=msg_list)
self.database.insert_message_exchange(
chat_model=self.parent_chat.model,
message_exchange=msg_list,
embedding=embedding_request["embedding"],
embedding=self.request_embedding(msg_list=msg_list),
)
return embedding_request["tokens_usage"]

def load_history(self) -> list[dict]:
"""Load the chat history."""
messages_df = self.database.get_messages_dataframe()
msg_exchanges = messages_df["message_exchange"].apply(ast.literal_eval).tolist()
return list(itertools.chain.from_iterable(msg_exchanges))

def get_context(self, msg: dict):
"""Return messages to serve as context for `msg` when requesting a completion."""
return _make_list_of_context_msgs(
history=self.select_relevant_history(msg=msg),
system_name=self.parent_chat.system_name,
)

@abstractmethod
def request_embedding(self, msg_list: list[dict]):
"""Request embedding from OpenAI API."""

@abstractmethod
def get_context(self, msg: dict):
"""Return context messages."""
def select_relevant_history(self, msg: dict):
"""Select chat history msgs to use as context for `msg`."""


class FullHistoryChatContext(ChatContext):
"""Context class using full chat history."""

def __init__(self, *args, **kwargs):
"""Initialise instance. Args and kwargs are passed to the parent class' `init`."""
super().__init__(*args, **kwargs)
self._placeholder_tokens_usage = {"input": 0, "output": 0}

# Implement abstract methods
def request_embedding(self, msg_list: list[dict]): # noqa: ARG002
"""Return a placeholder embedding request."""
return {"embedding": None, "tokens_usage": self._placeholder_tokens_usage}
"""Return a placeholder embedding."""
return

def get_context(self, msg: dict): # noqa: ARG002
"""Return context messages."""
context_msgs = _make_list_of_context_msgs(
history=self.load_history(), system_name=self.parent_chat.system_name
)
return {
"context_messages": context_msgs,
"tokens_usage": self._placeholder_tokens_usage,
}
def select_relevant_history(self, msg: dict): # noqa: ARG002
"""Select chat history msgs to use as context for `msg`."""
return self.load_history()


class EmbeddingBasedChatContext(ChatContext):
"""Chat context using embedding models."""

def _request_embedding_for_text(self, text: str):
return request_embedding_from_openai(text=text, model=self.embedding_model)
def request_embedding_for_text(self, text: str):
"""Request embedding for `text` from OpenAI according to used embedding model."""
embedding_request = request_embedding_from_openai(
text=text, model=self.embedding_model
)

# Update parent chat's token usage db with tokens used in embedding request
for db in [
self.parent_chat.general_token_usage_db,
self.parent_chat.token_usage_db,
]:
for comm_type, n_tokens in embedding_request["tokens_usage"].items():
input_or_output_kwargs = {f"n_{comm_type}_tokens": n_tokens}
db.insert_data(model=self.embedding_model, **input_or_output_kwargs)

return embedding_request["embedding"]

# Implement abstract methods
def request_embedding(self, msg_list: list[dict]):
"""Request embedding from OpenAI API."""
"""Convert `msg_list` into a paragraph and get embedding from OpenAI API call."""
text = "\n".join(
[f"{msg['role'].strip()}: {msg['content'].strip()}" for msg in msg_list]
)
return self._request_embedding_for_text(text=text)
return self.request_embedding_for_text(text=text)

def get_context(self, msg: dict):
"""Return context messages."""
embedding_request = self._request_embedding_for_text(text=msg["content"])
selected_history = _select_relevant_history(
def select_relevant_history(self, msg: dict):
"""Select chat history msgs to use as context for `msg`."""
return _select_relevant_history(
history_df=self.database.get_messages_dataframe(),
embedding=embedding_request["embedding"],
)
context_messages = _make_list_of_context_msgs(
history=selected_history, system_name=self.parent_chat.system_name
embedding=self.request_embedding_for_text(text=msg["content"]),
)
return {
"context_messages": context_messages,
"tokens_usage": embedding_request["tokens_usage"],
}


@retry_api_call()
Expand All @@ -138,7 +138,7 @@ def _make_list_of_context_msgs(history: list[dict], system_name: str):

def _select_relevant_history(
history_df: pd.DataFrame,
embedding: list[float],
embedding: np.ndarray,
max_n_prompt_reply_pairs: int = 5,
max_n_tailing_prompt_reply_pairs: int = 2,
):
Expand Down
5 changes: 3 additions & 2 deletions pyrobbot/command_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,17 @@
def accounting(args):
"""Show the accumulated costs of the chat and exit."""
chat = Chat.from_cli_args(cli_args=args)
# Prevent chat from creating entry in the cache directory
chat.private_mode = True
chat.report_token_usage(current_chat=False)
chat.report_token_usage(report_general=True, report_current_chat=False)


def run_on_terminal(args):
"""Run the chat on the terminal."""
chat = Chat.from_cli_args(cli_args=args)
chat.start()
if args.report_accounting_when_done:
chat.report_token_usage(current_chat=True)
chat.report_token_usage(report_general=True)


def run_on_ui(args):
Expand Down
Loading

0 comments on commit 3437bba

Please sign in to comment.