forked from huggingface/candle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Llama2.c wasm module. (huggingface#686)
- Loading branch information
1 parent
9bd486f
commit 8e84d8a
Showing
3 changed files
with
90 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
use candle::{Device, Tensor}; | ||
use candle_wasm_example_llama2::worker::{LogitsProcessor, Model as M, ModelData}; | ||
use wasm_bindgen::prelude::*; | ||
|
||
#[wasm_bindgen] | ||
pub struct Model { | ||
inner: M, | ||
logits_processor: LogitsProcessor, | ||
tokens: Vec<u32>, | ||
} | ||
|
||
impl Model { | ||
fn process(&mut self, tokens: &[u32]) -> candle::Result<String> { | ||
let dev = Device::Cpu; | ||
let input = Tensor::new(tokens, &dev)?.unsqueeze(0)?; | ||
let logits = self.inner.llama.forward(&input, tokens.len())?; | ||
let logits = logits.squeeze(0)?; | ||
|
||
let next_token = self.logits_processor.sample(&logits)?; | ||
self.tokens.push(next_token); | ||
let text = match self.inner.tokenizer.id_to_token(next_token) { | ||
Some(text) => text.replace('▁', " ").replace("<0x0A>", "\n"), | ||
None => "".to_string(), | ||
}; | ||
Ok(text) | ||
} | ||
} | ||
|
||
#[wasm_bindgen] | ||
impl Model { | ||
#[wasm_bindgen(constructor)] | ||
pub fn new(weights: Vec<u8>, tokenizer: Vec<u8>) -> Result<Model, JsError> { | ||
let model = M::load(ModelData { | ||
tokenizer, | ||
model: weights, | ||
}); | ||
let logits_processor = LogitsProcessor::new(299792458, None); | ||
match model { | ||
Ok(inner) => Ok(Self { | ||
inner, | ||
logits_processor, | ||
tokens: vec![], | ||
}), | ||
Err(e) => Err(JsError::new(&e.to_string())), | ||
} | ||
} | ||
|
||
#[wasm_bindgen] | ||
pub fn init_with_prompt(&mut self, prompt: String, temp: f64) -> Result<String, JsError> { | ||
// First reset the cache. | ||
{ | ||
let mut cache = self.inner.cache.kvs.lock().unwrap(); | ||
for elem in cache.iter_mut() { | ||
*elem = None | ||
} | ||
} | ||
let temp = if temp <= 0. { None } else { Some(temp) }; | ||
self.logits_processor = LogitsProcessor::new(299792458, temp); | ||
self.tokens.clear(); | ||
let tokens = self | ||
.inner | ||
.tokenizer | ||
.encode(prompt.to_string(), true) | ||
.map_err(|m| JsError::new(&m.to_string()))? | ||
.get_ids() | ||
.to_vec(); | ||
let text = self | ||
.process(&tokens) | ||
.map_err(|m| JsError::new(&m.to_string()))?; | ||
Ok(text) | ||
} | ||
|
||
#[wasm_bindgen] | ||
pub fn next_token(&mut self) -> Result<String, JsError> { | ||
let last_token = *self.tokens.last().unwrap(); | ||
let text = self | ||
.process(&[last_token]) | ||
.map_err(|m| JsError::new(&m.to_string()))?; | ||
Ok(text) | ||
} | ||
} | ||
|
||
fn main() {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
mod app; | ||
mod model; | ||
mod worker; | ||
pub mod model; | ||
pub mod worker; | ||
pub use app::App; | ||
pub use worker::Worker; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters