Skip to content

Commit

Permalink
[Example] MLX: add documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
grorge123 authored and hydai committed Sep 24, 2024
1 parent 7a323a7 commit 7ee0774
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 27 deletions.
81 changes: 81 additions & 0 deletions wasmedge-mlx/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# MLX example with WasmEdge WASI-NN MLX plugin

This example demonstrates using WasmEdge WASI-NN MLX plugin to perform an inference task with LLM model.

## Supported Models

| Family | Models |
|--------|--------|
| LLaMA 2 | llama_2_7b_chat_hf |
| LLaMA 3 | llama_3_8b |
| TinyLLaMA | tiny_llama_1.1B_chat_v1.0 |

## Install WasmEdge with WASI-NN MLX plugin

The MLX backend relies on [MLX](https://github.com/ml-explore/mlx), but we will auto-download MLX when you build WasmEdge. You do not need to install it yourself. If you want to custom MLX, install it yourself or set the `CMAKE_PREFIX_PATH` variable when configuring cmake.

Build and install WasmEdge from source:

``` bash
cd <path/to/your/wasmedge/source/folder>

cmake -GNinja -Bbuild -DCMAKE_BUILD_TYPE=Release -DWASMEDGE_PLUGIN_WASI_NN_BACKEND="mlx"
cmake --build build

# For the WASI-NN plugin, you should install this project.
cmake --install build
```

Then you will have an executable `wasmedge` runtime under `/usr/local/bin` and the WASI-NN with MLX backend plug-in under `/usr/local/lib/wasmedge/libwasmedgePluginWasiNN.so` after installation.

## Download the model and tokenizer

In this example, we will use `tiny_llama_1.1B_chat_v1.0`, which you can change to `llama_2_7b_chat_hf` or `llama_3_8b`.

``` bash
# Download model weight
wget https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0/resolve/main/model.safetensors
# Download tokenizer
wget https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0/resolve/main/tokenizer.json
```

## Build wasm

Run the following command to build wasm, the output WASM file will be at `target/wasm32-wasi/release/`

```bash
cargo build --target wasm32-wasi --release
```
## Execute

Execute the WASM with the `wasmedge` using nn-preload to load model.

``` bash
wasmedge --dir .:. \
--nn-preload default:mlx:AUTO:model.safetensors \
./target/wasm32-wasi/release/wasmedge-mlx.wasm default

```

If your model has multiple weight files, you need to provide all in the nn-preload.

For example:
``` bash
wasmedge --dir .:. \
--nn-preload default:mlx:AUTO:llama2-7b/model-00001-of-00002.safetensors:llama2-7b/model-00002-of-00002.safetensors \
./target/wasm32-wasi/release/wasmedge-mlx.wasm default
```

## Other

There are some metadata for MLX plugin you can set.

- model_type (required): LLM model type.
- tokenizer (required): tokenizer.json path
- max_token (option): maximum generate token number, default is 1024.
- enable_debug_log (option): if print debug log, default is false.

``` rust
let graph = GraphBuilder::new(GraphEncoding::Mlx, ExecutionTarget::AUTO)
.config(serde_json::to_string(&json!({"model_type": "tiny_llama_1.1B_chat_v1.0", "tokenizer":tokenizer_path, "max_token":100}))
```
53 changes: 26 additions & 27 deletions wasmedge-mlx/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,42 +1,41 @@
use tokenizers::tokenizer::Tokenizer;
use serde_json::json;
use std::env;
use wasmedge_wasi_nn::{
self, ExecutionTarget, GraphBuilder, GraphEncoding, GraphExecutionContext,
TensorType,
self, ExecutionTarget, GraphBuilder, GraphEncoding, GraphExecutionContext, TensorType,
};
use std::env;
fn get_data_from_context(context: &GraphExecutionContext, index: usize) -> Vec<u8> {
// Preserve for 4096 tokens with average token length 8
const MAX_OUTPUT_BUFFER_SIZE: usize = 4096 * 8;
fn get_data_from_context(context: &GraphExecutionContext, index: usize) -> String {
// Preserve for 4096 tokens with average token length 6
const MAX_OUTPUT_BUFFER_SIZE: usize = 4096 * 6;
let mut output_buffer = vec![0u8; MAX_OUTPUT_BUFFER_SIZE];
let _ = context
let mut output_size = context
.get_output(index, &mut output_buffer)
.expect("Failed to get output");
output_size = std::cmp::min(MAX_OUTPUT_BUFFER_SIZE, output_size);

return output_buffer;
return String::from_utf8_lossy(&output_buffer[..output_size]).to_string();
}

fn get_output_from_context(context: &GraphExecutionContext) -> Vec<u8> {
fn get_output_from_context(context: &GraphExecutionContext) -> String {
get_data_from_context(context, 0)
}
fn main() {
let tokenizer_path = "tokenizer.json";
let prompt = "Once upon a time, there existed a little girl,";

let graph = GraphBuilder::new(GraphEncoding::Mlx, ExecutionTarget::AUTO)
.config(serde_json::to_string(&json!({"tokenizer":tokenizer_path})).expect("Failed to serialize options"))
let tokenizer_path = "./tokenizer.json";
let prompt = "Once upon a time, there existed a little girl,";
let args: Vec<String> = env::args().collect();
let model_name: &str = &args[1];
let graph = GraphBuilder::new(GraphEncoding::Mlx, ExecutionTarget::AUTO)
.config(serde_json::to_string(&json!({"model_type": "tiny_llama_1.1B_chat_v1.0", "tokenizer":tokenizer_path, "max_token":100})).expect("Failed to serialize options"))
.build_from_cache(model_name)
.expect("Failed to build graph");
let mut context = graph
.init_execution_context()
.expect("Failed to init context");
let tensor_data = prompt.as_bytes().to_vec();
context
.set_input(0, TensorType::U8, &[1], &tensor_data)
.expect("Failed to set input");
context.compute().expect("Failed to compute");
let output_bytes = get_output_from_context(&context);
let mut context = graph
.init_execution_context()
.expect("Failed to init context");
let tensor_data = prompt.as_bytes().to_vec();
context
.set_input(0, TensorType::U8, &[1], &tensor_data)
.expect("Failed to set input");
context.compute().expect("Failed to compute");
let output = get_output_from_context(&context);

println!("{}", output.trim());

}
println!("{}", output.trim());
}

0 comments on commit 7ee0774

Please sign in to comment.