diff --git a/README.md b/README.md index 089c62c..ee2cdd4 100644 --- a/README.md +++ b/README.md @@ -43,8 +43,8 @@ Commands such as saving and loading conversations are available as the following | `/q` | Exit from this chat tool | | `/clear` | Clear chat history all | | `/history` | Show chat history in json format | -| `/save` | Save chat hisotry in json format | -| `/load` | Load chat hisotry from a json file | +| `/save` | Save chat history in json format | +| `/load` | Load chat history from a json file | | `/help` | Show all commands which you can use in this chat tool | ## Configuration diff --git a/oregpt/chat_bot.py b/oregpt/chat_bot.py index cc05f01..9ce56c0 100644 --- a/oregpt/chat_bot.py +++ b/oregpt/chat_bot.py @@ -7,23 +7,25 @@ class ChatBot: - SYSTEM_ROLE = [{"role": "system", "content": "You are a chat bot"}] - - def __init__(self, model: str, std_in_out: StdInOut): + def __init__(self, model: str, assistant_role: str, std_in_out: StdInOut): self._std_in_out = std_in_out # Model list: gpt-4, gpt-4-0314, gpt-4-32k, gpt-4-32k-0314, gpt-3.5-turbo, gpt-3.5-turbo-0301 # https://platform.openai.com/docs/models/overview self._model = model - + self._assistant_role = [{"role": "system", "content": assistant_role}] self._initialize_log() def _initialize_log(self) -> None: - self._log: list[dict[str, str]] = deepcopy(ChatBot.SYSTEM_ROLE) + self._log: list[dict[str, str]] = deepcopy(self._assistant_role) @property def model(self) -> str: return self._model + @property + def assistant_role(self) -> list[dict[str, str]]: + return deepcopy(self._assistant_role) + @property def log(self) -> list[dict[str, str]]: return self._log diff --git a/oregpt/command.py b/oregpt/command.py index 40d4d1e..a580cf2 100644 --- a/oregpt/command.py +++ b/oregpt/command.py @@ -88,7 +88,7 @@ def _abspath(x: str) -> str: @register class SaveCommand(Command): - """Save chat hisotry in json format""" + """Save chat history in json format""" representations: list[str] = ["save"] @@ -107,7 +107,7 @@ def execute(self) -> None: @register class LoadCommand(Command): - """Load chat hisotry from a json file""" + """Load chat history from a json file""" representations: list[str] = ["load"] diff --git a/oregpt/main.py b/oregpt/main.py index 3b7c6ef..7a94b35 100644 --- a/oregpt/main.py +++ b/oregpt/main.py @@ -1,8 +1,10 @@ import os import pathlib import shutil -from typing import Any +from enum import Enum +from typing import Any, Optional +import click import openai import yaml @@ -34,28 +36,53 @@ def initialize_open_ai_key(config: dict[str, Any]) -> None: raise LookupError("OpenAI's API key was not found in config.yml and environment variables") -def main() -> int: +class Status(Enum): + USER = 1 + BOT = 2 + + +def prefer_left(lhs: str, rhs: Optional[str]) -> str: + if lhs != "": + return lhs + if rhs is not None: + return rhs + raise Exception("Set as an argument or in ~/.config/oregpt/config.yml") + + +# Add "type: ignore" to avoid this https://github.com/python/typeshed/issues/6156 +@click.command() # type: ignore +@click.option("--model_name", "-m", type=str, help="Model name in OpenAI (e.g, gpt-3.5-turbo, gpt-4)", default="") # type: ignore +@click.option("--assistant_role", "-a", type=str, help="Role setting for Assistant (AI)", default="") # type: ignore +def main(model_name: str, assistant_role: str) -> int: config = load_config() initialize_open_ai_key(config["openai"]) + model_name = prefer_left(model_name, config["openai"].get("model")) + assistant_role = prefer_left(assistant_role, config["character"]["assistant"].get("role")) std_in_out = StdInOut(config["character"], lambda: "To exit, type /q, /quit, /exit, or Ctrl + C") - bot = ChatBot(config["openai"]["model"], std_in_out) + bot = ChatBot(model_name, assistant_role, std_in_out) command_builder = CommandBuilder(config, bot) - try: - while True: + while True: + try: + status = Status.USER message = std_in_out.input().lower() if command := command_builder.build(message): command.execute() else: + status = Status.BOT if command_builder.looks_like_command(message): std_in_out.print_system("Invalid command. Valid commands are as the following:") command_builder.build("/help").execute() # type: ignore else: bot.respond(message) - except KeyboardInterrupt: - return 0 - except Exception as e: - raise Exception(f"Something happened: {str(e)}") from e + except KeyboardInterrupt: + if status == Status.BOT: + std_in_out.print_assistant("\n") + else: + return 0 + except Exception as e: + raise Exception(f"Something happened: {str(e)}") from e + return 1 return 0 diff --git a/oregpt/resources/config.yml b/oregpt/resources/config.yml index 8107474..d98cf0a 100644 --- a/oregpt/resources/config.yml +++ b/oregpt/resources/config.yml @@ -4,12 +4,13 @@ openai: # You can also specify OpenAI's API key here # api_key: character: - user: - name: Me - style: "#00BEFE" assistant: name: AI style: "#87CEEB" + role: "You are a chat bot" + user: + name: Me + style: "#00BEFE" system: name: System style: "#cc0000" diff --git a/poetry.lock b/poetry.lock index 2968db3..343ab82 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry and should not be changed by hand. +# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand. [[package]] name = "aiohttp" @@ -333,7 +333,7 @@ files = [ name = "click" version = "8.1.3" description = "Composable command line interface toolkit" -category = "dev" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1121,4 +1121,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "be4b5246268bdfe46094f60814fc95355382a347678ace8e450fcf711b31611a" +content-hash = "45493e3694bfda018541e3a28f77173aef143a8961d180cce01a9e6a9f0d0490" diff --git a/pyproject.toml b/pyproject.toml index ab48350..10abc7d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ python = "^3.10" openai = "^0.27.6" pyyaml = "^6.0" prompt-toolkit = "^3.0.38" +click = "^8.1.3" [tool.poetry.group.test.dependencies] diff --git a/tests/conftest.py b/tests/conftest.py index 4903261..069782b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,8 +20,8 @@ def make_std_in_out(): return StdInOut({}, lambda: "Dummy") @staticmethod - def make_chat_bot(name: str): - return ChatBot(name, Helpers.make_std_in_out()) + def make_chat_bot(name: str, role: str): + return ChatBot(name, role, Helpers.make_std_in_out()) @pytest.fixture @@ -45,4 +45,4 @@ def _print_as_contextmanager(*args, **kwargs): monkeypatch.setattr(ChatCompletion, "create", _create) monkeypatch.setattr(StdInOut, "_print", _print) monkeypatch.setattr(StdInOut, "print_assistant_thinking", _print_as_contextmanager) - return helpers.make_chat_bot("Yes") + return helpers.make_chat_bot("THE AI", "You are a great chat bot") diff --git a/tests/test_chat_bot.py b/tests/test_chat_bot.py index a35d2f3..2e00121 100644 --- a/tests/test_chat_bot.py +++ b/tests/test_chat_bot.py @@ -2,8 +2,6 @@ import pytest -from oregpt.chat_bot import ChatBot - @pytest.fixture def tmp_file(tmpdir_factory): @@ -11,16 +9,16 @@ def tmp_file(tmpdir_factory): def test_initialized_property(helpers): - bot = helpers.make_chat_bot("THE AI") + bot = helpers.make_chat_bot("THE AI", "You are a bot") assert bot.model == "THE AI" - assert bot.log == ChatBot.SYSTEM_ROLE + assert bot.log == bot.assistant_role def test_respond_and_log(patched_bot): what_user_said = "Hello, world" - assert patched_bot.log == ChatBot.SYSTEM_ROLE + assert patched_bot.log == patched_bot.assistant_role assert pytest.DUMMY_CONTENT == patched_bot.respond(what_user_said) - assert patched_bot.log == ChatBot.SYSTEM_ROLE + [ + assert patched_bot.log == patched_bot.assistant_role + [ {"role": "user", "content": what_user_said}, {"role": "assistant", "content": pytest.DUMMY_CONTENT}, ] @@ -33,7 +31,7 @@ def test_save(tmp_file, patched_bot): with tmp_file.open("r") as file: assert patched_bot.log == json.load(file) - assert patched_bot.log == ChatBot.SYSTEM_ROLE + [ + assert patched_bot.log == patched_bot.assistant_role + [ {"role": "user", "content": what_user_said}, {"role": "assistant", "content": pytest.DUMMY_CONTENT}, ] @@ -44,7 +42,7 @@ def test_load(tmp_file, patched_bot, helpers): patched_bot.respond(what_user_said) patched_bot.save(str(tmp_file)) - bot = helpers.make_chat_bot("THE AI") + bot = helpers.make_chat_bot("THE AI", "You are a bot") bot.load(tmp_file) assert bot.log == patched_bot.log @@ -52,9 +50,9 @@ def test_load(tmp_file, patched_bot, helpers): def test_clear(patched_bot): what_user_said = "Hello, world" patched_bot.respond(what_user_said) - assert patched_bot.log == ChatBot.SYSTEM_ROLE + [ + assert patched_bot.log == patched_bot.assistant_role + [ {"role": "user", "content": what_user_said}, {"role": "assistant", "content": pytest.DUMMY_CONTENT}, ] patched_bot.clear() - assert patched_bot._log == ChatBot.SYSTEM_ROLE + assert patched_bot._log == patched_bot.assistant_role diff --git a/tests/test_command.py b/tests/test_command.py index 09ad7f4..27ffbcf 100644 --- a/tests/test_command.py +++ b/tests/test_command.py @@ -1,6 +1,5 @@ import pytest -from oregpt.chat_bot import ChatBot from oregpt.command import ( ClearCommand, CommandBuilder, @@ -13,14 +12,14 @@ def test_command_builder_build(helpers): - command_builder = CommandBuilder({}, helpers.make_chat_bot("Yahoo")) + command_builder = CommandBuilder({}, helpers.make_chat_bot("Yahoo", "You are a bot")) for command_type in [ExitCommand, ClearCommand, HistoryCommand, SaveCommand, LoadCommand, HelpCommand]: for representation in command_type.representations: assert isinstance(command_builder.build(f"/{representation}"), command_type) def test_command_builder_looks_like_command(helpers): - command_builder = CommandBuilder({}, helpers.make_chat_bot("Yahoo")) + command_builder = CommandBuilder({}, helpers.make_chat_bot("Yahoo", "You are a bot")) assert command_builder.looks_like_command("/hoge hoge") == True assert command_builder.looks_like_command("/hoge") == True assert command_builder.looks_like_command("/") == True @@ -39,4 +38,4 @@ def test_clear_command(patched_bot): cl = ClearCommand({}, patched_bot, []) patched_bot.respond("Hi, bot-san") cl.execute() - assert patched_bot.log == ChatBot.SYSTEM_ROLE + assert patched_bot.log == patched_bot.assistant_role