Skip to content

Commit

Permalink
Load command and some refactoring (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
shinichi-takayanagi authored May 15, 2023
1 parent ec926b7 commit 1ac6fb0
Show file tree
Hide file tree
Showing 9 changed files with 227 additions and 38 deletions.
5 changes: 4 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,7 @@ repos:
rev: v1.2.0
hooks:
- id: mypy
args: [--explicit-package-bases]
args: [
--explicit-package-bases
]
exclude: 'test_*'
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,15 @@ character:
name: System
style: "#cc0000"
```
## Supported commands on chat
| Command | Description |
| ---- | ---- |
| `/exit` | Exit from this chat tool |
| `/quit | Exit from this chat tool |
| `/q | Exit from this chat tool |
| `/clear | Clear chat history all |
| `/history | Show chat history in json format |
| `/save | Save chat hisotry in json format |
| `/load | Load chat hisotry from a json file |
| `/help | Show all commands which you can use in this chat tool |
31 changes: 18 additions & 13 deletions oregpt/chat_bot.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import json
import pathlib
from datetime import datetime
from copy import deepcopy

import openai

from oregpt.stdinout import StdInOut


class ChatBot:
SYSTEM_ROLE = [{"role": "system", "content": "You are a chat bot"}]

def __init__(self, model: str, std_in_out: StdInOut):
self._std_in_out = std_in_out
# Model list: gpt-4, gpt-4-0314, gpt-4-32k, gpt-4-32k-0314, gpt-3.5-turbo, gpt-3.5-turbo-0301
Expand All @@ -16,17 +17,20 @@ def __init__(self, model: str, std_in_out: StdInOut):

self._initialize_log()

def _initialize_log(self) -> None:
self._log: list[dict[str, str]] = deepcopy(ChatBot.SYSTEM_ROLE)

@property
def model(self) -> str:
return self._model

@property
def log(self) -> list[dict[str, str]]:
return self._log

def _initialize_log(self) -> None:
# TODO
# Make system role
# https://community.openai.com/t/the-system-role-how-it-influences-the-chat-behavior/87353
# https://learn.microsoft.com/ja-jp/azure/cognitive-services/openai/how-to/chatgpt?pivots=programming-language-chat-completions#system-role
# self._log = [{"role": "system", "content": f"You are a chat bot."}]
self._log: list[dict[str, str]] = []
@property
def std_in_out(self) -> StdInOut:
return self._std_in_out

def respond(self, message: str) -> str:
self._log.append({"role": "user", "content": message})
Expand All @@ -49,13 +53,14 @@ def respond(self, message: str) -> str:
self._log.append({"role": "assistant", "content": content})
return content

def save(self, directory: str) -> str:
path = pathlib.Path(directory)
path.mkdir(parents=True, exist_ok=True)
file_name = str(path / datetime.now().strftime("log_%Y-%m-%d-%H-%M-%S.json"))
def save(self, file_name: str) -> str:
with open(file_name, "w", encoding="utf-8") as file:
json.dump(self._log, file, indent=4, ensure_ascii=False)
return file_name

def load(self, file_name: str) -> None:
with open(file_name, "r", encoding="utf-8") as file:
self._log = json.load(file)

def clear(self) -> None:
self._initialize_log()
64 changes: 47 additions & 17 deletions oregpt/command.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,44 @@
import os
import pathlib
import sys
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Optional, Type

from oregpt.chat_bot import ChatBot
from oregpt.stdinout import StdInOut


class CommandBuilder:
classes: dict[str, Type["Command"]] = dict({})
command_classes: dict[str, Type["Command"]] = dict({})

def __init__(self, config: dict[str, Any], bot: ChatBot, std_in_out: StdInOut):
def __init__(self, config: dict[str, Any], bot: ChatBot):
self._config = config
self._bot = bot
self._std_in_out = std_in_out

def build(self, message: str) -> Optional["Command"]:
messages = message.split(" ")
command = messages[0].strip()
args = messages[1:] if len(messages) >= 2 else [""]
args = messages[1:] if len(messages) >= 2 else []
args = list(filter(None, args))
return (
class_type(self._config, self._bot, self._std_in_out, args)
if (class_type := self.__class__.classes.get(command))
class_type(self._config, self._bot, args)
if (class_type := CommandBuilder.command_classes.get(command))
else None
)


def register(cls: Type["Command"]) -> None:
def register(cls: Type["Command"]) -> Type["Command"]:
for representation in cls.representations:
CommandBuilder.classes["/" + representation] = cls
CommandBuilder.command_classes["/" + representation] = cls
return cls


class Command(ABC):
representations: list[str] = []

def __init__(self, config: dict[str, Any], bot: ChatBot, std_in_out: StdInOut, args: list[str]):
def __init__(self, config: dict[str, Any], bot: ChatBot, args: list[str]):
self._config = config
self._bot = bot
self._std_in_out = std_in_out
self._args = args

@abstractmethod
Expand All @@ -62,7 +64,7 @@ class ClearCommand(Command):

def execute(self) -> None:
self._bot.clear()
self._std_in_out.print_system("Clear all conversation history")
self._bot.std_in_out.print_system("Clear all conversation history")


@register
Expand All @@ -72,7 +74,11 @@ class HistoryCommand(Command):
representations: list[str] = ["history"]

def execute(self) -> None:
self._std_in_out.print_system(str(self._bot.log))
self._bot.std_in_out.print_system(str(self._bot.log))


def _abspath(x: str) -> str:
return os.path.abspath(os.path.expanduser(x))


@register
Expand All @@ -82,8 +88,32 @@ class SaveCommand(Command):
representations: list[str] = ["save"]

def execute(self) -> None:
file_name = self._bot.save(self._config["log"])
self._std_in_out.print_system(f"Save all conversation history in {file_name}")
file_name = ""
if len(self._args) == 1:
file_name = _abspath(self._args[0].strip())
directory = pathlib.Path(os.path.dirname(file_name))
if file_name == "":
directory = pathlib.Path(_abspath(self._config["log"]))
file_name = str(directory / datetime.now().strftime("log_%Y-%m-%d-%H-%M-%S.json"))
directory.mkdir(parents=True, exist_ok=True)
self._bot.save(file_name)
self._bot.std_in_out.print_system(f"Save all conversation history in {file_name}")


@register
class LoadCommand(Command):
"""Load chat hisotry from a json file"""

representations: list[str] = ["load"]

def execute(self) -> None:
print(self._args)
if len(self._args) != 1:
self._bot.std_in_out.print_system("Loaded file was not specified as an argument")
return
file_name = _abspath(self._args[0])
self._bot.load(file_name)
self._bot.std_in_out.print_system(f"Loaded chat history from {file_name}")


@register
Expand All @@ -93,5 +123,5 @@ class HelpCommand(Command):
representations: list[str] = ["help"]

def execute(self) -> None:
for k, v in CommandBuilder.classes.items():
self._std_in_out.print_system(f"{k}: {v.__doc__}")
for k, v in CommandBuilder.command_classes.items():
self._bot.std_in_out.print_system(f"{k}: {v.__doc__}")
2 changes: 1 addition & 1 deletion oregpt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def main() -> int:
initialize_open_ai_key(config["openai"])
std_in_out = StdInOut(config["character"], lambda: "To exit, type q, quit, exit, or Ctrl + C")
bot = ChatBot(config["openai"]["model"], std_in_out)
command_builder = CommandBuilder(config, bot, std_in_out)
command_builder = CommandBuilder(config, bot)

try:
while True:
Expand Down
6 changes: 5 additions & 1 deletion oregpt/stdinout.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ class Character:

class StdInOut:
def __init__(self, config: dict[str, Any], bottom_toolbar: Optional[AnyFormattedText]):
self._characters: dict[str, Character] = {}
self._characters: dict[str, Character] = {
"user": Character("Me", "#00BEFE"),
"assistant": Character("AI", "#87CEEB"),
"system": Character("System", "#cc0000"),
}
for key, value in config.items():
self._characters[key] = Character(**value)
self._bottom_toolbar = bottom_toolbar
Expand Down
47 changes: 47 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import contextlib

import pytest
from openai import ChatCompletion

from oregpt.chat_bot import ChatBot
from oregpt.stdinout import StdInOut


def pytest_configure():
pytest.DUMMY_CONTENT = "Yep"


# Use this trick
# https://stackoverflow.com/a/42156088/3926333
class Helpers:
@staticmethod
def make_std_in_out():
return StdInOut({}, lambda: "Dummy")

@staticmethod
def make_chat_bot(name: str):
return ChatBot(name, Helpers.make_std_in_out())


@pytest.fixture
def helpers():
return Helpers


@pytest.fixture
def patched_bot(monkeypatch, helpers):
def _create(*args, **kwargs):
return [{"choices": [{"delta": {"content": pytest.DUMMY_CONTENT}}]}]

# Set monkey patch to avoid this error: https://github.com/prompt-toolkit/python-prompt-toolkit/issues/406
def _print(*args, **kwargs):
pass

@contextlib.contextmanager
def _print_as_contextmanager(*args, **kwargs):
yield

monkeypatch.setattr(ChatCompletion, "create", _create)
monkeypatch.setattr(StdInOut, "_print", _print)
monkeypatch.setattr(StdInOut, "print_assistant_thinking", _print_as_contextmanager)
return helpers.make_chat_bot("Yes")
66 changes: 61 additions & 5 deletions tests/test_chat_bot.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import contextlib
import json

import pytest
from openai import ChatCompletion

from oregpt.chat_bot import ChatBot
from oregpt.stdinout import StdInOut

DUMMY_CONTENT = "Yep"

def test_chat_bot_respond(monkeypatch):

@pytest.fixture(scope="function")
def patched_bot(monkeypatch, helpers):
def _create(*args, **kwargs):
return [{"choices": [{"delta": {"content": "Yep"}}]}]
return [{"choices": [{"delta": {"content": DUMMY_CONTENT}}]}]

# Set monkey patch to avoid this error: https://github.com/prompt-toolkit/python-prompt-toolkit/issues/406
def _print(*args, **kwargs):
Expand All @@ -21,8 +26,59 @@ def _print_as_contextmanager(*args, **kwargs):
monkeypatch.setattr(ChatCompletion, "create", _create)
monkeypatch.setattr(StdInOut, "_print", _print)
monkeypatch.setattr(StdInOut, "print_assistant_thinking", _print_as_contextmanager)
return helpers.make_chat_bot("Yes")


@pytest.fixture
def tmp_file(tmpdir_factory):
return tmpdir_factory.mktemp("data").join("test.json")


def test_initialized_property(helpers):
bot = helpers.make_chat_bot("THE AI")
assert bot.model == "THE AI"
assert bot.log == ChatBot.SYSTEM_ROLE


def test_respond_and_log(patched_bot):
what_user_said = "Hello, world"
assert patched_bot.log == ChatBot.SYSTEM_ROLE
assert DUMMY_CONTENT == patched_bot.respond(what_user_said)
assert patched_bot.log == ChatBot.SYSTEM_ROLE + [
{"role": "user", "content": what_user_said},
{"role": "assistant", "content": DUMMY_CONTENT},
]


def test_save(tmp_file, patched_bot):
what_user_said = "Hello, world???"
patched_bot.respond(what_user_said)
patched_bot.save(str(tmp_file))

with tmp_file.open("r") as file:
assert patched_bot.log == json.load(file)
assert patched_bot.log == ChatBot.SYSTEM_ROLE + [
{"role": "user", "content": what_user_said},
{"role": "assistant", "content": DUMMY_CONTENT},
]


def test_load(tmp_file, patched_bot, helpers):
what_user_said = "Hello, world"
patched_bot.respond(what_user_said)
patched_bot.save(str(tmp_file))

bot = helpers.make_chat_bot("THE AI")
bot.load(tmp_file)
assert bot.log == patched_bot.log

bot = ChatBot("ultra-ai", StdInOut({}, lambda: "Dummy"))

answer = bot.respond("Hello, world")
assert "Yep" == answer
def test_clear(patched_bot):
what_user_said = "Hello, world"
patched_bot.respond(what_user_said)
assert patched_bot.log == ChatBot.SYSTEM_ROLE + [
{"role": "user", "content": what_user_said},
{"role": "assistant", "content": DUMMY_CONTENT},
]
patched_bot.clear()
assert patched_bot._log == ChatBot.SYSTEM_ROLE
32 changes: 32 additions & 0 deletions tests/test_command.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest

from oregpt.chat_bot import ChatBot
from oregpt.command import (
ClearCommand,
CommandBuilder,
ExitCommand,
HelpCommand,
HistoryCommand,
LoadCommand,
SaveCommand,
)


def test_command_builder(helpers):
command_builder = CommandBuilder({}, helpers.make_chat_bot("Yahoo"))
for command_type in [ExitCommand, ClearCommand, HistoryCommand, SaveCommand, LoadCommand, HelpCommand]:
for representation in command_type.representations:
assert isinstance(command_builder.build(f"/{representation}"), command_type)


def test_exit_command(patched_bot):
command = ExitCommand({}, patched_bot, [])
with pytest.raises(SystemExit):
command.execute()


def test_clear_command(patched_bot):
cl = ClearCommand({}, patched_bot, [])
patched_bot.respond("Hi, bot-san")
cl.execute()
patched_bot.log == ChatBot.SYSTEM_ROLE

0 comments on commit 1ac6fb0

Please sign in to comment.