Skip to content

Commit

Permalink
Add web endpoint to the inference server
Browse files Browse the repository at this point in the history
  • Loading branch information
erikbern committed Mar 12, 2024
1 parent bcc4e63 commit 3aa7a9e
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 13 deletions.
45 changes: 33 additions & 12 deletions src/inference.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import time
from pathlib import Path
import yaml

import modal
from fastapi.responses import StreamingResponse

from .common import stub, vllm_image, VOLUME_CONFIG

Expand All @@ -16,14 +18,25 @@
container_idle_timeout=120,
)
class Inference:
def __init__(self, run_folder: str) -> None:
import yaml

with open(f"{run_folder}/config.yml") as f:
config = yaml.safe_load(f.read())
model_path = (Path(run_folder) / config["output_dir"] / "merged").resolve()

def __init__(self, run_name: str = "", model_dir: str = "/runs") -> None:
self.run_name = run_name
self.model_dir = model_dir

@modal.enter()
def init(self):
if self.run_name:
run_name = self.run_name
else:
# Pick the last run automatically
run_name = VOLUME_CONFIG[self.model_dir].listdir("/")[-1].path

# Grab the output dir (usually "lora-out")
with open(f"/runs/{run_name}/config.yml") as f:
output_dir = yaml.safe_load(f.read())["output_dir"]

model_path = f"/runs/{run_name}/{output_dir}/merged"
print("Initializing vLLM engine on:", model_path)

from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine

Expand All @@ -34,8 +47,7 @@ def __init__(self, run_folder: str) -> None:
)
self.engine = AsyncLLMEngine.from_engine_args(engine_args)

@modal.method()
async def completion(self, input: str):
async def _stream(self, input: str):
if not input:
return

Expand Down Expand Up @@ -71,16 +83,25 @@ async def completion(self, input: str):
print(f"Request completed: {throughput:.4f} tokens/s")
print(request_output.outputs[0].text)

@modal.method()
async def completion(self, input: str):
async for text in self._stream(input):
yield text

@modal.web_endpoint()
async def web(self, input: str):
return StreamingResponse(self._stream(input), media_type="text/event-stream")


@stub.local_entrypoint()
def inference_main(run_folder: str, prompt: str = ""):
def inference_main(run_name: str = "", prompt: str = ""):
if prompt:
for chunk in Inference(run_folder).completion.remote_gen(prompt):
for chunk in Inference(run_name).completion.remote_gen(prompt):
print(chunk, end="")
else:
prompt = input(
"Enter a prompt (including the prompt template, e.g. [INST] ... [/INST]):\n"
)
print("Loading model ...")
for chunk in Inference(run_folder).completion.remote_gen(prompt):
for chunk in Inference(run_name).completion.remote_gen(prompt):
print(chunk, end="")
2 changes: 1 addition & 1 deletion src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,4 @@ def main(

print(f"Training complete. Run tag: {run_name}")
print(f"To inspect weights, run `modal volume ls example-runs-vol {run_name}`")
print(f"To run sample inference, run `modal run -q src.inference --run-folder /runs/{run_name}`")
print(f"To run sample inference, run `modal run -q src.inference --run-name {run_name}`")

0 comments on commit 3aa7a9e

Please sign in to comment.