diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3441e63..00bfc95 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,4 +24,7 @@ repos: rev: v1.2.0 hooks: - id: mypy - args: [--explicit-package-bases] + args: [ + --explicit-package-bases + ] + exclude: 'test_*' diff --git a/README.md b/README.md index 8aa33db..5add616 100644 --- a/README.md +++ b/README.md @@ -57,3 +57,15 @@ character: name: System style: "#cc0000" ``` + +## Supported commands on chat +| Command | Description | +| ---- | ---- | +| `/exit` | Exit from this chat tool | +| `/quit | Exit from this chat tool | +| `/q | Exit from this chat tool | +| `/clear | Clear chat history all | +| `/history | Show chat history in json format | +| `/save | Save chat hisotry in json format | +| `/load | Load chat hisotry from a json file | +| `/help | Show all commands which you can use in this chat tool | diff --git a/oregpt/chat_bot.py b/oregpt/chat_bot.py index 05c5826..cc05f01 100644 --- a/oregpt/chat_bot.py +++ b/oregpt/chat_bot.py @@ -1,6 +1,5 @@ import json -import pathlib -from datetime import datetime +from copy import deepcopy import openai @@ -8,6 +7,8 @@ class ChatBot: + SYSTEM_ROLE = [{"role": "system", "content": "You are a chat bot"}] + def __init__(self, model: str, std_in_out: StdInOut): self._std_in_out = std_in_out # Model list: gpt-4, gpt-4-0314, gpt-4-32k, gpt-4-32k-0314, gpt-3.5-turbo, gpt-3.5-turbo-0301 @@ -16,17 +17,20 @@ def __init__(self, model: str, std_in_out: StdInOut): self._initialize_log() + def _initialize_log(self) -> None: + self._log: list[dict[str, str]] = deepcopy(ChatBot.SYSTEM_ROLE) + + @property + def model(self) -> str: + return self._model + @property def log(self) -> list[dict[str, str]]: return self._log - def _initialize_log(self) -> None: - # TODO - # Make system role - # https://community.openai.com/t/the-system-role-how-it-influences-the-chat-behavior/87353 - # https://learn.microsoft.com/ja-jp/azure/cognitive-services/openai/how-to/chatgpt?pivots=programming-language-chat-completions#system-role - # self._log = [{"role": "system", "content": f"You are a chat bot."}] - self._log: list[dict[str, str]] = [] + @property + def std_in_out(self) -> StdInOut: + return self._std_in_out def respond(self, message: str) -> str: self._log.append({"role": "user", "content": message}) @@ -49,13 +53,14 @@ def respond(self, message: str) -> str: self._log.append({"role": "assistant", "content": content}) return content - def save(self, directory: str) -> str: - path = pathlib.Path(directory) - path.mkdir(parents=True, exist_ok=True) - file_name = str(path / datetime.now().strftime("log_%Y-%m-%d-%H-%M-%S.json")) + def save(self, file_name: str) -> str: with open(file_name, "w", encoding="utf-8") as file: json.dump(self._log, file, indent=4, ensure_ascii=False) return file_name + def load(self, file_name: str) -> None: + with open(file_name, "r", encoding="utf-8") as file: + self._log = json.load(file) + def clear(self) -> None: self._initialize_log() diff --git a/oregpt/command.py b/oregpt/command.py index 42be332..ff569aa 100644 --- a/oregpt/command.py +++ b/oregpt/command.py @@ -1,42 +1,44 @@ +import os +import pathlib import sys from abc import ABC, abstractmethod +from datetime import datetime from typing import Any, Optional, Type from oregpt.chat_bot import ChatBot -from oregpt.stdinout import StdInOut class CommandBuilder: - classes: dict[str, Type["Command"]] = dict({}) + command_classes: dict[str, Type["Command"]] = dict({}) - def __init__(self, config: dict[str, Any], bot: ChatBot, std_in_out: StdInOut): + def __init__(self, config: dict[str, Any], bot: ChatBot): self._config = config self._bot = bot - self._std_in_out = std_in_out def build(self, message: str) -> Optional["Command"]: messages = message.split(" ") command = messages[0].strip() - args = messages[1:] if len(messages) >= 2 else [""] + args = messages[1:] if len(messages) >= 2 else [] + args = list(filter(None, args)) return ( - class_type(self._config, self._bot, self._std_in_out, args) - if (class_type := self.__class__.classes.get(command)) + class_type(self._config, self._bot, args) + if (class_type := CommandBuilder.command_classes.get(command)) else None ) -def register(cls: Type["Command"]) -> None: +def register(cls: Type["Command"]) -> Type["Command"]: for representation in cls.representations: - CommandBuilder.classes["/" + representation] = cls + CommandBuilder.command_classes["/" + representation] = cls + return cls class Command(ABC): representations: list[str] = [] - def __init__(self, config: dict[str, Any], bot: ChatBot, std_in_out: StdInOut, args: list[str]): + def __init__(self, config: dict[str, Any], bot: ChatBot, args: list[str]): self._config = config self._bot = bot - self._std_in_out = std_in_out self._args = args @abstractmethod @@ -62,7 +64,7 @@ class ClearCommand(Command): def execute(self) -> None: self._bot.clear() - self._std_in_out.print_system("Clear all conversation history") + self._bot.std_in_out.print_system("Clear all conversation history") @register @@ -72,7 +74,11 @@ class HistoryCommand(Command): representations: list[str] = ["history"] def execute(self) -> None: - self._std_in_out.print_system(str(self._bot.log)) + self._bot.std_in_out.print_system(str(self._bot.log)) + + +def _abspath(x: str) -> str: + return os.path.abspath(os.path.expanduser(x)) @register @@ -82,8 +88,32 @@ class SaveCommand(Command): representations: list[str] = ["save"] def execute(self) -> None: - file_name = self._bot.save(self._config["log"]) - self._std_in_out.print_system(f"Save all conversation history in {file_name}") + file_name = "" + if len(self._args) == 1: + file_name = _abspath(self._args[0].strip()) + directory = pathlib.Path(os.path.dirname(file_name)) + if file_name == "": + directory = pathlib.Path(_abspath(self._config["log"])) + file_name = str(directory / datetime.now().strftime("log_%Y-%m-%d-%H-%M-%S.json")) + directory.mkdir(parents=True, exist_ok=True) + self._bot.save(file_name) + self._bot.std_in_out.print_system(f"Save all conversation history in {file_name}") + + +@register +class LoadCommand(Command): + """Load chat hisotry from a json file""" + + representations: list[str] = ["load"] + + def execute(self) -> None: + print(self._args) + if len(self._args) != 1: + self._bot.std_in_out.print_system("Loaded file was not specified as an argument") + return + file_name = _abspath(self._args[0]) + self._bot.load(file_name) + self._bot.std_in_out.print_system(f"Loaded chat history from {file_name}") @register @@ -93,5 +123,5 @@ class HelpCommand(Command): representations: list[str] = ["help"] def execute(self) -> None: - for k, v in CommandBuilder.classes.items(): - self._std_in_out.print_system(f"{k}: {v.__doc__}") + for k, v in CommandBuilder.command_classes.items(): + self._bot.std_in_out.print_system(f"{k}: {v.__doc__}") diff --git a/oregpt/main.py b/oregpt/main.py index 995bf89..497f553 100644 --- a/oregpt/main.py +++ b/oregpt/main.py @@ -39,7 +39,7 @@ def main() -> int: initialize_open_ai_key(config["openai"]) std_in_out = StdInOut(config["character"], lambda: "To exit, type q, quit, exit, or Ctrl + C") bot = ChatBot(config["openai"]["model"], std_in_out) - command_builder = CommandBuilder(config, bot, std_in_out) + command_builder = CommandBuilder(config, bot) try: while True: diff --git a/oregpt/stdinout.py b/oregpt/stdinout.py index e9cc7ef..dece07e 100644 --- a/oregpt/stdinout.py +++ b/oregpt/stdinout.py @@ -16,7 +16,11 @@ class Character: class StdInOut: def __init__(self, config: dict[str, Any], bottom_toolbar: Optional[AnyFormattedText]): - self._characters: dict[str, Character] = {} + self._characters: dict[str, Character] = { + "user": Character("Me", "#00BEFE"), + "assistant": Character("AI", "#87CEEB"), + "system": Character("System", "#cc0000"), + } for key, value in config.items(): self._characters[key] = Character(**value) self._bottom_toolbar = bottom_toolbar diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..72d4376 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,47 @@ +import contextlib + +import pytest +from openai import ChatCompletion + +from oregpt.chat_bot import ChatBot +from oregpt.stdinout import StdInOut + + +def pytest_configure(): + pytest.DUMMY_CONTENT = "Yep" + + +# Use this trick +# https://stackoverflow.com/a/42156088/3926333 +class Helpers: + @staticmethod + def make_std_in_out(): + return StdInOut({}, lambda: "Dummy") + + @staticmethod + def make_chat_bot(name: str): + return ChatBot(name, Helpers.make_std_in_out()) + + +@pytest.fixture +def helpers(): + return Helpers + + +@pytest.fixture +def patched_bot(monkeypatch, helpers): + def _create(*args, **kwargs): + return [{"choices": [{"delta": {"content": pytest.DUMMY_CONTENT}}]}] + + # Set monkey patch to avoid this error: https://github.com/prompt-toolkit/python-prompt-toolkit/issues/406 + def _print(*args, **kwargs): + pass + + @contextlib.contextmanager + def _print_as_contextmanager(*args, **kwargs): + yield + + monkeypatch.setattr(ChatCompletion, "create", _create) + monkeypatch.setattr(StdInOut, "_print", _print) + monkeypatch.setattr(StdInOut, "print_assistant_thinking", _print_as_contextmanager) + return helpers.make_chat_bot("Yes") diff --git a/tests/test_chat_bot.py b/tests/test_chat_bot.py index 737762e..0b81818 100644 --- a/tests/test_chat_bot.py +++ b/tests/test_chat_bot.py @@ -1,14 +1,19 @@ import contextlib +import json +import pytest from openai import ChatCompletion from oregpt.chat_bot import ChatBot from oregpt.stdinout import StdInOut +DUMMY_CONTENT = "Yep" -def test_chat_bot_respond(monkeypatch): + +@pytest.fixture(scope="function") +def patched_bot(monkeypatch, helpers): def _create(*args, **kwargs): - return [{"choices": [{"delta": {"content": "Yep"}}]}] + return [{"choices": [{"delta": {"content": DUMMY_CONTENT}}]}] # Set monkey patch to avoid this error: https://github.com/prompt-toolkit/python-prompt-toolkit/issues/406 def _print(*args, **kwargs): @@ -21,8 +26,59 @@ def _print_as_contextmanager(*args, **kwargs): monkeypatch.setattr(ChatCompletion, "create", _create) monkeypatch.setattr(StdInOut, "_print", _print) monkeypatch.setattr(StdInOut, "print_assistant_thinking", _print_as_contextmanager) + return helpers.make_chat_bot("Yes") + + +@pytest.fixture +def tmp_file(tmpdir_factory): + return tmpdir_factory.mktemp("data").join("test.json") + + +def test_initialized_property(helpers): + bot = helpers.make_chat_bot("THE AI") + assert bot.model == "THE AI" + assert bot.log == ChatBot.SYSTEM_ROLE + + +def test_respond_and_log(patched_bot): + what_user_said = "Hello, world" + assert patched_bot.log == ChatBot.SYSTEM_ROLE + assert DUMMY_CONTENT == patched_bot.respond(what_user_said) + assert patched_bot.log == ChatBot.SYSTEM_ROLE + [ + {"role": "user", "content": what_user_said}, + {"role": "assistant", "content": DUMMY_CONTENT}, + ] + + +def test_save(tmp_file, patched_bot): + what_user_said = "Hello, world???" + patched_bot.respond(what_user_said) + patched_bot.save(str(tmp_file)) + + with tmp_file.open("r") as file: + assert patched_bot.log == json.load(file) + assert patched_bot.log == ChatBot.SYSTEM_ROLE + [ + {"role": "user", "content": what_user_said}, + {"role": "assistant", "content": DUMMY_CONTENT}, + ] + + +def test_load(tmp_file, patched_bot, helpers): + what_user_said = "Hello, world" + patched_bot.respond(what_user_said) + patched_bot.save(str(tmp_file)) + + bot = helpers.make_chat_bot("THE AI") + bot.load(tmp_file) + assert bot.log == patched_bot.log - bot = ChatBot("ultra-ai", StdInOut({}, lambda: "Dummy")) - answer = bot.respond("Hello, world") - assert "Yep" == answer +def test_clear(patched_bot): + what_user_said = "Hello, world" + patched_bot.respond(what_user_said) + assert patched_bot.log == ChatBot.SYSTEM_ROLE + [ + {"role": "user", "content": what_user_said}, + {"role": "assistant", "content": DUMMY_CONTENT}, + ] + patched_bot.clear() + assert patched_bot._log == ChatBot.SYSTEM_ROLE diff --git a/tests/test_command.py b/tests/test_command.py new file mode 100644 index 0000000..81d833a --- /dev/null +++ b/tests/test_command.py @@ -0,0 +1,32 @@ +import pytest + +from oregpt.chat_bot import ChatBot +from oregpt.command import ( + ClearCommand, + CommandBuilder, + ExitCommand, + HelpCommand, + HistoryCommand, + LoadCommand, + SaveCommand, +) + + +def test_command_builder(helpers): + command_builder = CommandBuilder({}, helpers.make_chat_bot("Yahoo")) + for command_type in [ExitCommand, ClearCommand, HistoryCommand, SaveCommand, LoadCommand, HelpCommand]: + for representation in command_type.representations: + assert isinstance(command_builder.build(f"/{representation}"), command_type) + + +def test_exit_command(patched_bot): + command = ExitCommand({}, patched_bot, []) + with pytest.raises(SystemExit): + command.execute() + + +def test_clear_command(patched_bot): + cl = ClearCommand({}, patched_bot, []) + patched_bot.respond("Hi, bot-san") + cl.execute() + patched_bot.log == ChatBot.SYSTEM_ROLE