From 7ee0774b4440986eddea68c7f57e82898c50a5b4 Mon Sep 17 00:00:00 2001 From: grorge Date: Mon, 9 Sep 2024 13:43:55 +0800 Subject: [PATCH] [Example] MLX: add documentation --- wasmedge-mlx/README.md | 81 ++++++++++++++++++++++++++++++++++++++++ wasmedge-mlx/src/main.rs | 53 +++++++++++++------------- 2 files changed, 107 insertions(+), 27 deletions(-) create mode 100644 wasmedge-mlx/README.md diff --git a/wasmedge-mlx/README.md b/wasmedge-mlx/README.md new file mode 100644 index 0000000..95dbfd4 --- /dev/null +++ b/wasmedge-mlx/README.md @@ -0,0 +1,81 @@ +# MLX example with WasmEdge WASI-NN MLX plugin + +This example demonstrates using WasmEdge WASI-NN MLX plugin to perform an inference task with LLM model. + +## Supported Models + +| Family | Models | +|--------|--------| +| LLaMA 2 | llama_2_7b_chat_hf | +| LLaMA 3 | llama_3_8b | +| TinyLLaMA | tiny_llama_1.1B_chat_v1.0 | + +## Install WasmEdge with WASI-NN MLX plugin + +The MLX backend relies on [MLX](https://github.com/ml-explore/mlx), but we will auto-download MLX when you build WasmEdge. You do not need to install it yourself. If you want to custom MLX, install it yourself or set the `CMAKE_PREFIX_PATH` variable when configuring cmake. + +Build and install WasmEdge from source: + +``` bash +cd + +cmake -GNinja -Bbuild -DCMAKE_BUILD_TYPE=Release -DWASMEDGE_PLUGIN_WASI_NN_BACKEND="mlx" +cmake --build build + +# For the WASI-NN plugin, you should install this project. +cmake --install build +``` + +Then you will have an executable `wasmedge` runtime under `/usr/local/bin` and the WASI-NN with MLX backend plug-in under `/usr/local/lib/wasmedge/libwasmedgePluginWasiNN.so` after installation. + +## Download the model and tokenizer + +In this example, we will use `tiny_llama_1.1B_chat_v1.0`, which you can change to `llama_2_7b_chat_hf` or `llama_3_8b`. + +``` bash +# Download model weight +wget https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0/resolve/main/model.safetensors +# Download tokenizer +wget https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0/resolve/main/tokenizer.json +``` + +## Build wasm + +Run the following command to build wasm, the output WASM file will be at `target/wasm32-wasi/release/` + +```bash +cargo build --target wasm32-wasi --release +``` +## Execute + +Execute the WASM with the `wasmedge` using nn-preload to load model. + +``` bash +wasmedge --dir .:. \ + --nn-preload default:mlx:AUTO:model.safetensors \ + ./target/wasm32-wasi/release/wasmedge-mlx.wasm default + +``` + +If your model has multiple weight files, you need to provide all in the nn-preload. + +For example: +``` bash +wasmedge --dir .:. \ + --nn-preload default:mlx:AUTO:llama2-7b/model-00001-of-00002.safetensors:llama2-7b/model-00002-of-00002.safetensors \ + ./target/wasm32-wasi/release/wasmedge-mlx.wasm default +``` + +## Other + +There are some metadata for MLX plugin you can set. + +- model_type (required): LLM model type. +- tokenizer (required): tokenizer.json path +- max_token (option): maximum generate token number, default is 1024. +- enable_debug_log (option): if print debug log, default is false. + +``` rust +let graph = GraphBuilder::new(GraphEncoding::Mlx, ExecutionTarget::AUTO) + .config(serde_json::to_string(&json!({"model_type": "tiny_llama_1.1B_chat_v1.0", "tokenizer":tokenizer_path, "max_token":100})) +``` \ No newline at end of file diff --git a/wasmedge-mlx/src/main.rs b/wasmedge-mlx/src/main.rs index a225c92..14df217 100644 --- a/wasmedge-mlx/src/main.rs +++ b/wasmedge-mlx/src/main.rs @@ -1,42 +1,41 @@ -use tokenizers::tokenizer::Tokenizer; use serde_json::json; +use std::env; use wasmedge_wasi_nn::{ - self, ExecutionTarget, GraphBuilder, GraphEncoding, GraphExecutionContext, - TensorType, + self, ExecutionTarget, GraphBuilder, GraphEncoding, GraphExecutionContext, TensorType, }; -use std::env; -fn get_data_from_context(context: &GraphExecutionContext, index: usize) -> Vec { - // Preserve for 4096 tokens with average token length 8 - const MAX_OUTPUT_BUFFER_SIZE: usize = 4096 * 8; +fn get_data_from_context(context: &GraphExecutionContext, index: usize) -> String { + // Preserve for 4096 tokens with average token length 6 + const MAX_OUTPUT_BUFFER_SIZE: usize = 4096 * 6; let mut output_buffer = vec![0u8; MAX_OUTPUT_BUFFER_SIZE]; - let _ = context + let mut output_size = context .get_output(index, &mut output_buffer) .expect("Failed to get output"); + output_size = std::cmp::min(MAX_OUTPUT_BUFFER_SIZE, output_size); - return output_buffer; + return String::from_utf8_lossy(&output_buffer[..output_size]).to_string(); } -fn get_output_from_context(context: &GraphExecutionContext) -> Vec { +fn get_output_from_context(context: &GraphExecutionContext) -> String { get_data_from_context(context, 0) } fn main() { - let tokenizer_path = "tokenizer.json"; - let prompt = "Once upon a time, there existed a little girl,"; - - let graph = GraphBuilder::new(GraphEncoding::Mlx, ExecutionTarget::AUTO) - .config(serde_json::to_string(&json!({"tokenizer":tokenizer_path})).expect("Failed to serialize options")) + let tokenizer_path = "./tokenizer.json"; + let prompt = "Once upon a time, there existed a little girl,"; + let args: Vec = env::args().collect(); + let model_name: &str = &args[1]; + let graph = GraphBuilder::new(GraphEncoding::Mlx, ExecutionTarget::AUTO) + .config(serde_json::to_string(&json!({"model_type": "tiny_llama_1.1B_chat_v1.0", "tokenizer":tokenizer_path, "max_token":100})).expect("Failed to serialize options")) .build_from_cache(model_name) .expect("Failed to build graph"); - let mut context = graph - .init_execution_context() - .expect("Failed to init context"); - let tensor_data = prompt.as_bytes().to_vec(); - context - .set_input(0, TensorType::U8, &[1], &tensor_data) - .expect("Failed to set input"); - context.compute().expect("Failed to compute"); - let output_bytes = get_output_from_context(&context); + let mut context = graph + .init_execution_context() + .expect("Failed to init context"); + let tensor_data = prompt.as_bytes().to_vec(); + context + .set_input(0, TensorType::U8, &[1], &tensor_data) + .expect("Failed to set input"); + context.compute().expect("Failed to compute"); + let output = get_output_from_context(&context); - println!("{}", output.trim()); - -} \ No newline at end of file + println!("{}", output.trim()); +}