Skip to content

Commit

Permalink
Fixes mypy issues
Browse files Browse the repository at this point in the history
  • Loading branch information
luiscape committed Aug 28, 2023
1 parent 7c55f11 commit 7c30198
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions 06_gpu_and_ml/mlc/mlc_inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# # Llama 2 inference with MLC
#
# Machine Learning Compilation (MLC) is high-performance tool for serving
# [Machine Learning Compilation (MLC)](https://mlc.ai/mlc-llm/) is high-performance tool for serving
# LLMs including Llama 2. We will use the `mlc_chat` and the pre-compiled
# Llama 2 binaries to run inference using a Modal GPU.
#
Expand All @@ -10,6 +10,8 @@
import queue
import threading

from typing import Generator, List, Dict

# Determine which [GPU](https://modal.com/docs/guide/gpu#gpu-acceleration) you want to use.
GPU: str = "a10g"

Expand Down Expand Up @@ -70,7 +72,7 @@
# memory and run inference on an input prompt. This is a generator, streaming
# tokens back to the client as they are generated.
@stub.function(gpu=GPU)
def generate(prompt: str) -> dict[str, str]:
def generate(prompt: str) -> Generator[Dict[str, str], None, None]:
from mlc_chat import ChatModule
from mlc_chat.callback import DeltaCallback

Expand All @@ -82,9 +84,9 @@ def generate(prompt: str) -> dict[str, str]:
class QueueCallback(DeltaCallback):
"""Stream the output of the chat module to client."""

def __init__(self, callback_interval: int = 2):
def __init__(self, callback_interval: float):
super().__init__()
self.queue = queue.Queue()
self.queue:queue.Queue = queue.Queue()
self.stopped = False
self.callback_interval = callback_interval

Expand All @@ -99,7 +101,7 @@ def stopped_callback(self):
model=f"/dist/prebuilt/mlc-chat-Llama-2-{LLAMA_MODEL_SIZE}-chat-hf-q4f16_1",
lib_path=f"/dist/prebuilt/lib/Llama-2-{LLAMA_MODEL_SIZE}-chat-hf-q4f16_1-cuda.so",
)
queue_callback = QueueCallback(callback_interval=1)
queue_callback = QueueCallback(callback_interval=0.1)

# Generate tokens in a background thread so we can yield tokens
# to caller as a generator.
Expand Down Expand Up @@ -129,7 +131,7 @@ def main(prompt: str):
import curses

def _generate(stdscr):
buffer = []
buffer:List[str] = []

def _buffered_message():
return "".join(buffer) + ("\n" * 4)
Expand Down

0 comments on commit 7c30198

Please sign in to comment.