From 4d3c06ad7abff7f0611f8d019014450e54cccd14 Mon Sep 17 00:00:00 2001 From: Setepenre Date: Wed, 4 Oct 2023 13:41:26 +0000 Subject: [PATCH] Add walk through --- docs/examples/llm/client.py | 6 +++ docs/examples/llm/inference_server.sh | 55 ++++++++++++++++++++++----- docs/examples/llm/vllm.rst | 52 +++++++++++++++++++++++-- 3 files changed, 101 insertions(+), 12 deletions(-) diff --git a/docs/examples/llm/client.py b/docs/examples/llm/client.py index 63cabf5f..86e3ce7e 100644 --- a/docs/examples/llm/client.py +++ b/docs/examples/llm/client.py @@ -1,5 +1,9 @@ import openai + +# +# Parse the server info from the job comment +# def parse_meta(comment): data = dict() if comment != "(null)": @@ -22,10 +26,12 @@ def get_job_comment(name="inference_server.sh"): server = parse_meta(get_job_comment()) +# Override OpenAPI API URL with out custom server openai.api_key = "EMPTY" openai.api_base = f"http://{server['host']}:{server['port']}/v1" +# profit completion = openai.Completion.create( model=server['model'], prompt=args.prompt diff --git a/docs/examples/llm/inference_server.sh b/docs/examples/llm/inference_server.sh index 985a4ee5..acf2a857 100644 --- a/docs/examples/llm/inference_server.sh +++ b/docs/examples/llm/inference_server.sh @@ -16,8 +16,45 @@ #SBATCH --ntasks-per-node=1 #SBATCH --mem=32G -MODEL="$1" -PATH="$2" +usage() { + echo "Usage: $0 [-m] [-p] + echo " -h Display this help message." + echo " -m MODEL Specify a file to process." + echo " -p PATH Specify a directory to work in." + echo " ARGUMENT Any additional argument you want to process." + exit 1 +} + +MODEL="" +PATH="" +ENV="./env" + + +while getopts ":hf:d:" opt; do + case $opt in + h) + usage + ;; + m) + MODEL="$OPTARG" + ;; + p) + PATH="$OPTARG" + ;; + e) + ENV="$OPTARG" + ;; + \?) + echo "Invalid option: -$OPTARG" >&2 + usage + ;; + :) + echo "Option -$OPTARG requires an argument." >&2 + usage + ;; + esac +done + export MILA_WEIGHTS="/network/weights/" @@ -33,19 +70,19 @@ source $CONDA_BASE/../etc/profile.d/conda.sh # # Create a new environment # -conda create --prefix ./env python=3.9 -y -conda activate ./env +if [ ! -d "$ENV" ]; then + conda create --prefix $ENV python=3.9 -y +fi +conda activate $ENV pip install vllm -# -# Save metadata for retrival -# - PORT=$(python -c "import socket; sock = socket.socket(); sock.bind(('', 0)); print(sock.getsockname()[1])") HOST="$(hostname)" NAME="$WEIGHTS/$MODEL" -echo " -> $HOST:$PORT" +# +# Save metadata for retrival +# scontrol update job $SLURM_JOB_ID comment="model=$MODEL|host=$HOST|port=$PORT|shared=y" # diff --git a/docs/examples/llm/vllm.rst b/docs/examples/llm/vllm.rst index 0715ce2c..e55c1291 100644 --- a/docs/examples/llm/vllm.rst +++ b/docs/examples/llm/vllm.rst @@ -2,12 +2,58 @@ LLM Inference ============= +Server +------ + +`vLLM `_ comes with its own server entry point that mimicks OpenAI's API. +It is very easy to setup and supports a wide range of models through Huggingfaces. + + +.. code-block:: + + # sbatch inference_server.sh -m MODEL_NAME -p WEIGHT_PATH -e CONDA_ENV_NAME_TO_USE + sbatch inference_server.sh -m Llama-2-7b-chat-hf -p /network/weights/llama.var/llama2/Llama-2-7b-chat-hf -e base + + +By default the script will launch the server on an rtx8000 for 15 minutes. +You can override the defaults by specifying arguments to sbatch. -Dependencies ------------- .. code-block:: - sbatch inference_server.sh Llama-2-7b-chat-hf /network/weights/llama.var/llama2/Llama-2-7b-chat-hf + sbatch --time=00:30:00 inference_server.sh -m Llama-2-7b-chat-hf -p /network/weights/llama.var/llama2/Llama-2-7b-chat-hf -e base + +.. note:: + + We are using job comment to store hostname, port and model names, + which enable the client to automatically pick them up on its side. + + +.. literalinclude:: inference_server.sh + :language: bash + + +Client +------ + +Becasue vLLM replicates OpenAI's API, the client side is quite straight forward. +Own OpenAI's client can be reused. + +.. warning:: + + The server takes a while to setup you might to have to wait a few minutes + before the server is ready for inference. + + You can check the job log of the server. + Look for + + +.. note:: + + We use squeue to look for the inference server job to configure the + url endpoint automatically. + Make sure your job name is unique! +.. literalinclude:: client.py + :language: python