Skip to content

Commit

Permalink
Llama2.c wasm module. (huggingface#686)
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare authored Aug 31, 2023
1 parent 9bd486f commit 8e84d8a
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 7 deletions.
83 changes: 83 additions & 0 deletions candle-wasm-examples/llama2-c/src/bin/m.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
use candle::{Device, Tensor};
use candle_wasm_example_llama2::worker::{LogitsProcessor, Model as M, ModelData};
use wasm_bindgen::prelude::*;

#[wasm_bindgen]
pub struct Model {
inner: M,
logits_processor: LogitsProcessor,
tokens: Vec<u32>,
}

impl Model {
fn process(&mut self, tokens: &[u32]) -> candle::Result<String> {
let dev = Device::Cpu;
let input = Tensor::new(tokens, &dev)?.unsqueeze(0)?;
let logits = self.inner.llama.forward(&input, tokens.len())?;
let logits = logits.squeeze(0)?;

let next_token = self.logits_processor.sample(&logits)?;
self.tokens.push(next_token);
let text = match self.inner.tokenizer.id_to_token(next_token) {
Some(text) => text.replace('▁', " ").replace("<0x0A>", "\n"),
None => "".to_string(),
};
Ok(text)
}
}

#[wasm_bindgen]
impl Model {
#[wasm_bindgen(constructor)]
pub fn new(weights: Vec<u8>, tokenizer: Vec<u8>) -> Result<Model, JsError> {
let model = M::load(ModelData {
tokenizer,
model: weights,
});
let logits_processor = LogitsProcessor::new(299792458, None);
match model {
Ok(inner) => Ok(Self {
inner,
logits_processor,
tokens: vec![],
}),
Err(e) => Err(JsError::new(&e.to_string())),
}
}

#[wasm_bindgen]
pub fn init_with_prompt(&mut self, prompt: String, temp: f64) -> Result<String, JsError> {
// First reset the cache.
{
let mut cache = self.inner.cache.kvs.lock().unwrap();
for elem in cache.iter_mut() {
*elem = None
}
}
let temp = if temp <= 0. { None } else { Some(temp) };
self.logits_processor = LogitsProcessor::new(299792458, temp);
self.tokens.clear();
let tokens = self
.inner
.tokenizer
.encode(prompt.to_string(), true)
.map_err(|m| JsError::new(&m.to_string()))?
.get_ids()
.to_vec();
let text = self
.process(&tokens)
.map_err(|m| JsError::new(&m.to_string()))?;
Ok(text)
}

#[wasm_bindgen]
pub fn next_token(&mut self) -> Result<String, JsError> {
let last_token = *self.tokens.last().unwrap();
let text = self
.process(&[last_token])
.map_err(|m| JsError::new(&m.to_string()))?;
Ok(text)
}
}

fn main() {}
4 changes: 2 additions & 2 deletions candle-wasm-examples/llama2-c/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mod app;
mod model;
mod worker;
pub mod model;
pub mod worker;
pub use app::App;
pub use worker::Worker;
10 changes: 5 additions & 5 deletions candle-wasm-examples/llama2-c/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ fn read_tensor<R: std::io::Read, S: Into<Shape>>(
Ok(tensor)
}

struct Model {
cache: Cache,
pub struct Model {
pub cache: Cache,
config: Config,
llama: Llama,
tokenizer: Tokenizer,
pub llama: Llama,
pub tokenizer: Tokenizer,
}

pub struct LogitsProcessor {
Expand Down Expand Up @@ -275,7 +275,7 @@ impl TransformerWeights {
}

impl Model {
fn load(md: ModelData) -> Result<Self> {
pub fn load(md: ModelData) -> Result<Self> {
let dev = Device::Cpu;
let mut model = std::io::Cursor::new(md.model);
let config = Config::from_reader(&mut model)?;
Expand Down

0 comments on commit 8e84d8a

Please sign in to comment.