diff --git a/README.md b/README.md index 98789cc..c0a9e31 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ modal run --detach src.train --config=config/mistral.yml --data=data/sqlqa.jsonl 4. Try the model from a completed training run. You can select a folder via `modal volume ls example-runs-vol`, and then specify the training folder with the `--run-folder` flag (something like `/runs/axo-2023-11-24-17-26-66e8`) for inference: ```bash -modal run -q src.inference --run-folder /runs/ +modal run -q src.inference --run-name ``` Our quickstart example trains a 7B model on a text-to-SQL dataset as a proof of concept (it takes just a few minutes). It uses DeepSpeed ZeRO-3 to shard the model state across 2 A100s. Inference on the fine-tuned model displays conformity to the output structure (`[SQL] ... [/SQL]`). To achieve better results, you would need to use more data! Refer to the full development section below. diff --git a/src/inference.py b/src/inference.py index ddc88c6..2adc736 100644 --- a/src/inference.py +++ b/src/inference.py @@ -1,7 +1,9 @@ import time from pathlib import Path +import yaml import modal +from fastapi.responses import StreamingResponse from .common import stub, vllm_image, VOLUME_CONFIG @@ -16,14 +18,25 @@ container_idle_timeout=120, ) class Inference: - def __init__(self, run_folder: str) -> None: - import yaml - - with open(f"{run_folder}/config.yml") as f: - config = yaml.safe_load(f.read()) - model_path = (Path(run_folder) / config["output_dir"] / "merged").resolve() - + def __init__(self, run_name: str = "", model_dir: str = "/runs") -> None: + self.run_name = run_name + self.model_dir = model_dir + + @modal.enter() + def init(self): + if self.run_name: + run_name = self.run_name + else: + # Pick the last run automatically + run_name = VOLUME_CONFIG[self.model_dir].listdir("/")[-1].path + + # Grab the output dir (usually "lora-out") + with open(f"{self.model_dir}/{run_name}/config.yml") as f: + output_dir = yaml.safe_load(f.read())["output_dir"] + + model_path = f"{self.model_dir}/{run_name}/{output_dir}/merged" print("Initializing vLLM engine on:", model_path) + from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -34,8 +47,7 @@ def __init__(self, run_folder: str) -> None: ) self.engine = AsyncLLMEngine.from_engine_args(engine_args) - @modal.method() - async def completion(self, input: str): + async def _stream(self, input: str): if not input: return @@ -71,16 +83,25 @@ async def completion(self, input: str): print(f"Request completed: {throughput:.4f} tokens/s") print(request_output.outputs[0].text) + @modal.method() + async def completion(self, input: str): + async for text in self._stream(input): + yield text + + @modal.web_endpoint() + async def web(self, input: str): + return StreamingResponse(self._stream(input), media_type="text/event-stream") + @stub.local_entrypoint() -def inference_main(run_folder: str, prompt: str = ""): +def inference_main(run_name: str = "", prompt: str = ""): if prompt: - for chunk in Inference(run_folder).completion.remote_gen(prompt): + for chunk in Inference(run_name).completion.remote_gen(prompt): print(chunk, end="") else: prompt = input( "Enter a prompt (including the prompt template, e.g. [INST] ... [/INST]):\n" ) print("Loading model ...") - for chunk in Inference(run_folder).completion.remote_gen(prompt): + for chunk in Inference(run_name).completion.remote_gen(prompt): print(chunk, end="") diff --git a/src/train.py b/src/train.py index 5307f57..1c7df7b 100644 --- a/src/train.py +++ b/src/train.py @@ -150,4 +150,4 @@ def main( print(f"Training complete. Run tag: {run_name}") print(f"To inspect weights, run `modal volume ls example-runs-vol {run_name}`") - print(f"To run sample inference, run `modal run -q src.inference --run-folder /runs/{run_name}`") + print(f"To run sample inference, run `modal run -q src.inference --run-name {run_name}`")