Skip to content

Commit

Permalink
Merge pull request #38 from modal-labs/erikbern/inference-web-endpoint
Browse files Browse the repository at this point in the history
Add web endpoint to the inference server
  • Loading branch information
erikbern authored Mar 12, 2024
2 parents bcc4e63 + 3cea777 commit 0772e84
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 20 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ modal run --detach src.train --config=config/mistral.yml --data=data/sqlqa.jsonl
4. Try the model from a completed training run. You can select a folder via `modal volume ls example-runs-vol`, and then specify the training folder with the `--run-folder` flag (something like `/runs/axo-2023-11-24-17-26-66e8`) for inference:

```bash
modal run -q src.inference --run-folder /runs/<run_tag>
modal run -q src.inference --run-name <run_tag>
```

Our quickstart example trains a 7B model on a text-to-SQL dataset as a proof of concept (it takes just a few minutes). It uses DeepSpeed ZeRO-3 to shard the model state across 2 A100s. Inference on the fine-tuned model displays conformity to the output structure (`[SQL] ... [/SQL]`). To achieve better results, you would need to use more data! Refer to the full development section below.
Expand Down
2 changes: 1 addition & 1 deletion ci/check_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
CREATE TABLE head (age INTEGER)
How many heads of the departments are older than 56 ? [/INST] """

p = subprocess.Popen(["modal", "run", "src.inference", "--run-folder", f"/runs/{run_name}", "--prompt", prompt], stdout=subprocess.PIPE)
p = subprocess.Popen(["modal", "run", "src.inference", "--run-name", run_name, "--prompt", prompt], stdout=subprocess.PIPE)
output = ""

for line in iter(p.stdout.readline, b''):
Expand Down
55 changes: 38 additions & 17 deletions src/inference.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import time
from pathlib import Path
import yaml

import modal
from fastapi.responses import StreamingResponse

from .common import stub, vllm_image, VOLUME_CONFIG

N_INFERENCE_GPU = 2

with vllm_image.imports():
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid


@stub.cls(
gpu=modal.gpu.H100(count=N_INFERENCE_GPU),
Expand All @@ -16,16 +24,24 @@
container_idle_timeout=120,
)
class Inference:
def __init__(self, run_folder: str) -> None:
import yaml

with open(f"{run_folder}/config.yml") as f:
config = yaml.safe_load(f.read())
model_path = (Path(run_folder) / config["output_dir"] / "merged").resolve()

def __init__(self, run_name: str = "", run_dir: str = "/runs") -> None:
self.run_name = run_name
self.run_dir = run_dir

@modal.enter()
def init(self):
if self.run_name:
run_name = self.run_name
else:
# Pick the last run automatically
run_name = VOLUME_CONFIG[self.run_dir].listdir("/")[-1].path

# Grab the output dir (usually "lora-out")
with open(f"{self.run_dir}/{run_name}/config.yml") as f:
output_dir = yaml.safe_load(f.read())["output_dir"]

model_path = f"{self.run_dir}/{run_name}/{output_dir}/merged"
print("Initializing vLLM engine on:", model_path)
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine

engine_args = AsyncEngineArgs(
model=model_path,
Expand All @@ -34,14 +50,10 @@ def __init__(self, run_folder: str) -> None:
)
self.engine = AsyncLLMEngine.from_engine_args(engine_args)

@modal.method()
async def completion(self, input: str):
async def _stream(self, input: str):
if not input:
return

from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid

sampling_params = SamplingParams(
repetition_penalty=1.1,
temperature=0.2,
Expand Down Expand Up @@ -71,16 +83,25 @@ async def completion(self, input: str):
print(f"Request completed: {throughput:.4f} tokens/s")
print(request_output.outputs[0].text)

@modal.method()
async def completion(self, input: str):
async for text in self._stream(input):
yield text

@modal.web_endpoint()
async def web(self, input: str):
return StreamingResponse(self._stream(input), media_type="text/event-stream")


@stub.local_entrypoint()
def inference_main(run_folder: str, prompt: str = ""):
def inference_main(run_name: str = "", prompt: str = ""):
if prompt:
for chunk in Inference(run_folder).completion.remote_gen(prompt):
for chunk in Inference(run_name).completion.remote_gen(prompt):
print(chunk, end="")
else:
prompt = input(
"Enter a prompt (including the prompt template, e.g. [INST] ... [/INST]):\n"
)
print("Loading model ...")
for chunk in Inference(run_folder).completion.remote_gen(prompt):
for chunk in Inference(run_name).completion.remote_gen(prompt):
print(chunk, end="")
2 changes: 1 addition & 1 deletion src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,4 @@ def main(

print(f"Training complete. Run tag: {run_name}")
print(f"To inspect weights, run `modal volume ls example-runs-vol {run_name}`")
print(f"To run sample inference, run `modal run -q src.inference --run-folder /runs/{run_name}`")
print(f"To run sample inference, run `modal run -q src.inference --run-name {run_name}`")

0 comments on commit 0772e84

Please sign in to comment.