-
Notifications
You must be signed in to change notification settings - Fork 39
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
107 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# MLX example with WasmEdge WASI-NN MLX plugin | ||
|
||
This example demonstrates using WasmEdge WASI-NN MLX plugin to perform an inference task with LLM model. | ||
|
||
## Supported Models | ||
|
||
| Family | Models | | ||
|--------|--------| | ||
| LLaMA 2 | llama_2_7b_chat_hf | | ||
| LLaMA 3 | llama_3_8b | | ||
| TinyLLaMA | tiny_llama_1.1B_chat_v1.0 | | ||
|
||
## Install WasmEdge with WASI-NN MLX plugin | ||
|
||
The MLX backend relies on [MLX](https://github.com/ml-explore/mlx), but we will auto-download MLX when you build WasmEdge. You do not need to install it yourself. If you want to custom MLX, install it yourself or set the `CMAKE_PREFIX_PATH` variable when configuring cmake. | ||
|
||
Build and install WasmEdge from source: | ||
|
||
``` bash | ||
cd <path/to/your/wasmedge/source/folder> | ||
|
||
cmake -GNinja -Bbuild -DCMAKE_BUILD_TYPE=Release -DWASMEDGE_PLUGIN_WASI_NN_BACKEND="mlx" | ||
cmake --build build | ||
|
||
# For the WASI-NN plugin, you should install this project. | ||
cmake --install build | ||
``` | ||
|
||
Then you will have an executable `wasmedge` runtime under `/usr/local/bin` and the WASI-NN with MLX backend plug-in under `/usr/local/lib/wasmedge/libwasmedgePluginWasiNN.so` after installation. | ||
|
||
## Download the model and tokenizer | ||
|
||
In this example, we will use `tiny_llama_1.1B_chat_v1.0`, which you can change to `llama_2_7b_chat_hf` or `llama_3_8b`. | ||
|
||
``` bash | ||
# Download model weight | ||
wget https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0/resolve/main/model.safetensors | ||
# Download tokenizer | ||
wget https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0/resolve/main/tokenizer.json | ||
``` | ||
|
||
## Build wasm | ||
|
||
Run the following command to build wasm, the output WASM file will be at `target/wasm32-wasi/release/` | ||
|
||
```bash | ||
cargo build --target wasm32-wasi --release | ||
``` | ||
## Execute | ||
|
||
Execute the WASM with the `wasmedge` using nn-preload to load model. | ||
|
||
``` bash | ||
wasmedge --dir .:. \ | ||
--nn-preload default:mlx:AUTO:model.safetensors \ | ||
./target/wasm32-wasi/release/wasmedge-mlx.wasm default | ||
|
||
``` | ||
|
||
If your model has multiple weight files, you need to provide all in the nn-preload. | ||
|
||
For example: | ||
``` bash | ||
wasmedge --dir .:. \ | ||
--nn-preload default:mlx:AUTO:llama2-7b/model-00001-of-00002.safetensors:llama2-7b/model-00002-of-00002.safetensors \ | ||
./target/wasm32-wasi/release/wasmedge-mlx.wasm default | ||
``` | ||
|
||
## Other | ||
|
||
There are some metadata for MLX plugin you can set. | ||
|
||
- model_type (required): LLM model type. | ||
- tokenizer (required): tokenizer.json path | ||
- max_token (option): maximum generate token number, default is 1024. | ||
- enable_debug_log (option): if print debug log, default is false. | ||
|
||
``` rust | ||
let graph = GraphBuilder::new(GraphEncoding::Mlx, ExecutionTarget::AUTO) | ||
.config(serde_json::to_string(&json!({"model_type": "tiny_llama_1.1B_chat_v1.0", "tokenizer":tokenizer_path, "max_token":100})) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,42 +1,41 @@ | ||
use tokenizers::tokenizer::Tokenizer; | ||
use serde_json::json; | ||
use std::env; | ||
use wasmedge_wasi_nn::{ | ||
self, ExecutionTarget, GraphBuilder, GraphEncoding, GraphExecutionContext, | ||
TensorType, | ||
self, ExecutionTarget, GraphBuilder, GraphEncoding, GraphExecutionContext, TensorType, | ||
}; | ||
use std::env; | ||
fn get_data_from_context(context: &GraphExecutionContext, index: usize) -> Vec<u8> { | ||
// Preserve for 4096 tokens with average token length 8 | ||
const MAX_OUTPUT_BUFFER_SIZE: usize = 4096 * 8; | ||
fn get_data_from_context(context: &GraphExecutionContext, index: usize) -> String { | ||
// Preserve for 4096 tokens with average token length 6 | ||
const MAX_OUTPUT_BUFFER_SIZE: usize = 4096 * 6; | ||
let mut output_buffer = vec![0u8; MAX_OUTPUT_BUFFER_SIZE]; | ||
let _ = context | ||
let mut output_size = context | ||
.get_output(index, &mut output_buffer) | ||
.expect("Failed to get output"); | ||
output_size = std::cmp::min(MAX_OUTPUT_BUFFER_SIZE, output_size); | ||
|
||
return output_buffer; | ||
return String::from_utf8_lossy(&output_buffer[..output_size]).to_string(); | ||
} | ||
|
||
fn get_output_from_context(context: &GraphExecutionContext) -> Vec<u8> { | ||
fn get_output_from_context(context: &GraphExecutionContext) -> String { | ||
get_data_from_context(context, 0) | ||
} | ||
fn main() { | ||
let tokenizer_path = "tokenizer.json"; | ||
let prompt = "Once upon a time, there existed a little girl,"; | ||
|
||
let graph = GraphBuilder::new(GraphEncoding::Mlx, ExecutionTarget::AUTO) | ||
.config(serde_json::to_string(&json!({"tokenizer":tokenizer_path})).expect("Failed to serialize options")) | ||
let tokenizer_path = "./tokenizer.json"; | ||
let prompt = "Once upon a time, there existed a little girl,"; | ||
let args: Vec<String> = env::args().collect(); | ||
let model_name: &str = &args[1]; | ||
let graph = GraphBuilder::new(GraphEncoding::Mlx, ExecutionTarget::AUTO) | ||
.config(serde_json::to_string(&json!({"model_type": "tiny_llama_1.1B_chat_v1.0", "tokenizer":tokenizer_path, "max_token":100})).expect("Failed to serialize options")) | ||
.build_from_cache(model_name) | ||
.expect("Failed to build graph"); | ||
let mut context = graph | ||
.init_execution_context() | ||
.expect("Failed to init context"); | ||
let tensor_data = prompt.as_bytes().to_vec(); | ||
context | ||
.set_input(0, TensorType::U8, &[1], &tensor_data) | ||
.expect("Failed to set input"); | ||
context.compute().expect("Failed to compute"); | ||
let output_bytes = get_output_from_context(&context); | ||
let mut context = graph | ||
.init_execution_context() | ||
.expect("Failed to init context"); | ||
let tensor_data = prompt.as_bytes().to_vec(); | ||
context | ||
.set_input(0, TensorType::U8, &[1], &tensor_data) | ||
.expect("Failed to set input"); | ||
context.compute().expect("Failed to compute"); | ||
let output = get_output_from_context(&context); | ||
|
||
println!("{}", output.trim()); | ||
|
||
} | ||
println!("{}", output.trim()); | ||
} |