diff --git a/.flakeheaven.toml b/.flakeheaven.toml new file mode 100644 index 0000000..47c8298 --- /dev/null +++ b/.flakeheaven.toml @@ -0,0 +1,24 @@ +[tool.flakeheaven] + exclude = [".*/", "tmp/", "*/tmp/", "*.ipynb"] + format = "colored" + # Show line of source code in output, with syntax highlighting + show_source = true + style = "google" + +# list of plugins and rules for them +[tool.flakeheaven.plugins] + # Deactivate all rules for all plugins by default + "*" = ["-*"] + # Activate only those plugins not covered by ruff + pydoclint = [ + "+*", + "-DOC105", + "-DOC106", + "-DOC107", + "-DOC109", + "-DOC110", + "-DOC203", + "-DOC301", + "-DOC403", + "-DOC404", + ] diff --git a/.github/workflows/linting.yaml b/.github/workflows/linting.yaml new file mode 100644 index 0000000..aefdd41 --- /dev/null +++ b/.github/workflows/linting.yaml @@ -0,0 +1,68 @@ +#.github/workflows/linting.yaml +name: Linting Checks + +on: + pull_request: + branches: + - main + - develop + paths: + - '**.py' + - '.github/workflows/linting.yaml' + push: + branches: + - '**' # Every branch + paths: + - '**.py' + - '.github/workflows/linting.yaml' + +jobs: + linting: + if: github.repository_owner == 'paulovcmedeiros' + name: Run Linters + runs-on: ubuntu-latest + steps: + #---------------------------------------------- + # check-out repo and set-up python + #---------------------------------------------- + - name: Check out repository + uses: actions/checkout@v3 + - name: Set up python + id: setup-python + uses: actions/setup-python@v4 + with: + python-version: '3.9' + + #---------------------------------------------- + # --- configure poetry & install project ---- + #---------------------------------------------- + - name: Install Poetry + uses: snok/install-poetry@v1 + with: + virtualenvs-create: true + virtualenvs-in-project: true + + - name: Install poethepoet + run: poetry self add 'poethepoet[poetry_plugin]' + + - name: Load cached venv (if cache exists) + id: cached-poetry-dependencies + uses: actions/cache@v3 + with: + path: .venv + key: ${{ github.job }}-venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/pyproject.toml', '**/poetry.toml') }} + + - name: Install dependencies (if venv cache is not found) + if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' + run: poetry install --no-interaction --no-root --only main,linting + + - name: Install the project itself + run: poetry install --no-interaction --only-root + + #---------------------------------------------- + # Run the linting checks + #---------------------------------------------- + - name: Run linters + run: | + poetry devtools lint + diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml new file mode 100644 index 0000000..f59c8e4 --- /dev/null +++ b/.github/workflows/tests.yaml @@ -0,0 +1,72 @@ +#.github/workflows/tests.yaml +name: Unit Tests + +on: + pull_request: + branches: + - main + - develop + push: + branches: + - '**' # Every branch + +jobs: + tests: + if: github.repository_owner == 'paulovcmedeiros' + strategy: + fail-fast: true + matrix: + os: [ "ubuntu-latest" ] + env: [ "pytest" ] + python-version: [ "3.9" ] + + name: "${{ matrix.os }}, python=${{ matrix.python-version }}" + runs-on: ${{ matrix.os }} + + container: + image: python:${{ matrix.python-version }}-bullseye + env: + COVERAGE_FILE: ".coverage.${{ matrix.env }}.${{ matrix.python-version }}" + + steps: + #---------------------------------------------- + # check-out repo + #---------------------------------------------- + - name: Check out repository + uses: actions/checkout@v3 + + #---------------------------------------------- + # --- configure poetry & install project ---- + #---------------------------------------------- + - name: Install Poetry + uses: snok/install-poetry@v1 + with: + virtualenvs-create: true + virtualenvs-in-project: true + + - name: Load cached venv (if cache exists) + id: cached-poetry-dependencies + uses: actions/cache@v3 + with: + path: .venv + key: ${{ github.job }}-venv-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml', '**/poetry.toml') }} + + - name: Install dependencies (if venv cache is not found) + if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' + run: poetry install --no-interaction --no-root --only main,test + + - name: Install the project itself + run: poetry install --no-interaction --only-root + + #---------------------------------------------- + # run test suite and report coverage + #---------------------------------------------- + - name: Run tests + run: | + poetry run pytest + + - name: Upload test coverage report to Codecov + uses: codecov/codecov-action@v3 + with: + token: ${{ secrets.CODECOV_TOKEN }} + files: ./.coverage.xml diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b8c0e73 --- /dev/null +++ b/.gitignore @@ -0,0 +1,166 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ + +# Vim +*.swp + +# Temporary files and directories +tmp/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..dabe323 --- /dev/null +++ b/README.md @@ -0,0 +1,75 @@ +[![GitHub](https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white)](https://github.com/paulovcmedeiros/pyRobBot) + + +[![Contributors Welcome](https://img.shields.io/badge/Contributors-welcome-.svg)](https://github.com/paulovcmedeiros/pyRobBot/pulls) +[![Linting](https://github.com/paulovcmedeiros/pyRobBot/actions/workflows/linting.yaml/badge.svg)](https://github.com/paulovcmedeiros/pyRobBot/actions/workflows/linting.yaml) +[![Tests](https://github.com/paulovcmedeiros/pyRobBot/actions/workflows/tests.yaml/badge.svg)](https://github.com/paulovcmedeiros/pyRobBot/actions/workflows/tests.yaml) +[![codecov](https://codecov.io/gh/paulovcmedeiros/pyRobBot/graph/badge.svg?token=XI8G1WH9O6)](https://codecov.io/gh/paulovcmedeiros/pyRobBot) + +# pyRobBot + +A simple chatbot that uses the OpenAI API to get responses from [GPT LLMs](https://platform.openai.com/docs/models) via OpenAI API. Written in Python with a Web UI made with [Streamlit](https://streamlit.io). Can also be used directly from the terminal. + +See also the [online documentation](https://paulovcmedeiros.github.io/pyRobBot-docs). + +## Features +- [x] Web UI + - Add/remove conversations dynamically +- [x] Fully configurable + - Support for multiple GPT LLMs + - Control over the parameters passed to the OpenAI API, with (hopefully) sensible defaults + - Ability o modify the chat parameters in the same conversation + - Each conversation has its own parameters +- [x] Autosave and retrieve chat history +- [x] Chat context handling using [embeddings](https://platform.openai.com/docs/guides/embeddings) +- [x] Kepp track of estimated token usage and associated API call costs +- [x] Terminal UI + + +## System Requirements +- Python >= 3.9 +- A valid [OpenAI API key](https://platform.openai.com/account/api-keys) + - Set in the Web UI or through the environment variable `OPENAI_API_KEY` + +## Installation +### Using pip +```shell +pip install pyrobbot +``` + +### From source +```shell +pip install git+https://github.com/paulovcmedeiros/pyRobBot.git +``` + +## Basic Usage +Upon succesfull installation, you should be able to run +```shell +rob [opts] SUBCOMMAND [subcommand_opts] +``` +where `[opts]` and `[subcommand_opts]` denote optional command line arguments +that apply, respectively, to `rob` in general and to `SUBCOMMAND` +specifically. + +**Please run `rob -h` for information** about the supported subcommands +and general `rob` options. For info about specific subcommands and the +options that apply to them only, **please run `rob SUBCOMMAND -h`** (note +that the `-h` goes after the subcommand in this case). + +### Using the Web UI +```shell +rob +``` + +### Running on the Terminal +```shell +rob . +``` +## Disclaimers +This project's main purpose is to serve as a learning exercise for me (the author) and to serve as tool for and experimenting with OpenAI API and GPT LLMs. It does not aim to be the best or more robust OpenAI-powered chatbot out there. + +Having said this, this project *does* aim to have a friendly user interface and to be easy to use and configure. So, please feel free to open an issue or submit a pull request if you find a bug or have a suggestion. + +Last but not least: this project is **not** affiliated with OpenAI in any way. + + diff --git a/poetry.toml b/poetry.toml new file mode 100644 index 0000000..c880064 --- /dev/null +++ b/poetry.toml @@ -0,0 +1,7 @@ +[virtualenvs] + create = true + in-project = true + prefer-active-python = true + +[virtualenvs.options] + system-site-packages = false diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..b6fbe46 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,158 @@ +[tool.poetry] + authors = ["Paulo V C Medeiros "] + description = "A simple UI & terminal chatbot that uses the OpenAI API." + license = "MIT" + name = "pyrobbot" + readme = "README.md" + version = "0.1.0" + +[build-system] + build-backend = "poetry.core.masonry.api" + requires = ["poetry-core"] + +[tool.poetry.scripts] + rob = "pyrobbot.__main__:main" + +[tool.poetry.dependencies] + # Python version + python = ">=3.9,<3.9.7 || >3.9.7,<3.13" + # Deps that should have been openapi deps + matplotlib = "^3.8.0" + plotly = "^5.18.0" + scikit-learn = "^1.3.2" + scipy = "^1.11.3" + # Other dependencies + loguru = "^0.7.2" + numpy = "^1.26.1" + openai = "^0.28.1" + pandas = "^2.1.2" + pillow = "^10.1.0" + pydantic = "^2.4.2" + streamlit = "^1.28.0" + tiktoken = "^0.5.1" + +[tool.poetry.group.dev.dependencies] + ipython = "^8.16.1" + +[tool.poetry.group.linting.dependencies] + black = "^23.10.1" + flakeheaven = "^3.3.0" + isort = "^5.12.0" + pydoclint = "^0.3.8" + ruff = "^0.1.3" + +[tool.poetry.group.test.dependencies] + pytest = "^7.4.3" + pytest-cov = "^4.1.0" + pytest-mock = "^3.12.0" + pytest-order = "^1.1.0" + python-lorem = "^1.3.0.post1" + + ################## + # Linter configs # + ################## + +[tool.black] + line-length = 90 + +[tool.flakeheaven] + base = ".flakeheaven.toml" + +[tool.isort] + line_length = 90 + profile = "black" + +[tool.ruff] + # C901: Function is too complex. Ignoring this for now but will be removed later. + ignore = ["C901", "D105", "EXE001", "RET504", "RUF012"] + line-length = 90 + select = [ + "A", + "ARG", + "B", + "BLE", + "C4", + "C90", + "D", + "E", + "ERA", + "EXE", + "F", + "G", + "I", + "N", + "PD", + "PERF", + "PIE", + "PL", + "PT", + "Q", + "RET", + "RSE", + "RUF", + "S", + "SIM", + "SLF", + "T20", + "W", + ] + +[tool.ruff.per-file-ignores] + # S101: Use of `assert` detected + "tests/**/*.py" = [ + "D100", + "D101", + "D102", + "D103", + "D104", + "D105", + "D106", + "D107", + "E501", + "S101", + "SLF001", + ] + +[tool.ruff.pydocstyle] + convention = "google" + + ################## + # pytest configs # + ################## + +[tool.pytest.ini_options] + addopts = "-v --failed-first --cov-report=term-missing --cov-report=term:skip-covered --cov-report=xml:.coverage.xml --cov=./" + log_cli_level = "INFO" + testpaths = ["tests/smoke", "tests/unit"] + + #################################### + # Leave configs for `poe` separate # + #################################### + +[tool.poe] + poetry_command = "devtools" + +[tool.poe.tasks] + _black = "black ." + _isort = "isort ." + _ruff = "ruff check ." + # Test-related tasks + pytest = "pytest" + # Tasks to be run as pre-push checks + pre-push-checks = ["lint", "pytest"] + +[tool.poe.tasks._flake8] + cmd = "flakeheaven lint ." + env = {FLAKEHEAVEN_CACHE_TIMEOUT = "0"} + +[tool.poe.tasks.lint] + args = [{name = "fix", type = "boolean", default = false}] + control = {expr = "fix"} + +[[tool.poe.tasks.lint.switch]] + case = "True" + sequence = ["_isort", "_black", "_ruff --fix", "_flake8"] + +[[tool.poe.tasks.lint.switch]] + case = "False" + sequence = ["_isort --check-only", "_black --check --diff", "_ruff", "_flake8"] diff --git a/pyrobbot/__init__.py b/pyrobbot/__init__.py new file mode 100644 index 0000000..acf2c68 --- /dev/null +++ b/pyrobbot/__init__.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +"""Unnoficial OpenAI API UI and CLI tool.""" +import os +import tempfile +import uuid +from importlib.metadata import version +from pathlib import Path + +import openai + + +class GeneralConstants: + """General constants for the package.""" + + # Main package info + RUN_ID = uuid.uuid4().hex + PACKAGE_NAME = __name__ + VERSION = version(__name__) + + # Main package directories + PACKAGE_DIRECTORY = Path(__file__).parent + PACKAGE_CACHE_DIRECTORY = Path.home() / ".cache" / PACKAGE_NAME + _PACKAGE_TMPDIR = tempfile.TemporaryDirectory() + PACKAGE_TMPDIR = Path(_PACKAGE_TMPDIR.name) + CHAT_CACHE_DIR = PACKAGE_CACHE_DIRECTORY / "chats" + + # Constants related to the app + APP_NAME = "pyRobBot" + APP_DIR = PACKAGE_DIRECTORY / "app" + APP_PATH = APP_DIR / "app.py" + PARSED_ARGS_FILE = PACKAGE_TMPDIR / f"parsed_args_{RUN_ID}.pkl" + + # Constants related to using the OpenAI API + OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") + TOKEN_USAGE_DATABASE = PACKAGE_CACHE_DIRECTORY / "token_usage.db" + + # Initialise the package's directories + PACKAGE_TMPDIR.mkdir(parents=True, exist_ok=True) + PACKAGE_CACHE_DIRECTORY.mkdir(parents=True, exist_ok=True) + CHAT_CACHE_DIR.mkdir(parents=True, exist_ok=True) + + +# Initialize the OpenAI API client +openai.api_key = GeneralConstants.OPENAI_API_KEY diff --git a/pyrobbot/__main__.py b/pyrobbot/__main__.py new file mode 100644 index 0000000..9f012e3 --- /dev/null +++ b/pyrobbot/__main__.py @@ -0,0 +1,9 @@ +#!/usr/bin/env python3 +"""Program's entry point.""" +from .argparse_wrapper import get_parsed_args + + +def main(argv=None): + """Program's main routine.""" + args = get_parsed_args(argv=argv) + args.run_command(args=args) diff --git a/pyrobbot/app/.streamlit/config.toml b/pyrobbot/app/.streamlit/config.toml new file mode 100644 index 0000000..612800d --- /dev/null +++ b/pyrobbot/app/.streamlit/config.toml @@ -0,0 +1,18 @@ +# Stremlit configs. +# See . +[browser] + gatherUsageStats = false + +[runner] + fastReruns = true + +[server] + runOnSave = true + +[client] + showErrorDetails = true + +[theme] + base = "light" + # Colors + primaryColor = "#2BB5E8" diff --git a/pyrobbot/app/__init__.py b/pyrobbot/app/__init__.py new file mode 100644 index 0000000..5246618 --- /dev/null +++ b/pyrobbot/app/__init__.py @@ -0,0 +1 @@ +"""UI for the package.""" diff --git a/pyrobbot/app/app.py b/pyrobbot/app/app.py new file mode 100644 index 0000000..20a9161 --- /dev/null +++ b/pyrobbot/app/app.py @@ -0,0 +1,14 @@ +"""Entrypoint for the package's UI.""" +from pyrobbot import GeneralConstants +from pyrobbot.app.multipage import MultipageChatbotApp + + +def run_app(): + """Create and run an instance of the pacage's app.""" + MultipageChatbotApp( + page_title=GeneralConstants.APP_NAME, page_icon=":speech_balloon:" + ).render() + + +if __name__ == "__main__": + run_app() diff --git a/pyrobbot/app/app_page_templates.py b/pyrobbot/app/app_page_templates.py new file mode 100644 index 0000000..86095ac --- /dev/null +++ b/pyrobbot/app/app_page_templates.py @@ -0,0 +1,227 @@ +"""Utilities for creating pages in a streamlit app.""" +import sys +import uuid +from abc import ABC, abstractmethod + +import streamlit as st +from PIL import Image + +from pyrobbot import GeneralConstants +from pyrobbot.chat import Chat +from pyrobbot.chat_configs import ChatOptions +from pyrobbot.openai_utils import CannotConnectToApiError + +_AVATAR_FILES_DIR = GeneralConstants.APP_DIR / "data" +_ASSISTANT_AVATAR_FILE_PATH = _AVATAR_FILES_DIR / "assistant_avatar.png" +_USER_AVATAR_FILE_PATH = _AVATAR_FILES_DIR / "user_avatar.png" +_ASSISTANT_AVATAR_IMAGE = Image.open(_ASSISTANT_AVATAR_FILE_PATH) +_USER_AVATAR_IMAGE = Image.open(_USER_AVATAR_FILE_PATH) + + +# Sentinel object for when a chat is recovered from cache +_RecoveredChat = object() + + +class AppPage(ABC): + """Abstract base class for a page within a streamlit application.""" + + def __init__(self, sidebar_title: str = "", page_title: str = ""): + """Initializes a new instance of the AppPage class. + + Args: + sidebar_title (str, optional): The title to be displayed in the sidebar. + Defaults to an empty string. + page_title (str, optional): The title to be displayed on the page. + Defaults to an empty string. + """ + self.page_id = str(uuid.uuid4()) + self.page_number = st.session_state.get("n_created_pages", 0) + 1 + + chat_number_for_title = f"Chat #{self.page_number}" + if page_title is _RecoveredChat: + self.fallback_page_title = f"{chat_number_for_title.strip('#')} (Recovered)" + page_title = None + else: + self.fallback_page_title = chat_number_for_title + if page_title: + self.title = page_title + + self._fallback_sidebar_title = page_title if page_title else chat_number_for_title + if sidebar_title: + self.sidebar_title = sidebar_title + + @property + def state(self): + """Return the state of the page, for persistence of data.""" + if self.page_id not in st.session_state: + st.session_state[self.page_id] = {} + return st.session_state[self.page_id] + + @property + def sidebar_title(self): + """Get the title of the page in the sidebar.""" + return self.state.get("sidebar_title", self._fallback_sidebar_title) + + @sidebar_title.setter + def sidebar_title(self, value: str): + """Set the sidebar title for the page.""" + self.state["sidebar_title"] = value + + @property + def title(self): + """Get the title of the page.""" + return self.state.get("page_title", self.fallback_page_title) + + @title.setter + def title(self, value: str): + """Set the title of the page.""" + self.state["page_title"] = value + + @abstractmethod + def render(self): + """Create the page.""" + + +class ChatBotPage(AppPage): + """Implement a chatbot page in a streamlit application, inheriting from AppPage.""" + + def __init__( + self, chat_obj: Chat = None, sidebar_title: str = "", page_title: str = "" + ): + """Initialize new instance of the ChatBotPage class with an optional Chat object. + + Args: + chat_obj (Chat): The chat object. Defaults to None. + sidebar_title (str): The sidebar title for the chatbot page. + Defaults to an empty string. + page_title (str): The title for the chatbot page. + Defaults to an empty string. + """ + super().__init__(sidebar_title=sidebar_title, page_title=page_title) + + if chat_obj: + self.chat_obj = chat_obj + + self.avatars = {"assistant": _ASSISTANT_AVATAR_IMAGE, "user": _USER_AVATAR_IMAGE} + + @property + def chat_configs(self) -> ChatOptions: + """Return the configs used for the page's chat object.""" + if "chat_configs" not in self.state: + chat_options_file_path = sys.argv[-1] + self.state["chat_configs"] = ChatOptions.from_file(chat_options_file_path) + return self.state["chat_configs"] + + @chat_configs.setter + def chat_configs(self, value: ChatOptions): + self.state["chat_configs"] = ChatOptions.model_validate(value) + if "chat_obj" in self.state: + del self.state["chat_obj"] + + @property + def chat_obj(self) -> Chat: + """Return the chat object responsible for the queries in this page.""" + if "chat_obj" not in self.state: + self.chat_obj = Chat(self.chat_configs) + return self.state["chat_obj"] + + @chat_obj.setter + def chat_obj(self, new_chat_obj: Chat): + self.state["chat_obj"] = new_chat_obj + self.state["chat_configs"] = new_chat_obj.configs + + @property + def chat_history(self) -> list[dict[str, str]]: + """Return the chat history of the page.""" + if "messages" not in self.state: + self.state["messages"] = [] + return self.state["messages"] + + def render_chat_history(self): + """Render the chat history of the page. Do not include system messages.""" + for message in self.chat_history: + role = message["role"] + if role == "system": + continue + with st.chat_message(role, avatar=self.avatars.get(role)): + st.markdown(message["content"]) + + def render(self): + """Render a chatbot page. + + Adapted from: + + + """ + st.title(self.title) + st.divider() + + if self.chat_history: + self.render_chat_history() + else: + with st.chat_message("assistant", avatar=self.avatars["assistant"]): + st.markdown(self.chat_obj.initial_greeting) + self.chat_history.append( + { + "role": "assistant", + "name": self.chat_obj.assistant_name, + "content": self.chat_obj.initial_greeting, + } + ) + + # Accept user input + placeholder = ( + f"Send a message to {self.chat_obj.assistant_name} ({self.chat_obj.model})" + ) + if prompt := st.chat_input( + placeholder=placeholder, + on_submit=lambda: self.state.update({"chat_started": True}), + ): + # Display user message in chat message container + with st.chat_message("user", avatar=self.avatars["user"]): + st.markdown(prompt) + self.chat_history.append( + {"role": "user", "name": self.chat_obj.username, "content": prompt} + ) + + # Display (stream) assistant response in chat message container + with st.chat_message( + "assistant", avatar=self.avatars["assistant"] + ), st.empty(): + st.markdown("▌") + full_response = "" + try: + for chunk in self.chat_obj.respond_user_prompt(prompt): + full_response += chunk + st.markdown(full_response + "▌") + except CannotConnectToApiError: + full_response = self.chat_obj.api_connection_error_msg + finally: + st.markdown(full_response) + + self.chat_history.append( + { + "role": "assistant", + "name": self.chat_obj.assistant_name, + "content": full_response, + } + ) + + # Reset title according to conversation initial contents + min_history_len_for_summary = 3 + if ( + "page_title" not in self.state + and len(self.chat_history) > min_history_len_for_summary + ): + with st.spinner("Working out conversation topic..."): + prompt = "Summarize the messages in max 4 words.\n" + title = "".join( + self.chat_obj.respond_system_prompt(prompt, add_to_history=False) + ) + self.chat_obj.metadata["page_title"] = title + self.chat_obj.metadata["sidebar_title"] = title + self.chat_obj.save_cache() + + self.title = title + self.sidebar_title = title + st.title(title) diff --git a/pyrobbot/app/data/assistant_avatar.png b/pyrobbot/app/data/assistant_avatar.png new file mode 100644 index 0000000..9c6ee50 Binary files /dev/null and b/pyrobbot/app/data/assistant_avatar.png differ diff --git a/pyrobbot/app/data/user_avatar.png b/pyrobbot/app/data/user_avatar.png new file mode 100644 index 0000000..7edab9b Binary files /dev/null and b/pyrobbot/app/data/user_avatar.png differ diff --git a/pyrobbot/app/multipage.py b/pyrobbot/app/multipage.py new file mode 100644 index 0000000..fd0ed02 --- /dev/null +++ b/pyrobbot/app/multipage.py @@ -0,0 +1,369 @@ +"""Code for the creation streamlit apps with dynamically created pages.""" +import contextlib +from abc import ABC, abstractmethod + +import openai +import streamlit as st +from pydantic import ValidationError + +from pyrobbot import GeneralConstants +from pyrobbot.app.app_page_templates import AppPage, ChatBotPage, _RecoveredChat +from pyrobbot.chat import Chat +from pyrobbot.chat_configs import ChatOptions + + +class AbstractMultipageApp(ABC): + """Framework for creating streamlite multipage apps. + + Adapted from: + . + + """ + + def __init__(self, **kwargs) -> None: + """Initialise streamlit page configs.""" + st.set_page_config(**kwargs) + + @property + def n_created_pages(self): + """Return the number of pages created by the app, including deleted ones.""" + return st.session_state.get("n_created_pages", 0) + + @n_created_pages.setter + def n_created_pages(self, value): + st.session_state["n_created_pages"] = value + + @property + def pages(self) -> dict[AppPage]: + """Return the pages of the app.""" + if "available_pages" not in st.session_state: + st.session_state["available_pages"] = {} + return st.session_state["available_pages"] + + def add_page(self, page: AppPage, selected: bool = True): + """Add a page to the app.""" + self.pages[page.page_id] = page + self.n_created_pages += 1 + if selected: + self.register_selected_page(page) + + def remove_page(self, page: AppPage): + """Remove a page from the app.""" + self.pages[page.page_id].chat_obj.private_mode = True + self.pages[page.page_id].chat_obj.clear_cache() + del self.pages[page.page_id] + try: + self.register_selected_page(next(iter(self.pages.values()))) + except StopIteration: + self.add_page() + + def register_selected_page(self, page: AppPage): + """Register a page as selected.""" + st.session_state["selected_page"] = page + + @property + def selected_page(self) -> ChatBotPage: + """Return the selected page.""" + if "selected_page" not in st.session_state: + return next(iter(self.pages.values())) + return st.session_state["selected_page"] + + @abstractmethod + def handle_ui_page_selection(self, **kwargs): + """Control page selection in the UI sidebar.""" + + def render(self, **kwargs): + """Render the multipage app with focus on the selected page.""" + self.handle_ui_page_selection(**kwargs) + self.selected_page.render() + st.session_state["last_rendered_page"] = self.selected_page.page_id + + +class MultipageChatbotApp(AbstractMultipageApp): + """A Streamlit multipage app specifically for chatbot interactions. + + Inherits from AbstractMultipageApp and adds chatbot-specific functionalities. + + """ + + def init_openai_client(self): + """Initializes the OpenAI client with the API key provided in the Streamlit UI.""" + # Initialize the OpenAI API client + placeholher = ( + "OPENAI_API_KEY detected" + if GeneralConstants.OPENAI_API_KEY + else "You need this to use the chat" + ) + openai_api_key = st.text_input( + label="OpenAI API Key (required)", + placeholder=placeholher, + key="openai_api_key", + type="password", + help="[OpenAI API auth key](https://platform.openai.com/account/api-keys)", + ) + openai.api_key = ( + openai_api_key if openai_api_key else GeneralConstants.OPENAI_API_KEY + ) + if not openai.api_key: + st.write(":red[You need to provide a key to use the chat]") + + def add_page(self, page: ChatBotPage = None, selected: bool = True, **kwargs): + """Adds a new ChatBotPage to the app. + + If no page is specified, a new instance of ChatBotPage is created and added. + + Args: + page: The ChatBotPage to be added. If None, a new page is created. + selected: Whether the added page should be selected immediately. + **kwargs: Additional keyword arguments for ChatBotPage creation. + + Returns: + The result of the superclass's add_page method. + + """ + if page is None: + page = ChatBotPage(**kwargs) + return super().add_page(page=page, selected=selected) + + def get_widget_previous_value(self, widget_key, default=None): + """Get the previous value of a widget, if any.""" + if "widget_previous_value" not in self.selected_page.state: + self.selected_page.state["widget_previous_value"] = {} + return self.selected_page.state["widget_previous_value"].get(widget_key, default) + + def save_widget_previous_values(self, element_key): + """Save a widget's 'previous value`, to be read by `get_widget_previous_value`.""" + if "widget_previous_value" not in self.selected_page.state: + self.selected_page.state["widget_previous_value"] = {} + self.selected_page.state["widget_previous_value"][ + element_key + ] = st.session_state.get(element_key) + + def get_saved_chat_cache_dir_paths(self): + """Get the filepaths of saved chat contexts, sorted by last modified.""" + return sorted( + ( + directory + for directory in GeneralConstants.CHAT_CACHE_DIR.glob("chat_*/") + if next(directory.iterdir(), False) + ), + key=lambda fpath: fpath.stat().st_mtime, + reverse=True, + ) + + def handle_ui_page_selection(self): + """Control page selection and removal in the UI sidebar.""" + _set_button_style() + self._build_sidebar_tabs() + + with self.sidebar_tabs["settings"]: + caption = f"\u2699\uFE0F Settings for Chat #{self.selected_page.page_number}" + if self.selected_page.title != self.selected_page.fallback_page_title: + caption += f": {self.selected_page.title}" + st.caption(caption) + current_chat_configs = self.selected_page.chat_obj.configs + + # Present the user with the model and instructions fields first + field_names = ["model", "ai_instructions", "context_model"] + field_names += list(ChatOptions.model_fields) + field_names = list(dict.fromkeys(field_names)) + model_fields = {k: ChatOptions.model_fields[k] for k in field_names} + + updates_to_chat_configs = self._handle_chat_configs_value_selection( + current_chat_configs, model_fields + ) + + if updates_to_chat_configs: + new_chat_configs = current_chat_configs.model_dump() + new_chat_configs.update(updates_to_chat_configs) + new_chat = Chat.from_dict(new_chat_configs) + self.selected_page.chat_obj = new_chat + new_chat.save_cache() + + def render(self, **kwargs): + """Renders the multipage chatbot app in the UI according to the selected page.""" + with st.sidebar: + st.title(GeneralConstants.APP_NAME) + self.init_openai_client() + # Create a sidebar with tabs for chats and settings + tab1, tab2 = st.tabs(["Chats", "Settings for Current Chat"]) + self.sidebar_tabs = {"chats": tab1, "settings": tab2} + with tab1: + # Add button to create a new chat + new_chat_button = st.button(label=":heavy_plus_sign: New Chat") + + # Reopen chats from cache (if any) + if not st.session_state.get("saved_chats_reloaded", False): + st.session_state["saved_chats_reloaded"] = True + for cache_dir_path in self.get_saved_chat_cache_dir_paths(): + try: + chat = Chat.from_cache(cache_dir=cache_dir_path) + except ValidationError: + st.warning( + f"Failed to load cached chat {cache_dir_path}: " + + "Non-supported configs.", + icon="⚠️", + ) + continue + + new_page = ChatBotPage( + chat_obj=chat, + page_title=chat.metadata.get("page_title", _RecoveredChat), + sidebar_title=chat.metadata.get("sidebar_title"), + ) + new_page.state["messages"] = chat.load_history() + self.add_page(page=new_page) + self.register_selected_page(next(iter(self.pages.values()), None)) + + # Create a new chat upon request or if there is none yet + if new_chat_button or not self.pages: + self.add_page() + + return super().render(**kwargs) + + def _build_sidebar_tabs(self): + with self.sidebar_tabs["chats"]: + for page in self.pages.values(): + col1, col2 = st.columns([0.9, 0.1]) + with col1: + st.button( + label=page.sidebar_title, + key=f"select_{page.page_id}", + on_click=self.register_selected_page, + kwargs={"page": page}, + use_container_width=True, + disabled=page.page_id == self.selected_page.page_id, + ) + with col2: + st.button( + ":wastebasket:", + key=f"delete_{page.page_id}", + type="primary", + use_container_width=True, + on_click=self.remove_page, + kwargs={"page": page}, + help="Delete this chat.", + ) + + def _handle_chat_configs_value_selection(self, current_chat_configs, model_fields): + updates_to_chat_configs = {} + for field_name, field in model_fields.items(): + title = field_name.replace("_", " ").title() + choices = ChatOptions.get_allowed_values(field=field_name) + description = ChatOptions.get_description(field=field_name) + field_type = ChatOptions.get_type(field=field_name) + + # Check if the field is frozen and disable corresponding UI element if so + chat_started = self.selected_page.state.get("chat_started", False) + extra_info = field.json_schema_extra + if extra_info is None: + extra_info = {} + disable_ui_element = extra_info.get("frozen", False) and ( + chat_started + or any(msg["role"] == "user" for msg in self.selected_page.chat_history) + ) + + # Keep track of selected values so that selectbox doesn't reset + current_config_value = getattr(current_chat_configs, field_name) + element_key = f"{field_name}-pg-{self.selected_page.page_id}-ui-element" + widget_previous_value = self.get_widget_previous_value( + element_key, default=current_config_value + ) + if choices: + new_field_value = st.selectbox( + title, + key=element_key, + options=choices, + index=choices.index(widget_previous_value), + help=description, + disabled=disable_ui_element, + on_change=self.save_widget_previous_values, + args=[element_key], + ) + elif field_type == str: + new_field_value = st.text_input( + title, + key=element_key, + value=widget_previous_value, + help=description, + disabled=disable_ui_element, + on_change=self.save_widget_previous_values, + args=[element_key], + ) + elif field_type in [int, float]: + step = 1 if field_type == int else 0.01 + bounds = [None, None] + for item in field.metadata: + with contextlib.suppress(AttributeError): + bounds[0] = item.gt + step + with contextlib.suppress(AttributeError): + bounds[0] = item.ge + with contextlib.suppress(AttributeError): + bounds[1] = item.lt - step + with contextlib.suppress(AttributeError): + bounds[1] = item.le + + new_field_value = st.number_input( + title, + key=element_key, + value=widget_previous_value, + placeholder="OpenAI Default", + min_value=bounds[0], + max_value=bounds[1], + step=step, + help=description, + disabled=disable_ui_element, + on_change=self.save_widget_previous_values, + args=[element_key], + ) + elif field_type in (list, tuple): + prev_value = ( + widget_previous_value + if isinstance(widget_previous_value, str) + else "\n".join(widget_previous_value) + ) + new_field_value = st.text_area( + title, + value=prev_value.strip(), + key=element_key, + help=description, + disabled=disable_ui_element, + on_change=self.save_widget_previous_values, + args=[element_key], + ) + else: + continue + + if new_field_value != current_config_value: + if field_type in (list, tuple): + new_field_value = tuple(new_field_value.strip().split("\n")) + updates_to_chat_configs[field_name] = new_field_value + + return updates_to_chat_configs + + +def _set_button_style(): + """CSS styling for the buttons in the app.""" + st.markdown( + """ + + """, + unsafe_allow_html=True, + ) diff --git a/pyrobbot/argparse_wrapper.py b/pyrobbot/argparse_wrapper.py new file mode 100644 index 0000000..d029068 --- /dev/null +++ b/pyrobbot/argparse_wrapper.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +"""Wrappers for argparse functionality.""" +import argparse +import sys + +from . import GeneralConstants +from .chat_configs import ChatOptions +from .command_definitions import accounting, run_on_terminal, run_on_ui + + +def get_parsed_args(argv=None, default_command="ui"): + """Get parsed command line arguments. + + Args: + argv (list): A list of passed command line args. + default_command (str, optional): The default command to run. + + Returns: + argparse.Namespace: Parsed command line arguments. + + """ + if argv is None: + argv = sys.argv[1:] + if not argv: + argv = [default_command] + + chat_options_parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, add_help=False + ) + argarse2pydantic = { + "type": ChatOptions.get_type, + "default": ChatOptions.get_default, + "choices": ChatOptions.get_allowed_values, + "help": ChatOptions.get_description, + } + for field_name, field in ChatOptions.model_fields.items(): + args_opts = { + key: argarse2pydantic[key](field_name) + for key in argarse2pydantic + if argarse2pydantic[key](field_name) is not None + } + args_opts["required"] = field.is_required() + if "help" in args_opts: + args_opts["help"] = f"{args_opts['help']} (default: %(default)s)" + if "default" in args_opts and isinstance(args_opts["default"], (list, tuple)): + args_opts.pop("type", None) + args_opts["nargs"] = "*" + + chat_options_parser.add_argument(f"--{field_name.replace('_', '-')}", **args_opts) + + main_parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + main_parser.add_argument( + "--version", + "-v", + action="version", + version=f"{GeneralConstants.PACKAGE_NAME} v" + GeneralConstants.VERSION, + ) + + # Configure the main parser to handle the commands + subparsers = main_parser.add_subparsers( + title="commands", + dest="command", + required=True, + description=( + "Valid commands (note that commands also accept their " + + "own arguments, in particular [-h]):" + ), + help="command description", + ) + + parser_ui = subparsers.add_parser( + "ui", + aliases=["app"], + parents=[chat_options_parser], + help="Run the chat UI on the browser.", + ) + parser_ui.set_defaults(run_command=run_on_ui) + + parser_terminal = subparsers.add_parser( + "terminal", + aliases=["."], + parents=[chat_options_parser], + help="Run the chat on the terminal.", + ) + parser_terminal.add_argument( + "--report-accounting-when-done", + action="store_true", + help="Report estimated costs when done with the chat.", + ) + parser_terminal.set_defaults(run_command=run_on_terminal) + + parser_accounting = subparsers.add_parser( + "accounting", + aliases=["acc"], + help="Show the estimated number of used tokens and associated costs, and exit.", + ) + parser_accounting.set_defaults(run_command=accounting) + + return main_parser.parse_args(argv) diff --git a/pyrobbot/chat.py b/pyrobbot/chat.py new file mode 100644 index 0000000..15acad2 --- /dev/null +++ b/pyrobbot/chat.py @@ -0,0 +1,299 @@ +#!/usr/bin/env python3 +"""Implementation of the Chat class.""" +import json +import shutil +import uuid +from collections import defaultdict +from pathlib import Path + +from loguru import logger + +from . import GeneralConstants +from .chat_configs import ChatOptions +from .chat_context import EmbeddingBasedChatContext, FullHistoryChatContext +from .openai_utils import make_api_chat_completion_call +from .tokens import TokenUsageDatabase, get_n_tokens_from_msgs + + +class Chat: + """Manages conversations with an AI chat model. + + This class encapsulates the chat behavior, including handling the chat context, + managing cache directories, and interfacing with the OpenAI API for generating chat + responses. + """ + + def __init__(self, configs: ChatOptions = None): + """Initializes a chat instance. + + Args: + configs (ChatOptions, optional): The configurations for this chat session. + + Raises: + NotImplementedError: If the context model specified in configs is unknown. + """ + self.id = uuid.uuid4() + + if configs is None: + configs = ChatOptions() + + self._passed_configs = configs + for field in self._passed_configs.model_fields: + setattr(self, field, self._passed_configs[field]) + + self.cache_dir.mkdir(parents=True, exist_ok=True) + + self.token_usage = defaultdict(lambda: {"input": 0, "output": 0}) + self.token_usage_db = TokenUsageDatabase(fpath=self.token_usage_db_path) + + if self.context_model == "full-history": + self.context_handler = FullHistoryChatContext(parent_chat=self) + elif self.context_model == "text-embedding-ada-002": + self.context_handler = EmbeddingBasedChatContext(parent_chat=self) + else: + raise NotImplementedError(f"Unknown context model: {self.context_model}") + + @property + def cache_dir(self): + """Return the cache directory for this chat.""" + return self._cache_dir + + @cache_dir.setter + def cache_dir(self, value): + if value is None: + value = GeneralConstants.CHAT_CACHE_DIR / f"chat_{self.id}" + self._cache_dir = Path(value) + + def save_cache(self): + """Store the chat's configs and metadata to the cache directory.""" + self.configs.export(self.configs_file) + + metadata = self.metadata # Trigger loading metadata if not yet done + with open(self.metadata_file, "w") as metadata_f: + json.dump(metadata, metadata_f, indent=2) + + def clear_cache(self): + """Remove the cache directory.""" + shutil.rmtree(self.cache_dir, ignore_errors=True) + + @property + def configs_file(self): + """File to store the chat's configs.""" + return self.cache_dir / "configs.json" + + @property + def context_file_path(self): + """Return the path to the file that stores the chat context and history.""" + return self.cache_dir / "embeddings.db" + + @property + def metadata_file(self): + """File to store the chat metadata.""" + return self.cache_dir / "metadata.json" + + @property + def metadata(self): + """Keep metadata associated with the chat.""" + try: + _ = self._metadata + except AttributeError: + try: + with open(self.metadata_file, "r") as f: + self._metadata = json.load(f) + except (FileNotFoundError, json.decoder.JSONDecodeError): + self._metadata = {} + return self._metadata + + @property + def configs(self): + """Return the chat's configs after initialisation.""" + configs_dict = {} + for field_name in ChatOptions.model_fields: + configs_dict[field_name] = getattr(self, field_name) + return ChatOptions.model_validate(configs_dict) + + @property + def base_directive(self): + """Return the base directive for the LLM.""" + msg_content = " ".join( + [ + instruction.strip() + for instruction in [ + f"You are {self.assistant_name} (model {self.model}).", + f"You are a helpful assistant to {self.username}.", + " ".join( + [f"{instruct.strip(' .')}." for instruct in self.ai_instructions] + ), + f"You must remember and follow all directives by {self.system_name}.", + ] + if instruction.strip() + ] + ) + return {"role": "system", "name": self.system_name, "content": msg_content} + + def __del__(self): + # Store token usage to database + for model in [self.model, self.context_model]: + self.token_usage_db.insert_data( + model=model, + n_input_tokens=self.token_usage[model]["input"], + n_output_tokens=self.token_usage[model]["output"], + ) + + cache_empty = self.cache_dir.exists() and not next( + self.cache_dir.iterdir(), False + ) + if self.private_mode or cache_empty: + self.clear_cache() + else: + self.save_cache() + + @classmethod + def from_dict(cls, configs: dict): + """Creates a Chat instance from a configuration dictionary. + + Converts the configuration dictionary into a ChatOptions instance + and uses it to instantiate the Chat class. + + Args: + configs (dict): The chat configuration options as a dictionary. + + Returns: + Chat: An instance of Chat initialized with the given configurations. + """ + return cls(configs=ChatOptions.model_validate(configs)) + + @classmethod + def from_cli_args(cls, cli_args): + """Creates a Chat instance from CLI arguments. + + Extracts relevant options from the CLI arguments and initializes a Chat instance + with them. + + Args: + cli_args: The command line arguments. + + Returns: + Chat: An instance of Chat initialized with CLI-specified configurations. + """ + chat_opts = { + k: v + for k, v in vars(cli_args).items() + if k in ChatOptions.model_fields and v is not None + } + return cls.from_dict(chat_opts) + + @classmethod + def from_cache(cls, cache_dir: Path): + """Loads a chat instance from a cache directory. + + Args: + cache_dir (Path): The path to the cache directory. + + Returns: + Chat: An instance of Chat loaded with cached configurations and metadata. + """ + try: + with open(cache_dir / "configs.json", "r") as configs_f: + new = cls.from_dict(json.load(configs_f)) + except FileNotFoundError: + new = cls() + return new + + def load_history(self): + """Load chat history from cache.""" + return self.context_handler.load_history() + + @property + def initial_greeting(self): + """Return the initial greeting for the chat.""" + return f"Hello! I'm {self.assistant_name}. How can I assist you today?" + + def respond_user_prompt(self, prompt: str, **kwargs): + """Respond to a user prompt.""" + yield from self._respond_prompt(prompt=prompt, role="user", **kwargs) + + def respond_system_prompt(self, prompt: str, **kwargs): + """Respond to a system prompt.""" + yield from self._respond_prompt(prompt=prompt, role="system", **kwargs) + + def yield_response_from_msg(self, prompt_msg: dict, add_to_history: bool = True): + """Yield response from a prompt message.""" + # Get appropriate context for prompt from the context handler + prompt_context_request = self.context_handler.get_context(msg=prompt_msg) + context = prompt_context_request["context_messages"] + + # Update token_usage with tokens used in context handler for prompt + self.token_usage[self.context_model]["input"] += sum( + prompt_context_request["tokens_usage"].values() + ) + + contextualised_prompt = [self.base_directive, *context, prompt_msg] + # Update token_usage with tokens used in chat input + self.token_usage[self.model]["input"] += get_n_tokens_from_msgs( + messages=contextualised_prompt, model=self.model + ) + + # Make API request and yield response chunks + full_reply_content = "" + for chunk in make_api_chat_completion_call( + conversation=contextualised_prompt, chat_obj=self + ): + full_reply_content += chunk + yield chunk + + # Update token_usage ith tokens used in chat output + reply_as_msg = {"role": "assistant", "content": full_reply_content} + self.token_usage[self.model]["output"] += get_n_tokens_from_msgs( + messages=[reply_as_msg], model=self.model + ) + + if add_to_history: + # Put current chat exchange in context handler's history + history_entry_reg_tokens_usage = self.context_handler.add_to_history( + msg_list=[ + prompt_msg, + {"role": "assistant", "content": full_reply_content}, + ] + ) + + # Update token_usage with tokens used in context handler for reply + self.token_usage[self.context_model]["output"] += sum( + history_entry_reg_tokens_usage.values() + ) + + def start(self): + """Start the chat.""" + # ruff: noqa: T201 + print(f"{self.assistant_name}> {self.initial_greeting}\n") + try: + while True: + question = input(f"{self.username}> ").strip() + if not question: + continue + print(f"{self.assistant_name}> ", end="", flush=True) + for chunk in self.respond_user_prompt(prompt=question): + print(chunk, end="", flush=True) + print() + print() + except (KeyboardInterrupt, EOFError): + print("", end="\r") + logger.info("Exiting chat.") + + def report_token_usage(self, current_chat: bool = True): + """Report token usage and associated costs.""" + self.token_usage_db.print_usage_costs(self.token_usage, current_chat=current_chat) + + def _respond_prompt(self, prompt: str, role: str, **kwargs): + prompt_as_msg = {"role": role.lower().strip(), "content": prompt.strip()} + yield from self.yield_response_from_msg(prompt_as_msg, **kwargs) + + @property + def api_connection_error_msg(self): + """Return the error message for API connection errors.""" + return ( + "Sorry, I'm having trouble communicating with OpenAI. " + + "Please check the validity of your API key and try again." + + "If the problem persists, please also take a look at the " + + "OpenAI status page: https://status.openai.com." + ) diff --git a/pyrobbot/chat_configs.py b/pyrobbot/chat_configs.py new file mode 100644 index 0000000..42d3607 --- /dev/null +++ b/pyrobbot/chat_configs.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python3 +"""Registration and validation of options.""" +import argparse +import json +import types +import typing +from getpass import getuser +from pathlib import Path +from typing import Literal, Optional, get_args, get_origin + +from pydantic import BaseModel, Field + +from pyrobbot import GeneralConstants + + +class BaseConfigModel(BaseModel): + """Base model for configuring options.""" + + @classmethod + def get_allowed_values(cls, field: str): + """Return a tuple of allowed values for `field`.""" + annotation = cls._get_field_param(field=field, param="annotation") + if isinstance(annotation, type(Literal[""])): + return get_args(annotation) + return None + + @classmethod + def get_type(cls, field: str): + """Return type of `field`.""" + type_hint = typing.get_type_hints(cls)[field] + if isinstance(type_hint, type): + if isinstance(type_hint, types.GenericAlias): + return get_origin(type_hint) + return type_hint + type_hint_first_arg = get_args(type_hint)[0] + if isinstance(type_hint_first_arg, type): + return type_hint_first_arg + return None + + @classmethod + def get_default(cls, field: str): + """Return allowed value(s) for `field`.""" + return cls.model_fields[field].get_default() + + @classmethod + def get_description(cls, field: str): + """Return description of `field`.""" + return cls._get_field_param(field=field, param="description") + + @classmethod + def from_cli_args(cls, cli_args: argparse.Namespace): + """Return an instance of the class from CLI args.""" + relevant_args = { + k: v + for k, v in vars(cli_args).items() + if k in cls.model_fields and v is not None + } + return cls.model_validate(relevant_args) + + @classmethod + def _get_field_param(cls, field: str, param: str): + """Return param `param` of field `field`.""" + return getattr(cls.model_fields[field], param, None) + + def __getitem__(self, item): + """Make possible to retrieve values as in a dict.""" + try: + return getattr(self, item) + except AttributeError as error: + raise KeyError(item) from error + + def export(self, fpath: Path): + """Export the model's data to a file.""" + with open(fpath, "w") as configs_file: + configs_file.write(self.model_dump_json(indent=2, exclude_unset=True)) + + @classmethod + def from_file(cls, fpath: Path): + """Return an instance of the class given configs stored in a json file.""" + with open(fpath, "r") as configs_file: + return cls.model_validate(json.load(configs_file)) + + +class OpenAiApiCallOptions(BaseConfigModel): + """Model for configuring options for OpenAI API calls.""" + + _openai_url = "https://platform.openai.com/docs/api-reference/chat/create#chat-create" + _models_url = "https://platform.openai.com/docs/models" + + model: Literal[ + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-16k", # Will point to gpt-3.5-turbo-1106 starting Dec 11, 2023 + "gpt-3.5-turbo", # Will point to gpt-3.5-turbo-1106 starting Dec 11, 2023 + "gpt-4-1106-preview", + "gpt-4", + ] = Field( + default="gpt-3.5-turbo-1106", + description=f"OpenAI LLM model to use. See {_openai_url}-model and {_models_url}", + ) + max_tokens: Optional[int] = Field( + default=None, gt=0, description=f"See <{_openai_url}-max_tokens>" + ) + presence_penalty: Optional[float] = Field( + default=None, ge=-2.0, le=2.0, description=f"See <{_openai_url}-presence_penalty>" + ) + frequency_penalty: Optional[float] = Field( + default=None, + ge=-2.0, + le=2.0, + description=f"See <{_openai_url}-frequency_penalty>", + ) + temperature: Optional[float] = Field( + default=None, ge=0.0, le=2.0, description=f"See <{_openai_url}-temperature>" + ) + top_p: Optional[float] = Field( + default=None, ge=0.0, le=1.0, description=f"See <{_openai_url}-top_p>" + ) + request_timeout: Optional[float] = Field( + default=10.0, gt=0.0, description="Timeout for API requests in seconds" + ) + + +class ChatOptions(OpenAiApiCallOptions): + """Model for the chat's configuration options.""" + + username: str = Field(default=getuser(), description="Name of the chat's user") + assistant_name: str = Field(default="Rob", description="Name of the chat's assistant") + system_name: str = Field( + default=f"{GeneralConstants.PACKAGE_NAME}_system", + description="Name of the chat's system", + ) + context_model: Literal["text-embedding-ada-002", "full-history"] = Field( + default="text-embedding-ada-002", + description=( + "Model to use for chat context (~memory). " + + "Once picked, it cannot be changed." + ), + json_schema_extra={"frozen": True}, + ) + cache_dir: Optional[Path] = Field( + default=None, + description="Directory where to store/save info about the chat.", + ) + ai_instructions: tuple[str, ...] = Field( + default=( + "You answer correctly.", + "You do not lie.", + "You answer with the fewest tokens possible.", + ), + description="Initial instructions for the AI", + ) + token_usage_db_path: Optional[Path] = Field( + default=GeneralConstants.TOKEN_USAGE_DATABASE, + description="Path to the token usage database", + ) + api_connection_max_n_attempts: int = Field( + default=5, + gt=0, + description="Maximum number of attempts to connect to the OpenAI API", + ) + private_mode: Optional[bool] = Field( + default=None, + description="Toggle private mode. If set to `True`, the chat will not " + + "be logged and the chat history will not be saved.", + ) diff --git a/pyrobbot/chat_context.py b/pyrobbot/chat_context.py new file mode 100644 index 0000000..c9141e8 --- /dev/null +++ b/pyrobbot/chat_context.py @@ -0,0 +1,167 @@ +"""Chat context/history management.""" +import ast +import itertools +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +import numpy as np +import openai +import pandas as pd +from openai.embeddings_utils import cosine_similarity + +from .embeddings_database import EmbeddingsDatabase +from .openai_utils import retry_api_call + +if TYPE_CHECKING: + from .chat import Chat + + +class ChatContext(ABC): + """Abstract base class for representing the context of a chat.""" + + def __init__(self, parent_chat: "Chat"): + """Initialise the instance given a parent `Chat` object.""" + self.parent_chat = parent_chat + self.database = EmbeddingsDatabase( + db_path=self.context_file_path, embedding_model=self.embedding_model + ) + + @property + def embedding_model(self): + """Return the embedding model used for context management.""" + return self.parent_chat.context_model + + @property + def context_file_path(self): + """Return the path to the context file.""" + return self.parent_chat.context_file_path + + def add_to_history(self, msg_list: list[dict]): + """Add message exchange to history.""" + embedding_request = self.request_embedding(msg_list=msg_list) + self.database.insert_message_exchange( + chat_model=self.parent_chat.model, + message_exchange=msg_list, + embedding=embedding_request["embedding"], + ) + return embedding_request["tokens_usage"] + + def load_history(self) -> list[dict]: + """Load the chat history.""" + messages_df = self.database.get_messages_dataframe() + msg_exchanges = messages_df["message_exchange"].apply(ast.literal_eval).tolist() + return list(itertools.chain.from_iterable(msg_exchanges)) + + @abstractmethod + def request_embedding(self, msg_list: list[dict]): + """Request embedding from OpenAI API.""" + + @abstractmethod + def get_context(self, msg: dict): + """Return context messages.""" + + +class FullHistoryChatContext(ChatContext): + """Context class using full chat history.""" + + def __init__(self, *args, **kwargs): + """Initialise instance. Args and kwargs are passed to the parent class' `init`.""" + super().__init__(*args, **kwargs) + self._placeholder_tokens_usage = {"input": 0, "output": 0} + + # Implement abstract methods + def request_embedding(self, msg_list: list[dict]): # noqa: ARG002 + """Return a placeholder embedding request.""" + return {"embedding": None, "tokens_usage": self._placeholder_tokens_usage} + + def get_context(self, msg: dict): # noqa: ARG002 + """Return context messages.""" + context_msgs = _make_list_of_context_msgs( + history=self.load_history(), system_name=self.parent_chat.system_name + ) + return { + "context_messages": context_msgs, + "tokens_usage": self._placeholder_tokens_usage, + } + + +class EmbeddingBasedChatContext(ChatContext): + """Chat context using embedding models.""" + + def _request_embedding_for_text(self, text: str): + return request_embedding_from_openai(text=text, model=self.embedding_model) + + # Implement abstract methods + def request_embedding(self, msg_list: list[dict]): + """Request embedding from OpenAI API.""" + text = "\n".join( + [f"{msg['role'].strip()}: {msg['content'].strip()}" for msg in msg_list] + ) + return self._request_embedding_for_text(text=text) + + def get_context(self, msg: dict): + """Return context messages.""" + embedding_request = self._request_embedding_for_text(text=msg["content"]) + selected_history = _select_relevant_history( + history_df=self.database.get_messages_dataframe(), + embedding=embedding_request["embedding"], + ) + context_messages = _make_list_of_context_msgs( + history=selected_history, system_name=self.parent_chat.system_name + ) + return { + "context_messages": context_messages, + "tokens_usage": embedding_request["tokens_usage"], + } + + +@retry_api_call() +def request_embedding_from_openai(text: str, model: str): + """Request embedding for `text` according to context model `model` from OpenAI.""" + text = text.strip() + embedding_request = openai.Embedding.create(input=[text], model=model) + + embedding = embedding_request["data"][0]["embedding"] + + input_tokens = embedding_request["usage"]["prompt_tokens"] + output_tokens = embedding_request["usage"]["total_tokens"] - input_tokens + tokens_usage = {"input": input_tokens, "output": output_tokens} + + return {"embedding": embedding, "tokens_usage": tokens_usage} + + +def _make_list_of_context_msgs(history: list[dict], system_name: str): + sys_directives = "Considering the previous messages, answer the next message:" + sys_msg = {"role": "system", "name": system_name, "content": sys_directives} + return [*history, sys_msg] + + +def _select_relevant_history( + history_df: pd.DataFrame, + embedding: list[float], + max_n_prompt_reply_pairs: int = 5, + max_n_tailing_prompt_reply_pairs: int = 2, +): + history_df["embedding"] = ( + history_df["embedding"].apply(ast.literal_eval).apply(np.array) + ) + history_df["similarity"] = history_df["embedding"].apply( + lambda x: cosine_similarity(x, embedding) + ) + + # Get the last messages added to the history + df_last_n_chats = history_df.tail(max_n_tailing_prompt_reply_pairs) + + # Get the most similar messages + df_similar_chats = ( + history_df.sort_values("similarity", ascending=False) + .head(max_n_prompt_reply_pairs) + .sort_values("timestamp") + ) + + df_context = pd.concat([df_similar_chats, df_last_n_chats]) + selected_history = ( + df_context["message_exchange"].apply(ast.literal_eval).drop_duplicates() + ).tolist() + + return list(itertools.chain.from_iterable(selected_history)) diff --git a/pyrobbot/command_definitions.py b/pyrobbot/command_definitions.py new file mode 100644 index 0000000..62ddf2f --- /dev/null +++ b/pyrobbot/command_definitions.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +"""Commands supported by the package's script.""" +import subprocess + +from loguru import logger + +from . import GeneralConstants +from .chat import Chat +from .chat_configs import ChatOptions + + +def accounting(args): + """Show the accumulated costs of the chat and exit.""" + chat = Chat.from_cli_args(cli_args=args) + chat.private_mode = True + chat.report_token_usage(current_chat=False) + + +def run_on_terminal(args): + """Run the chat on the terminal.""" + chat = Chat.from_cli_args(cli_args=args) + chat.start() + if args.report_accounting_when_done: + chat.report_token_usage(current_chat=True) + + +def run_on_ui(args): + """Run the chat on the browser.""" + ChatOptions.from_cli_args(args).export(fpath=GeneralConstants.PARSED_ARGS_FILE) + try: + subprocess.run( + [ # noqa: S603, S607 + "streamlit", + "run", + GeneralConstants.APP_PATH.as_posix(), + "--", + GeneralConstants.PARSED_ARGS_FILE.as_posix(), + ], + cwd=GeneralConstants.APP_DIR.as_posix(), + check=True, + ) + except (KeyboardInterrupt, EOFError): + logger.info("Exiting.") diff --git a/pyrobbot/embeddings_database.py b/pyrobbot/embeddings_database.py new file mode 100644 index 0000000..45f3ec3 --- /dev/null +++ b/pyrobbot/embeddings_database.py @@ -0,0 +1,148 @@ +"""Management of embeddings/chat history storage and retrieval.""" +import datetime +import json +import sqlite3 +from pathlib import Path + +import pandas as pd + + +class EmbeddingsDatabase: + """Class for managing an SQLite database storing embeddings and associated data.""" + + def __init__(self, db_path: Path, embedding_model: str): + """Initialise the EmbeddingsDatabase object. + + Args: + db_path (Path): The path to the SQLite database file. + embedding_model (str): The embedding model associated with this database. + """ + self.db_path = db_path + self.embedding_model = embedding_model + self.create() + + def create(self): + """Create the necessary tables and triggers in the SQLite database.""" + self.db_path.parent.mkdir(parents=True, exist_ok=True) + conn = sqlite3.connect(self.db_path) + + # SQL to create 'embedding_model' table with 'embedding_model' as primary key + create_embedding_model_table = """ + CREATE TABLE IF NOT EXISTS embedding_model ( + created_timestamp INTEGER NOT NULL, + embedding_model TEXT NOT NULL, + PRIMARY KEY (embedding_model) + ) + """ + + # SQL to create 'messages' table + create_messages_table = """ + CREATE TABLE IF NOT EXISTS messages ( + timestamp INTEGER NOT NULL, + chat_model TEXT NOT NULL, + message_exchange TEXT NOT NULL, + embedding TEXT + ) + """ + + with conn: + # Create tables + conn.execute(create_embedding_model_table) + conn.execute(create_messages_table) + + # Triggers to prevent modification after insertion + conn.execute( + """ + CREATE TRIGGER IF NOT EXISTS prevent_embedding_model_modification + BEFORE UPDATE ON embedding_model + BEGIN + SELECT RAISE(FAIL, 'modification not allowed'); + END; + """ + ) + + conn.execute( + """ + CREATE TRIGGER IF NOT EXISTS prevent_messages_modification + BEFORE UPDATE ON messages + BEGIN + SELECT RAISE(FAIL, 'modification not allowed'); + END; + """ + ) + + # Close the connection to the database + conn.close() + + def get_embedding_model(self): + """Retrieve the database's embedding model. + + Returns: + str: The embedding model or None if teh database is not yet initialised. + """ + conn = sqlite3.connect(self.db_path) + query = "SELECT embedding_model FROM embedding_model;" + # Execute the query and fetch the result + embedding_model = None + with conn: + cur = conn.cursor() + cur.execute(query) + result = cur.fetchone() + embedding_model = result[0] if result else None + + conn.close() + + return embedding_model + + def insert_message_exchange(self, chat_model, message_exchange, embedding): + """Insert a message exchange into the database's 'messages' table. + + Args: + chat_model (str): The chat model. + message_exchange: The message exchange. + embedding: The embedding associated with the message exchange. + + Raises: + ValueError: If the database already contains a different embedding model. + """ + stored_embedding_model = self.get_embedding_model() + if stored_embedding_model is None: + self._init_database() + elif stored_embedding_model != self.embedding_model: + raise ValueError( + "Database already contains a different embedding model: " + f"{self.get_embedding_model()}.\n" + "Cannot continue." + ) + + timestamp = int(datetime.datetime.utcnow().timestamp()) + message_exchange = json.dumps(message_exchange) + embedding = json.dumps(embedding) + conn = sqlite3.connect(self.db_path) + sql = "INSERT INTO messages " + sql += "(timestamp, chat_model, message_exchange, embedding) VALUES (?, ?, ?, ?);" + with conn: + conn.execute(sql, (timestamp, chat_model, message_exchange, embedding)) + conn.close() + + def get_messages_dataframe(self): + """Retrieve msg exchanges from the `messages` table. Return them as a DataFrame. + + Returns: + pd.DataFrame: A DataFrame containing the message exchanges. + """ + conn = sqlite3.connect(self.db_path) + query = "SELECT * FROM messages;" + messages_df = pd.read_sql_query(query, conn) + conn.close() + return messages_df + + def _init_database(self): + """Initialise the 'embedding_model' table in the database.""" + conn = sqlite3.connect(self.db_path) + create_time = int(datetime.datetime.utcnow().timestamp()) + sql = "INSERT INTO embedding_model " + sql += "(created_timestamp, embedding_model) VALUES (?, ?);" + with conn: + conn.execute(sql, (create_time, self.embedding_model)) + conn.close() diff --git a/pyrobbot/openai_utils.py b/pyrobbot/openai_utils.py new file mode 100644 index 0000000..34aee2b --- /dev/null +++ b/pyrobbot/openai_utils.py @@ -0,0 +1,95 @@ +"""Utils for using the OpenAI API.""" +import inspect +import time +from functools import wraps +from typing import TYPE_CHECKING + +import openai +from loguru import logger + +from .chat_configs import OpenAiApiCallOptions + +if TYPE_CHECKING: + from .chat import Chat + + +class CannotConnectToApiError(Exception): + """Error raised when the package cannot connect to the OpenAI API.""" + + +def retry_api_call(max_n_attempts=5, auth_error_msg="Problems connecting to OpenAI API."): + """Retry connecting to the API up to a maximum number of times.""" + handled_exceptions = ( + openai.error.ServiceUnavailableError, + openai.error.Timeout, + openai.error.APIError, + ) + + def on_error(error, n_attempts): + if n_attempts < max_n_attempts: + logger.warning( + "{}. Making new attempt ({}/{})...", error, n_attempts + 1, max_n_attempts + ) + time.sleep(1) + else: + raise CannotConnectToApiError(auth_error_msg) from error + + def retry_api_call_decorator(function): + """Wrap `function` and log beginning, exit and elapsed time.""" + + @wraps(function) + def wrapper_f(*args, **kwargs): + n_attempts = 0 + while True: + n_attempts += 1 + try: + return function(*args, **kwargs) + except handled_exceptions as error: + on_error(error=error, n_attempts=n_attempts) + except openai.error.AuthenticationError as error: + raise CannotConnectToApiError(auth_error_msg) from error + + @wraps(function) + def wrapper_generator_f(*args, **kwargs): + n_attempts = 0 + success = False + while not success: + n_attempts += 1 + try: + yield from function(*args, **kwargs) + except handled_exceptions as error: + on_error(error=error, n_attempts=n_attempts) + except openai.error.AuthenticationError as error: + raise CannotConnectToApiError(auth_error_msg) from error + else: + success = True + + return wrapper_generator_f if inspect.isgeneratorfunction(function) else wrapper_f + + return retry_api_call_decorator + + +def make_api_chat_completion_call(conversation: list, chat_obj: "Chat"): + """Stream a chat completion from OpenAI API given a conversation and a chat object. + + Args: + conversation (list): A list of messages passed as input for the completion. + chat_obj (Chat): Chat object containing the configurations for the chat. + + Yields: + str: Chunks of text generated by the API in response to the conversation. + """ + api_call_args = {} + for field in OpenAiApiCallOptions.model_fields: + if getattr(chat_obj, field) is not None: + api_call_args[field] = getattr(chat_obj, field) + + @retry_api_call(auth_error_msg=chat_obj.api_connection_error_msg) + def stream_reply(conversation, **api_call_args): + for completion_chunk in openai.ChatCompletion.create( + messages=conversation, stream=True, **api_call_args + ): + reply_chunk = getattr(completion_chunk.choices[0].delta, "content", "") + yield reply_chunk + + yield from stream_reply(conversation, **api_call_args) diff --git a/pyrobbot/tokens.py b/pyrobbot/tokens.py new file mode 100644 index 0000000..e52a332 --- /dev/null +++ b/pyrobbot/tokens.py @@ -0,0 +1,264 @@ +"""Management of token usage and costs for OpenAI API.""" +import datetime +import sqlite3 +from pathlib import Path + +import pandas as pd +import tiktoken + +# See for the latest prices. +PRICE_PER_K_TOKENS = { + "gpt-3.5-turbo": {"input": 0.0015, "output": 0.002}, + "gpt-3.5-turbo-16k": {"input": 0.001, "output": 0.002}, + "gpt-3.5-turbo-1106": {"input": 0.001, "output": 0.002}, + "gpt-4-1106-preview": {"input": 0.03, "output": 0.06}, + "gpt-4": {"input": 0.03, "output": 0.06}, + "text-embedding-ada-002": {"input": 0.0001, "output": 0.0}, + "full-history": {"input": 0.0, "output": 0.0}, +} + + +class TokenUsageDatabase: + """Manages a database to store estimated token usage and costs for OpenAI API.""" + + def __init__(self, fpath: Path): + """Initialize a TokenUsageDatabase instance.""" + self.fpath = fpath + self.token_price = {} + for model, price_per_k_tokens in PRICE_PER_K_TOKENS.items(): + self.token_price[model] = { + k: v / 1000.0 for k, v in price_per_k_tokens.items() + } + + self.create() + + def create(self): + """Create the database if it doesn't exist.""" + self.fpath.parent.mkdir(parents=True, exist_ok=True) + conn = sqlite3.connect(self.fpath) + cursor = conn.cursor() + + # Create a table to store the data with 'timestamp' as the primary key + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS token_costs ( + timestamp REAL PRIMARY KEY, + model TEXT, + n_input_tokens INTEGER, + n_output_tokens INTEGER, + cost_input_tokens REAL, + cost_output_tokens REAL + ) + """ + ) + + conn.commit() + conn.close() + + def insert_data(self, model, n_input_tokens, n_output_tokens): + """Insert the data into the token_costs table.""" + if model is None: + return + + conn = sqlite3.connect(self.fpath) + cursor = conn.cursor() + + # Insert the data into the table + cursor.execute( + """ + INSERT OR REPLACE INTO token_costs ( + timestamp, + model, + n_input_tokens, + n_output_tokens, + cost_input_tokens, + cost_output_tokens + ) + VALUES (?, ?, ?, ?, ?, ?) + """, + ( + datetime.datetime.utcnow().timestamp(), + model, + n_input_tokens, + n_output_tokens, + n_input_tokens * self.token_price[model]["input"], + n_output_tokens * self.token_price[model]["output"], + ), + ) + + conn.commit() + conn.close() + + def retrieve_sums_by_model(self): + """Retrieve the sums of tokens and costs by each model.""" + conn = sqlite3.connect(self.fpath) + cursor = conn.cursor() + + cursor.execute( + """ + SELECT + model, + MIN(timestamp) AS earliest_timestamp, + SUM(n_input_tokens) AS total_n_input_tokens, + SUM(n_output_tokens) AS total_n_output_tokens, + SUM(cost_input_tokens) AS total_cost_input_tokens, + SUM(cost_output_tokens) AS total_cost_output_tokens + FROM token_costs + GROUP BY model + """ + ) + + data = cursor.fetchall() + + conn.close() + + sums_by_model = {} + for row in data: + model_name = row[0] + sums = { + "earliest_timestamp": row[1], + "n_input_tokens": row[2], + "n_output_tokens": row[3], + "cost_input_tokens": row[4], + "cost_output_tokens": row[5], + } + sums_by_model[model_name] = sums + + return sums_by_model + + def get_usage_balance_dataframe(self): + """Get a dataframe with the accumulated token usage and costs.""" + sums_by_model = self.retrieve_sums_by_model() + df_rows = [] + for model, accumulated_usage in sums_by_model.items(): + if model is None: + continue + + accumulated_tokens_usage = { + "input": accumulated_usage["n_input_tokens"], + "output": accumulated_usage["n_output_tokens"], + } + accumlated_costs = { + "input": accumulated_usage["cost_input_tokens"], + "output": accumulated_usage["cost_output_tokens"], + } + first_used = datetime.datetime.fromtimestamp( + accumulated_usage["earliest_timestamp"], datetime.timezone.utc + ).isoformat(sep=" ", timespec="seconds") + df_row = { + "Model": model, + "First Registered Use": first_used.replace("+00:00", "Z"), + "Tokens: Input": accumulated_tokens_usage["input"], + "Tokens: Output": accumulated_tokens_usage["output"], + "Tokens: Total": sum(accumulated_tokens_usage.values()), + "Cost ($): Input": accumlated_costs["input"], + "Cost ($): Output": accumlated_costs["output"], + "Cost ($): Total": sum(accumlated_costs.values()), + } + df_rows.append(df_row) + + usage_df = pd.DataFrame(df_rows) + if not usage_df.empty: + usage_df = _add_totals_row(_group_columns_by_prefix(usage_df)) + + return usage_df + + def get_current_chat_usage_dataframe(self, token_usage_per_model: dict): + """Get a dataframe with the current chat's token usage and costs.""" + df_rows = [] + for model, token_usage in token_usage_per_model.items(): + if model is None: + continue + + costs = {k: v * self.token_price[model][k] for k, v in token_usage.items()} + df_row = { + "Model": model, + "Tokens: Input": token_usage["input"], + "Tokens: Output": token_usage["output"], + "Tokens: Total": sum(token_usage.values()), + "Cost ($): Input": costs["input"], + "Cost ($): Output": costs["output"], + "Cost ($): Total": sum(costs.values()), + } + df_rows.append(df_row) + chat_usage_df = pd.DataFrame(df_rows) + if df_rows: + chat_usage_df = _group_columns_by_prefix(chat_usage_df.set_index("Model")) + chat_usage_df = _add_totals_row(chat_usage_df) + return chat_usage_df + + def print_usage_costs(self, token_usage: dict, current_chat: bool = True): + """Print the estimated token usage and costs.""" + header_start = "Estimated token usage and associated costs" + header2dataframe = { + f"{header_start}: Accumulated": self.get_usage_balance_dataframe(), + f"{header_start}: Current Chat": self.get_current_chat_usage_dataframe( + token_usage + ), + } + + for header, df in header2dataframe.items(): + if "current" in header.lower() and not current_chat: + continue + _print_df(df=df, header=header) + + print() + print("Note: These are only estimates. Actual costs may vary.") + link = "https://platform.openai.com/account/usage" + print(f"Please visit <{link}> to follow your actual usage and costs.") + + +def get_n_tokens_from_msgs(messages: list[dict], model: str): + """Returns the number of tokens used by a list of messages.""" + # Adapted from + # + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + encoding = tiktoken.get_encoding("cl100k_base") + + # OpenAI's original function was implemented for gpt-3.5-turbo-0613, but we'll use + # it for all models for now. We are only intereste dinestimates, after all. + num_tokens = 0 + for message in messages: + # every message follows {role/name}\n{content}\n + num_tokens += 4 + for key, value in message.items(): + num_tokens += len(encoding.encode(value)) + if key == "name": # if there's a name, the role is omitted + num_tokens += -1 # role is always required and always 1 token + num_tokens += 2 # every reply is primed with assistant + return num_tokens + + +def _group_columns_by_prefix(dataframe: pd.DataFrame): + dataframe = dataframe.copy() + col_tuples_for_multiindex = dataframe.columns.str.split(": ", expand=True).to_numpy() + dataframe.columns = pd.MultiIndex.from_tuples( + [("", x[0]) if pd.isna(x[1]) else x for x in col_tuples_for_multiindex] + ) + return dataframe + + +def _add_totals_row(accounting_df: pd.DataFrame): + accounting_df = accounting_df.copy() + dtypes = accounting_df.dtypes + accounting_df.loc["Total"] = accounting_df.sum(numeric_only=True) + for col in accounting_df.columns: + accounting_df[col] = accounting_df[col].astype(dtypes[col]) + accounting_df = accounting_df.fillna("") + return accounting_df + + +def _print_df(df: pd.DataFrame, header: str): + # ruff: noqa: T201 + underline = "-" * len(header) + print() + print(underline) + print(header) + print(underline) + if df.empty or df.loc["Total"]["Tokens"]["Total"] == 0: + print("None.") + else: + print(df) + print() diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..cbb018a --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,122 @@ +import os + +import lorem +import numpy as np +import openai +import pytest + +import pyrobbot +from pyrobbot.chat import Chat +from pyrobbot.chat_configs import ChatOptions + + +# Register markers and constants +def pytest_configure(config): + config.addinivalue_line( + "markers", + "no_chat_completion_create_mocking: do not mock openai.ChatCompletion.create", + ) + config.addinivalue_line( + "markers", + "no_embedding_create_mocking: mark test to not mock openai.Embedding.create", + ) + + pytest.ORIGINAL_PACKAGE_CACHE_DIRECTORY = ( + pyrobbot.GeneralConstants.PACKAGE_CACHE_DIRECTORY + ) + + +@pytest.fixture(autouse=True) +def _set_env(): + # Make sure we don't consume our tokens in tests + os.environ["OPENAI_API_KEY"] = "INVALID_API_KEY" + openai.api_key = os.environ["OPENAI_API_KEY"] + + +@pytest.fixture(autouse=True) +def _mocked_general_constants(tmp_path): + pyrobbot.GeneralConstants.PACKAGE_CACHE_DIRECTORY = tmp_path / "cache" + + +@pytest.fixture(autouse=True) +def _openai_api_request_mockers(request, mocker): + """Mockers for OpenAI API requests. We don't want to consume our tokens in tests.""" + + def _mock_openai_chat_completion_create(*args, **kwargs): # noqa: ARG001 + """Mock `openai.ChatCompletion.create`. Yield from lorem ipsum instead.""" + completion_chunk = type("CompletionChunk", (), {}) + completion_chunk_choice = type("CompletionChunkChoice", (), {}) + completion_chunk_choice_delta = type("CompletionChunkChoiceDelta", (), {}) + for word in lorem.get_paragraph().split(): + completion_chunk_choice_delta.content = word + " " + completion_chunk_choice.delta = completion_chunk_choice_delta + completion_chunk.choices = [completion_chunk_choice] + yield completion_chunk + + def _mock_openai_embedding_create(*args, **kwargs): # noqa: ARG001 + """Mock `openai.Embedding.create`. Yield from lorem ipsum instead.""" + embedding_request = { + "data": [{"embedding": np.random.rand(512).tolist()}], + "usage": {"prompt_tokens": 0, "total_tokens": 0}, + } + return embedding_request + + if "no_chat_completion_create_mocking" not in request.keywords: + mocker.patch( + "openai.ChatCompletion.create", new=_mock_openai_chat_completion_create + ) + if "no_embedding_create_mocking" not in request.keywords: + mocker.patch("openai.Embedding.create", new=_mock_openai_embedding_create) + + +@pytest.fixture() +def _input_builtin_mocker(mocker, user_input): + """Mock the `input` builtin. Raise `KeyboardInterrupt` after the second call.""" + + # We allow two calls in order to allow for the chat context handler to kick in + def _mock_input(*args, **kwargs): # noqa: ARG001 + try: + _mock_input.execution_counter += 1 + except AttributeError: + _mock_input.execution_counter = 0 + if _mock_input.execution_counter > 1: + raise KeyboardInterrupt + return user_input + + mocker.patch( # noqa: PT008 + "builtins.input", new=lambda _: _mock_input(user_input=user_input) + ) + + +@pytest.fixture(params=ChatOptions.get_allowed_values("model")) +def llm_model(request): + return request.param + + +@pytest.fixture(params=ChatOptions.get_allowed_values("context_model")) +def context_model(request): + return request.param + + +@pytest.fixture() +def default_chat_configs(llm_model, context_model, tmp_path): + return ChatOptions( + model=llm_model, + context_model=context_model, + token_usage_db_path=tmp_path / "token_usage.db", # Don't use the regular db file + cache_dir=tmp_path, # Don't use our cache files + ) + + +@pytest.fixture() +def cli_args_overrides(default_chat_configs): + args = [] + for field, value in default_chat_configs.model_dump().items(): + if value is not None: + args = [*args, *[f"--{field.replace('_', '-')}", str(value)]] + return args + + +@pytest.fixture() +def default_chat(default_chat_configs): + return Chat(configs=default_chat_configs) diff --git a/tests/smoke/test_app.py b/tests/smoke/test_app.py new file mode 100644 index 0000000..046c221 --- /dev/null +++ b/tests/smoke/test_app.py @@ -0,0 +1,10 @@ +from pyrobbot.app import app + + +def test_app(mocker, default_chat_configs): + mocker.patch("streamlit.session_state", {}) + mocker.patch( + "pyrobbot.chat_configs.ChatOptions.from_file", + return_value=default_chat_configs, + ) + app.run_app() diff --git a/tests/smoke/test_commands.py b/tests/smoke/test_commands.py new file mode 100644 index 0000000..9cecf9b --- /dev/null +++ b/tests/smoke/test_commands.py @@ -0,0 +1,27 @@ +import pytest + +from pyrobbot.__main__ import main +from pyrobbot.argparse_wrapper import get_parsed_args + + +@pytest.mark.usefixtures("_input_builtin_mocker") +@pytest.mark.parametrize("user_input", ["Hi!", ""], ids=["regular-input", "empty-input"]) +def test_terminal_command(cli_args_overrides): + args = ["terminal", "--report-accounting-when-done", *cli_args_overrides] + args = list(dict.fromkeys(args)) + main(args) + + +def test_accounting_command(): + main(["accounting"]) + + +def test_default_command(mocker): + def _mock_subprocess_run(*args, **kwargs): # noqa: ARG001 + raise KeyboardInterrupt("Mocked KeyboardInterrupt") + + args = get_parsed_args(argv=[]) + assert args.command == "ui" + + mocker.patch("subprocess.run", new=_mock_subprocess_run) + main(argv=[]) diff --git a/tests/unit/test_chat.py b/tests/unit/test_chat.py new file mode 100644 index 0000000..ff72d49 --- /dev/null +++ b/tests/unit/test_chat.py @@ -0,0 +1,62 @@ +import openai +import pytest + +from pyrobbot import GeneralConstants +from pyrobbot.openai_utils import CannotConnectToApiError + + +@pytest.mark.order(1) +@pytest.mark.usefixtures("_input_builtin_mocker") +@pytest.mark.no_chat_completion_create_mocking() +@pytest.mark.parametrize("user_input", ["regular-input"]) +def test_testbed_doesnt_actually_connect_to_openai(default_chat): + with pytest.raises( # noqa: PT012 + CannotConnectToApiError, match=default_chat.api_connection_error_msg + ): + try: + default_chat.start() + except CannotConnectToApiError: + raise + else: + pytest.exit("Refuse to continue: Testbed is trying to connect to OpenAI API!") + + +@pytest.mark.order(2) +def test_we_are_using_tmp_cachedir(): + try: + assert ( + GeneralConstants.PACKAGE_CACHE_DIRECTORY + != pytest.ORIGINAL_PACKAGE_CACHE_DIRECTORY + ) + + except AssertionError: + pytest.exit( + "Refuse to continue: Tests attempted to use the package's real cache dir " + + f"({GeneralConstants.PACKAGE_CACHE_DIRECTORY})!" + ) + + +@pytest.mark.usefixtures("_input_builtin_mocker") +@pytest.mark.parametrize("user_input", ["Hi!", ""], ids=["regular-input", "empty-input"]) +def test_terminal_chat(default_chat): + default_chat.start() + default_chat.__del__() # Just to trigger testing the custom del method + + +def test_chat_configs(default_chat, default_chat_configs): + assert default_chat._passed_configs == default_chat_configs + + +@pytest.mark.no_chat_completion_create_mocking() +@pytest.mark.usefixtures("_input_builtin_mocker") +@pytest.mark.parametrize("user_input", ["regular-input"]) +def test_request_timeout_retry(mocker, default_chat): + def _mock_openai_chat_completion_create(*args, **kwargs): # noqa: ARG001 + raise openai.error.Timeout("Mocked timeout error was not caught!") + + mocker.patch("openai.ChatCompletion.create", new=_mock_openai_chat_completion_create) + mocker.patch("time.sleep") # Don't waste time sleeping in tests + with pytest.raises( + CannotConnectToApiError, match=default_chat.api_connection_error_msg + ): + default_chat.start()