From 8e84d8a59beeaa0ab051ac0d8febf1b01a234f75 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 31 Aug 2023 08:44:32 +0200 Subject: [PATCH] Llama2.c wasm module. (#686) --- candle-wasm-examples/llama2-c/src/bin/m.rs | 83 +++++++++++++++++++++ candle-wasm-examples/llama2-c/src/lib.rs | 4 +- candle-wasm-examples/llama2-c/src/worker.rs | 10 +-- 3 files changed, 90 insertions(+), 7 deletions(-) create mode 100644 candle-wasm-examples/llama2-c/src/bin/m.rs diff --git a/candle-wasm-examples/llama2-c/src/bin/m.rs b/candle-wasm-examples/llama2-c/src/bin/m.rs new file mode 100644 index 0000000000..ba9ed58d19 --- /dev/null +++ b/candle-wasm-examples/llama2-c/src/bin/m.rs @@ -0,0 +1,83 @@ +use candle::{Device, Tensor}; +use candle_wasm_example_llama2::worker::{LogitsProcessor, Model as M, ModelData}; +use wasm_bindgen::prelude::*; + +#[wasm_bindgen] +pub struct Model { + inner: M, + logits_processor: LogitsProcessor, + tokens: Vec, +} + +impl Model { + fn process(&mut self, tokens: &[u32]) -> candle::Result { + let dev = Device::Cpu; + let input = Tensor::new(tokens, &dev)?.unsqueeze(0)?; + let logits = self.inner.llama.forward(&input, tokens.len())?; + let logits = logits.squeeze(0)?; + + let next_token = self.logits_processor.sample(&logits)?; + self.tokens.push(next_token); + let text = match self.inner.tokenizer.id_to_token(next_token) { + Some(text) => text.replace('▁', " ").replace("<0x0A>", "\n"), + None => "".to_string(), + }; + Ok(text) + } +} + +#[wasm_bindgen] +impl Model { + #[wasm_bindgen(constructor)] + pub fn new(weights: Vec, tokenizer: Vec) -> Result { + let model = M::load(ModelData { + tokenizer, + model: weights, + }); + let logits_processor = LogitsProcessor::new(299792458, None); + match model { + Ok(inner) => Ok(Self { + inner, + logits_processor, + tokens: vec![], + }), + Err(e) => Err(JsError::new(&e.to_string())), + } + } + + #[wasm_bindgen] + pub fn init_with_prompt(&mut self, prompt: String, temp: f64) -> Result { + // First reset the cache. + { + let mut cache = self.inner.cache.kvs.lock().unwrap(); + for elem in cache.iter_mut() { + *elem = None + } + } + let temp = if temp <= 0. { None } else { Some(temp) }; + self.logits_processor = LogitsProcessor::new(299792458, temp); + self.tokens.clear(); + let tokens = self + .inner + .tokenizer + .encode(prompt.to_string(), true) + .map_err(|m| JsError::new(&m.to_string()))? + .get_ids() + .to_vec(); + let text = self + .process(&tokens) + .map_err(|m| JsError::new(&m.to_string()))?; + Ok(text) + } + + #[wasm_bindgen] + pub fn next_token(&mut self) -> Result { + let last_token = *self.tokens.last().unwrap(); + let text = self + .process(&[last_token]) + .map_err(|m| JsError::new(&m.to_string()))?; + Ok(text) + } +} + +fn main() {} diff --git a/candle-wasm-examples/llama2-c/src/lib.rs b/candle-wasm-examples/llama2-c/src/lib.rs index b6b4004f30..cd7834b52e 100644 --- a/candle-wasm-examples/llama2-c/src/lib.rs +++ b/candle-wasm-examples/llama2-c/src/lib.rs @@ -1,5 +1,5 @@ mod app; -mod model; -mod worker; +pub mod model; +pub mod worker; pub use app::App; pub use worker::Worker; diff --git a/candle-wasm-examples/llama2-c/src/worker.rs b/candle-wasm-examples/llama2-c/src/worker.rs index 0ee199afbc..e15aaa79a7 100644 --- a/candle-wasm-examples/llama2-c/src/worker.rs +++ b/candle-wasm-examples/llama2-c/src/worker.rs @@ -49,11 +49,11 @@ fn read_tensor>( Ok(tensor) } -struct Model { - cache: Cache, +pub struct Model { + pub cache: Cache, config: Config, - llama: Llama, - tokenizer: Tokenizer, + pub llama: Llama, + pub tokenizer: Tokenizer, } pub struct LogitsProcessor { @@ -275,7 +275,7 @@ impl TransformerWeights { } impl Model { - fn load(md: ModelData) -> Result { + pub fn load(md: ModelData) -> Result { let dev = Device::Cpu; let mut model = std::io::Cursor::new(md.model); let config = Config::from_reader(&mut model)?;