From cb8e12d1272f4154d79de4211762c10a79c69266 Mon Sep 17 00:00:00 2001 From: victor <52110451+cs50victor@users.noreply.github.com> Date: Tue, 26 Mar 2024 22:28:41 -0400 Subject: [PATCH] refactor: simplify code ( part I ) (#3) * refactor: default to ollama llm locally * refactor: simplify code further * chore: tauri build ci * fix: ci * chore: run ruff linter --- .github/workflows/ci-rs.yml | 7 +- .gitignore | 3 + .../clients/base_device.py => mac_device.py} | 24 ++-- core/source/clients/__init__.py | 0 core/source/clients/mac/__init__.py | 0 core/source/clients/mac/device.py | 10 -- core/source/server/conftest.py | 3 - core/source/server/i.py | 2 - core/source/server/server.py | 6 +- .../server/services/stt/local-whisper/stt.py | 2 - core/source/server/services/tts/openai/tts.py | 2 - core/source/server/skills/schedule.py | 1 - core/source/server/tests/__init__.py | 0 core/source/server/tests/test_run.py | 41 ------ core/source/server/tunnel.py | 4 +- core/source/server/utils/kernel.py | 2 - core/source/server/utils/local_mode.py | 123 ++++++------------ core/source/server/utils/process_utils.py | 2 +- core/start.py | 5 +- 19 files changed, 65 insertions(+), 172 deletions(-) rename core/{source/clients/base_device.py => mac_device.py} (96%) delete mode 100644 core/source/clients/__init__.py delete mode 100644 core/source/clients/mac/__init__.py delete mode 100644 core/source/clients/mac/device.py delete mode 100644 core/source/server/tests/__init__.py delete mode 100644 core/source/server/tests/test_run.py diff --git a/.github/workflows/ci-rs.yml b/.github/workflows/ci-rs.yml index cdc298e..f61e642 100644 --- a/.github/workflows/ci-rs.yml +++ b/.github/workflows/ci-rs.yml @@ -5,8 +5,8 @@ on: paths: - "**/.github/workflows/ci-rs.yml" - "**/Cargo.lock" - - "**/src/**" - "**/Cargo.toml" + - "**/src-tauri/**" - "**/rust-toolchain" - "**/.taplo.toml" workflow_dispatch: @@ -42,8 +42,11 @@ jobs: components: rustfmt, clippy enable-sccache: "true" + - name: Install Tauri + run: cargo install tauri + - name: Build - run: cargo build --release + run: cargo tauri build --release - name: Test run: cargo test --release diff --git a/.gitignore b/.gitignore index 6aac21d..55fd886 100644 --- a/.gitignore +++ b/.gitignore @@ -52,3 +52,6 @@ dist-ssr *.njsproj *.sln *.sw? + +# python +.cenv diff --git a/core/source/clients/base_device.py b/core/mac_device.py similarity index 96% rename from core/source/clients/base_device.py rename to core/mac_device.py index 6b71b7c..68789c5 100644 --- a/core/source/clients/base_device.py +++ b/core/mac_device.py @@ -4,10 +4,7 @@ import os import asyncio import threading -import os import pyaudio -from starlette.websockets import WebSocket -from queue import Queue from pynput import keyboard import json import traceback @@ -23,17 +20,17 @@ import base64 from interpreter import interpreter # Just for code execution. Maybe we should let people do from interpreter.computer import run? # In the future, I guess kernel watching code should be elsewhere? Somewhere server / client agnostic? -from ..server.utils.kernel import put_kernel_messages_into_queue -from ..server.utils.process_utils import kill_process_tree +from source.server.utils.kernel import put_kernel_messages_into_queue +from source.server.utils.process_utils import kill_process_tree -from ..server.utils.logs import setup_logging -from ..server.utils.logs import logger +from source.server.utils.logs import setup_logging +from source.server.utils.logs import logger setup_logging() os.environ["STT_RUNNER"] = "server" os.environ["TTS_RUNNER"] = "server" -from ..utils.accumulator import Accumulator +from source.utils.accumulator import Accumulator accumulator = Accumulator() @@ -324,4 +321,13 @@ async def start_async(self): def start(self): if os.getenv('TEACH_MODE') != "True": asyncio.run(self.start_async()) - p.terminate() \ No newline at end of file + p.terminate() + +device = Device() + +def run_device(server_url): + device.server_url = server_url + device.start() + +if __name__ == "__main__": + run_device() \ No newline at end of file diff --git a/core/source/clients/__init__.py b/core/source/clients/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/core/source/clients/mac/__init__.py b/core/source/clients/mac/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/core/source/clients/mac/device.py b/core/source/clients/mac/device.py deleted file mode 100644 index a9a79c0..0000000 --- a/core/source/clients/mac/device.py +++ /dev/null @@ -1,10 +0,0 @@ -from ..base_device import Device - -device = Device() - -def main(server_url): - device.server_url = server_url - device.start() - -if __name__ == "__main__": - main() diff --git a/core/source/server/conftest.py b/core/source/server/conftest.py index 4684194..6ca7ddd 100644 --- a/core/source/server/conftest.py +++ b/core/source/server/conftest.py @@ -1,8 +1,5 @@ -import os -import sys import pytest from source.server.i import configure_interpreter -from unittest.mock import Mock from interpreter import OpenInterpreter from fastapi.testclient import TestClient from .server import app diff --git a/core/source/server/i.py b/core/source/server/i.py index 89deb5e..170d850 100644 --- a/core/source/server/i.py +++ b/core/source/server/i.py @@ -1,11 +1,9 @@ from dotenv import load_dotenv load_dotenv() # take environment variables from .env. -import os import glob import time import json -from pathlib import Path from interpreter import OpenInterpreter import shutil diff --git a/core/source/server/server.py b/core/source/server/server.py index f468634..f575a3c 100644 --- a/core/source/server/server.py +++ b/core/source/server/server.py @@ -3,7 +3,6 @@ import traceback from platformdirs import user_data_dir -import ast import json import queue import os @@ -13,9 +12,7 @@ from fastapi import FastAPI, Request from fastapi.responses import PlainTextResponse from starlette.websockets import WebSocket, WebSocketDisconnect -from pathlib import Path import asyncio -import urllib.parse from .utils.kernel import put_kernel_messages_into_queue from .i import configure_interpreter from interpreter import interpreter @@ -352,7 +349,6 @@ def stream_tts(sentence): from uvicorn import Config, Server import os -import platform from importlib import import_module # these will be overwritten @@ -363,7 +359,7 @@ def stream_tts(sentence): async def startup_event(): server_url = f"{HOST}:{PORT}" print("") - print_markdown(f"\n*Ready.*\n") + print_markdown("\n*Ready.*\n") print("") @app.on_event("shutdown") diff --git a/core/source/server/services/stt/local-whisper/stt.py b/core/source/server/services/stt/local-whisper/stt.py index b318e8e..79d9bd9 100644 --- a/core/source/server/services/stt/local-whisper/stt.py +++ b/core/source/server/services/stt/local-whisper/stt.py @@ -10,8 +10,6 @@ import ffmpeg import subprocess -import os -import subprocess class Stt: diff --git a/core/source/server/services/tts/openai/tts.py b/core/source/server/services/tts/openai/tts.py index a3759bb..8fe9d1d 100644 --- a/core/source/server/services/tts/openai/tts.py +++ b/core/source/server/services/tts/openai/tts.py @@ -2,8 +2,6 @@ import tempfile from openai import OpenAI import os -import subprocess -import tempfile client = OpenAI() diff --git a/core/source/server/skills/schedule.py b/core/source/server/skills/schedule.py index f351c59..e3ae1c4 100644 --- a/core/source/server/skills/schedule.py +++ b/core/source/server/skills/schedule.py @@ -3,7 +3,6 @@ from pytimeparse import parse from crontab import CronTab from uuid import uuid4 -from datetime import datetime from platformdirs import user_data_dir def schedule(message="", start=None, interval=None) -> None: diff --git a/core/source/server/tests/__init__.py b/core/source/server/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/core/source/server/tests/test_run.py b/core/source/server/tests/test_run.py deleted file mode 100644 index ce04932..0000000 --- a/core/source/server/tests/test_run.py +++ /dev/null @@ -1,41 +0,0 @@ -# test_main.py -import subprocess -import uuid -import pytest -from source.server.i import configure_interpreter -from unittest.mock import Mock -from fastapi.testclient import TestClient - - - -@pytest.mark.asyncio -def test_ping(client): - response = client.get("/ping") - assert response.status_code == 200 - assert response.text == "pong" - - -# def test_interpreter_chat(mock_interpreter): -# # Set up a sample conversation -# messages = [ -# {"role": "user", "type": "message", "content": "Hello."}, -# {"role": "assistant", "type": "message", "content": "Hi there!"}, -# # Add more messages as needed -# ] - -# # Configure the mock interpreter with the sample conversation -# mock_interpreter.messages = messages - -# # Simulate additional user input -# user_input = {"role": "user", "type": "message", "content": "How are you?"} -# mock_interpreter.chat([user_input]) - -# # Ensure the interpreter processed the user input -# assert len(mock_interpreter.messages) == len(messages) -# assert mock_interpreter.messages[-1]["role"] == "assistant" -# assert "don't have feelings" in mock_interpreter.messages[-1]["content"] - -# def test_interpreter_configuration(mock_interpreter): -# # Test interpreter configuration -# interpreter = configure_interpreter(mock_interpreter) -# assert interpreter is not None \ No newline at end of file diff --git a/core/source/server/tunnel.py b/core/source/server/tunnel.py index 5d55a97..8d74f0d 100644 --- a/core/source/server/tunnel.py +++ b/core/source/server/tunnel.py @@ -1,12 +1,10 @@ -import os import subprocess import re -import shutil import time from ..utils.print_markdown import print_markdown def create_tunnel(tunnel_method='ngrok', server_host='localhost', server_port=10001): - print_markdown(f"Exposing server to the internet...") + print_markdown("Exposing server to the internet...") if tunnel_method == "bore": try: diff --git a/core/source/server/utils/kernel.py b/core/source/server/utils/kernel.py index 91433b1..88463af 100644 --- a/core/source/server/utils/kernel.py +++ b/core/source/server/utils/kernel.py @@ -3,10 +3,8 @@ import asyncio import subprocess -import platform from .logs import setup_logging -from .logs import logger setup_logging() def get_kernel_messages(): diff --git a/core/source/server/utils/local_mode.py b/core/source/server/utils/local_mode.py index 6d7113c..3f3aaf4 100644 --- a/core/source/server/utils/local_mode.py +++ b/core/source/server/utils/local_mode.py @@ -1,6 +1,4 @@ import sys -import os -import platform import subprocess import time import inquirer @@ -8,95 +6,48 @@ def select_local_model(): - - # START OF LOCAL MODEL PROVIDER LOGIC - interpreter.display_message("> 01 is compatible with several local model providers.\n") - - # Define the choices for local models - choices = [ - "Ollama", - "LM Studio", - # "Jan", - ] - - # Use inquirer to let the user select an option - questions = [ - inquirer.List( - "model", - message="Which one would you like to use?", - choices=choices, - ), - ] - answers = inquirer.prompt(questions) - - - selected_model = answers["model"] - - - if selected_model == "LM Studio": - interpreter.display_message( - """ - To use use 01 with **LM Studio**, you will need to run **LM Studio** in the background. - - 1. Download **LM Studio** from [https://lmstudio.ai/](https://lmstudio.ai/), then start it. - 2. Select a language model then click **Download**. - 3. Click the **<->** button on the left (below the chat button). - 4. Select your model at the top, then click **Start Server**. - - - Once the server is running, you can begin your conversation below. - - """ - ) - time.sleep(1) + selected_model = "Ollama" + try: - interpreter.llm.api_base = "http://localhost:1234/v1" - interpreter.llm.max_tokens = 1000 - interpreter.llm.context_window = 8000 - interpreter.llm.api_key = "x" - - elif selected_model == "Ollama": - try: - - # List out all downloaded ollama models. Will fail if ollama isn't installed - result = subprocess.run(["ollama", "list"], capture_output=True, text=True, check=True) - lines = result.stdout.split('\n') - names = [line.split()[0].replace(":latest", "") for line in lines[1:] if line.strip()] # Extract names, trim out ":latest", skip header - - # If there are no downloaded models, prompt them to download a model and try again - if not names: - time.sleep(1) - - interpreter.display_message(f"\nYou don't have any Ollama models downloaded. To download a new model, run `ollama run `, then start a new 01 session. \n\n For a full list of downloadable models, check out [https://ollama.com/library](https://ollama.com/library) \n") - - print("Please download a model then try again\n") - time.sleep(2) - sys.exit(1) + # List out all downloaded ollama models. Will fail if ollama isn't installed + result = subprocess.run(["ollama", "list"], capture_output=True, text=True, check=True) + lines = result.stdout.split('\n') + names = [line.split()[0].replace(":latest", "") for line in lines[1:] if line.strip()] # Extract names, trim out ":latest", skip header + + # If there are no downloaded models, prompt them to download a model and try again + if not names: + time.sleep(1) - # If there are models, prompt them to select one - else: - time.sleep(1) - interpreter.display_message(f"**{len(names)} Ollama model{'s' if len(names) != 1 else ''} found.** To download a new model, run `ollama run `, then start a new 01 session. \n\n For a full list of downloadable models, check out [https://ollama.com/library](https://ollama.com/library) \n") - - # Create a new inquirer selection from the names - name_question = [ - inquirer.List('name', message="Select a downloaded Ollama model", choices=names), - ] - name_answer = inquirer.prompt(name_question) - selected_name = name_answer['name'] if name_answer else None - - # Set the model to the selected model - interpreter.llm.model = f"ollama/{selected_name}" - interpreter.display_message(f"\nUsing Ollama model: `{selected_name}` \n") - time.sleep(1) + interpreter.display_message("\nYou don't have any Ollama models downloaded. To download a new model, run `ollama run `, then start a new 01 session. \n\n For a full list of downloadable models, check out [https://ollama.com/library](https://ollama.com/library) \n") - # If Ollama is not installed or not recognized as a command, prompt the user to download Ollama and try again - except (subprocess.CalledProcessError, FileNotFoundError) as e: - print("Ollama is not installed or not recognized as a command.") - time.sleep(1) - interpreter.display_message(f"\nPlease visit [https://ollama.com/](https://ollama.com/) to download Ollama and try again\n") + print("Please download a model then try again\n") time.sleep(2) sys.exit(1) + + # If there are models, prompt them to select one + else: + time.sleep(1) + interpreter.display_message(f"**{len(names)} Ollama model{'s' if len(names) != 1 else ''} found.** To download a new model, run `ollama run `, then start a new 01 session. \n\n For a full list of downloadable models, check out [https://ollama.com/library](https://ollama.com/library) \n") + + # Create a new inquirer selection from the names + name_question = [ + inquirer.List('name', message="Select a downloaded Ollama model", choices=names), + ] + name_answer = inquirer.prompt(name_question) + selected_name = name_answer['name'] if name_answer else None + + # Set the model to the selected model + interpreter.llm.model = "ollama/mistral" + interpreter.display_message(f"\nUsing Ollama model: `{selected_name}` \n") + time.sleep(1) + + # If Ollama is not installed or not recognized as a command, prompt the user to download Ollama and try again + except (subprocess.CalledProcessError, FileNotFoundError): + print("Ollama is not installed or not recognized as a command.") + time.sleep(1) + interpreter.display_message("\nPlease visit [https://ollama.com/](https://ollama.com/) to download Ollama and try again\n") + time.sleep(2) + sys.exit(1) # elif selected_model == "Jan": # interpreter.display_message( diff --git a/core/source/server/utils/process_utils.py b/core/source/server/utils/process_utils.py index adcf028..ebcd08c 100644 --- a/core/source/server/utils/process_utils.py +++ b/core/source/server/utils/process_utils.py @@ -25,4 +25,4 @@ def kill_process_tree(): except psutil.NoSuchProcess: print(f"Process {pid} does not exist or is already terminated") except psutil.AccessDenied: - print(f"Permission denied to terminate some processes") + print("Permission denied to terminate some processes") diff --git a/core/start.py b/core/start.py index 0c4e198..1af4615 100644 --- a/core/start.py +++ b/core/start.py @@ -2,10 +2,10 @@ import asyncio import threading import os -import importlib from source.server.tunnel import create_tunnel from source.server.server import main from source.server.utils.local_mode import select_local_model +from mac_device import run_device import signal app = typer.Typer() @@ -117,8 +117,7 @@ def handle_exit(signum, frame): tunnel_thread.start() if client: - module = importlib.import_module(f".clients.mac.device", package='source') - client_thread = threading.Thread(target=module.main, args=[server_url]) + client_thread = threading.Thread(target=run_device, args=[server_url]) client_thread.start() try: