Skip to content

Commit

Permalink
Add web endpoint to the inference server
Browse files Browse the repository at this point in the history
  • Loading branch information
erikbern committed Mar 12, 2024
1 parent bcc4e63 commit 29ca3b7
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 14 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ modal run --detach src.train --config=config/mistral.yml --data=data/sqlqa.jsonl
4. Try the model from a completed training run. You can select a folder via `modal volume ls example-runs-vol`, and then specify the training folder with the `--run-folder` flag (something like `/runs/axo-2023-11-24-17-26-66e8`) for inference:

```bash
modal run -q src.inference --run-folder /runs/<run_tag>
modal run -q src.inference --run-name <run_tag>
```

Our quickstart example trains a 7B model on a text-to-SQL dataset as a proof of concept (it takes just a few minutes). It uses DeepSpeed ZeRO-3 to shard the model state across 2 A100s. Inference on the fine-tuned model displays conformity to the output structure (`[SQL] ... [/SQL]`). To achieve better results, you would need to use more data! Refer to the full development section below.
Expand Down
45 changes: 33 additions & 12 deletions src/inference.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import time
from pathlib import Path
import yaml

import modal
from fastapi.responses import StreamingResponse

from .common import stub, vllm_image, VOLUME_CONFIG

Expand All @@ -16,14 +18,25 @@
container_idle_timeout=120,
)
class Inference:
def __init__(self, run_folder: str) -> None:
import yaml

with open(f"{run_folder}/config.yml") as f:
config = yaml.safe_load(f.read())
model_path = (Path(run_folder) / config["output_dir"] / "merged").resolve()

def __init__(self, run_name: str = "", model_dir: str = "/runs") -> None:
self.run_name = run_name
self.model_dir = model_dir

@modal.enter()
def init(self):
if self.run_name:
run_name = self.run_name
else:
# Pick the last run automatically
run_name = VOLUME_CONFIG[self.model_dir].listdir("/")[-1].path

# Grab the output dir (usually "lora-out")
with open(f"{self.model_dir}/{run_name}/config.yml") as f:
output_dir = yaml.safe_load(f.read())["output_dir"]

model_path = f"{self.model_dir}/{run_name}/{output_dir}/merged"
print("Initializing vLLM engine on:", model_path)

from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine

Expand All @@ -34,8 +47,7 @@ def __init__(self, run_folder: str) -> None:
)
self.engine = AsyncLLMEngine.from_engine_args(engine_args)

@modal.method()
async def completion(self, input: str):
async def _stream(self, input: str):
if not input:
return

Expand Down Expand Up @@ -71,16 +83,25 @@ async def completion(self, input: str):
print(f"Request completed: {throughput:.4f} tokens/s")
print(request_output.outputs[0].text)

@modal.method()
async def completion(self, input: str):
async for text in self._stream(input):
yield text

@modal.web_endpoint()
async def web(self, input: str):
return StreamingResponse(self._stream(input), media_type="text/event-stream")


@stub.local_entrypoint()
def inference_main(run_folder: str, prompt: str = ""):
def inference_main(run_name: str = "", prompt: str = ""):
if prompt:
for chunk in Inference(run_folder).completion.remote_gen(prompt):
for chunk in Inference(run_name).completion.remote_gen(prompt):
print(chunk, end="")
else:
prompt = input(
"Enter a prompt (including the prompt template, e.g. [INST] ... [/INST]):\n"
)
print("Loading model ...")
for chunk in Inference(run_folder).completion.remote_gen(prompt):
for chunk in Inference(run_name).completion.remote_gen(prompt):
print(chunk, end="")
2 changes: 1 addition & 1 deletion src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,4 @@ def main(

print(f"Training complete. Run tag: {run_name}")
print(f"To inspect weights, run `modal volume ls example-runs-vol {run_name}`")
print(f"To run sample inference, run `modal run -q src.inference --run-folder /runs/{run_name}`")
print(f"To run sample inference, run `modal run -q src.inference --run-name {run_name}`")

0 comments on commit 29ca3b7

Please sign in to comment.