diff --git a/screenpipe-core/examples/llama.rs b/screenpipe-core/examples/llama.rs new file mode 100644 index 00000000..9d060465 --- /dev/null +++ b/screenpipe-core/examples/llama.rs @@ -0,0 +1,11 @@ +use anyhow::Result; +use screenpipe_core::llama::LlamaInitConfig; +use screenpipe_core::llama_stream_text; + +fn main() -> Result<()> { + llama_stream_text(LlamaInitConfig::default(), |text| { + println!("{}", text); + Ok(()) + })?; + Ok(()) +} diff --git a/screenpipe-core/src/google.rs b/screenpipe-core/src/google.rs index e44f6516..078d074b 100644 --- a/screenpipe-core/src/google.rs +++ b/screenpipe-core/src/google.rs @@ -4,95 +4,13 @@ mod google_module { use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; use candle_transformers::generation::{LogitsProcessor, Sampling}; - use candle_transformers::models::gemma::{Config as Config1, Model as Model1}; - use candle_transformers::models::gemma2::{Config as Config2, Model as Model2}; + use candle_transformers::models::gemma::Model as Model1; + use candle_transformers::models::gemma2::Model as Model2; use hf_hub::api::sync::ApiBuilder; use hf_hub::{Repo, RepoType}; use tokenizers::Tokenizer; - use crate::hub_load_safetensors; - - pub struct TokenOutputStream { - tokenizer: tokenizers::Tokenizer, - tokens: Vec, - prev_index: usize, - current_index: usize, - } - - impl TokenOutputStream { - pub fn new(tokenizer: tokenizers::Tokenizer) -> Self { - Self { - tokenizer, - tokens: Vec::new(), - prev_index: 0, - current_index: 0, - } - } - - pub fn into_inner(self) -> tokenizers::Tokenizer { - self.tokenizer - } - - fn decode(&self, tokens: &[u32]) -> Result { - match self.tokenizer.decode(tokens, true) { - Ok(str) => Ok(str), - Err(err) => anyhow::bail!("cannot decode: {err}"), - } - } - - pub fn next_token(&mut self, token: u32) -> Result> { - let prev_text = if self.tokens.is_empty() { - String::new() - } else { - let tokens = &self.tokens[self.prev_index..self.current_index]; - self.decode(tokens)? - }; - self.tokens.push(token); - let text = self.decode(&self.tokens[self.prev_index..])?; - if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() { - let text = text.split_at(prev_text.len()); - self.prev_index = self.current_index; - self.current_index = self.tokens.len(); - Ok(Some(text.1.to_string())) - } else { - Ok(None) - } - } - - pub fn decode_rest(&self) -> Result> { - let prev_text = if self.tokens.is_empty() { - String::new() - } else { - let tokens = &self.tokens[self.prev_index..self.current_index]; - self.decode(tokens)? - }; - let text = self.decode(&self.tokens[self.prev_index..])?; - if text.len() > prev_text.len() { - let text = text.split_at(prev_text.len()); - Ok(Some(text.1.to_string())) - } else { - Ok(None) - } - } - - pub fn decode_all(&self) -> Result { - self.decode(&self.tokens) - } - - pub fn get_token(&self, token_s: &str) -> Option { - self.tokenizer.get_vocab(true).get(token_s).copied() - } - - pub fn tokenizer(&self) -> &tokenizers::Tokenizer { - &self.tokenizer - } - - pub fn clear(&mut self) { - self.tokens.clear(); - self.prev_index = 0; - self.current_index = 0; - } - } + use crate::{hub_load_safetensors, TokenOutputStream}; enum Model { V1(Model1), diff --git a/screenpipe-core/src/lib.rs b/screenpipe-core/src/lib.rs index 2ad276a3..374b2be3 100644 --- a/screenpipe-core/src/lib.rs +++ b/screenpipe-core/src/lib.rs @@ -5,6 +5,10 @@ pub mod llm; #[cfg(feature = "llm")] pub use llm::*; #[cfg(feature = "llm")] +pub mod phi; +#[cfg(feature = "llm")] +pub use phi::*; +#[cfg(feature = "llm")] pub mod google; #[cfg(feature = "llm")] pub use google::*; @@ -12,6 +16,10 @@ pub use google::*; pub mod mistral; #[cfg(feature = "llm")] pub use mistral::*; +#[cfg(feature = "llm")] +pub mod llama; +#[cfg(feature = "llm")] +pub use llama::*; #[cfg(feature = "pipes")] pub mod pipes; #[cfg(feature = "pipes")] diff --git a/screenpipe-core/src/llama.rs b/screenpipe-core/src/llama.rs new file mode 100644 index 00000000..050ab9f6 --- /dev/null +++ b/screenpipe-core/src/llama.rs @@ -0,0 +1,242 @@ +#[cfg(feature = "llm")] +mod llm_module { + + use anyhow::{Error as E, Result}; + + use candle::{DType, Device, Tensor}; + use candle_nn::VarBuilder; + use candle_transformers::generation::{LogitsProcessor, Sampling}; + use hf_hub::{ + api::sync::{Api, ApiBuilder}, + Repo, RepoType, + }; + + use candle_transformers::models::llama as model; + use model::{Llama, LlamaConfig}; + use tokenizers::Tokenizer; + + use crate::{hub_load_safetensors, TokenOutputStream}; + + const EOS_TOKEN: &str = ""; + + #[derive(Clone, Debug, Copy, PartialEq, Eq)] + enum Which { + V1, + V2, + V3, + V31, + V3Instruct, + V31Instruct, + V32_1b, + V32_1bInstruct, + V32_3b, + V32_3bInstruct, + Solar10_7B, + TinyLlama1_1BChat, + } + + #[derive(Debug)] + pub struct LlamaInitConfig { + /// The temperature used to generate samples. + temperature: f64, + + /// Nucleus sampling probability cutoff. + top_p: Option, + + /// Only sample among the top K samples. + top_k: Option, + + /// The seed to use when generating random samples. + seed: u64, + + /// The length of the sample to generate (in tokens). + sample_len: usize, + + /// Disable the key-value cache. + no_kv_cache: bool, + + /// The initial prompt. + prompt: Option, + + /// Use different dtype than f16 + dtype: Option, + + model_id: Option, + + revision: Option, + + /// The model size to use. + which: Which, + + use_flash_attn: bool, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + repeat_last_n: usize, + } + + impl Default for LlamaInitConfig { + fn default() -> Self { + Self { + use_flash_attn: false, + prompt: None, + temperature: 0.8, + top_p: Some(0.95), + top_k: None, + seed: 299792458, + sample_len: 100, + which: Which::V32_3bInstruct, + model_id: None, + revision: None, + repeat_penalty: 1.1, + repeat_last_n: 128, + no_kv_cache: false, + dtype: None, + } + } + } + + pub fn llama_stream_text(args: LlamaInitConfig, mut callback: F) -> Result<()> + where + F: FnMut(String) -> Result<()>, + { + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + println!( + "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", + args.temperature, args.repeat_penalty, args.repeat_last_n + ); + + let start = std::time::Instant::now(); + let api = ApiBuilder::new() + // ! hardcoded louis token dont CARE + .with_token(Some("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".to_string())) + .build()?; + let model_id = args.model_id.unwrap_or_else(|| match args.which { + Which::V1 => "Narsil/amall-7b".to_string(), + Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(), + Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(), + Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(), + Which::V31 => "meta-llama/Meta-Llama-3.1-8B".to_string(), + Which::V31Instruct => "meta-llama/Meta-Llama-3.1-8B-Instruct".to_string(), + Which::V32_1b => "meta-llama/Llama-3.2-1B".to_string(), + Which::V32_1bInstruct => "meta-llama/Llama-3.2-1B-Instruct".to_string(), + Which::V32_3b => "meta-llama/Llama-3.2-3B".to_string(), + Which::V32_3bInstruct => "meta-llama/Llama-3.2-3B-Instruct".to_string(), + Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(), + Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(), + }); + println!("loading the model weights from {model_id}"); + let revision = args.revision.unwrap_or("main".to_string()); + let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision)); + + let tokenizer_filename = api.get("tokenizer.json")?; + let config_filename = api.get("config.json")?; + let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?; + let config = config.into_config(args.use_flash_attn); + + let filenames = hub_load_safetensors(&api, "model.safetensors.index.json")?; + println!("retrieved the files in {:?}", start.elapsed()); + + let device = Device::new_metal(0).unwrap_or(Device::new_cuda(0).unwrap_or(Device::Cpu)); + + let dtype = match args.dtype.as_deref() { + Some("f16") => DType::F16, + Some("bf16") => DType::BF16, + Some("f32") => DType::F32, + Some(dtype) => anyhow::bail!("Unsupported dtype {dtype}"), + None => DType::F16, + }; + + let start = std::time::Instant::now(); + let mut cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; + let llama = Llama::load(vb, &config)?; + println!("loaded the model in {:?}", start.elapsed()); + + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + let mut tokens = tokenizer + .encode(args.prompt.unwrap(), true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + + let mut tokenizer = TokenOutputStream::new(tokenizer); + for &t in tokens.iter() { + if let Some(t) = tokenizer.next_token(t)? { + callback(t)?; + } + } + + let mut logits_processor = { + let temperature = args.temperature; + let sampling = if temperature <= 0. { + Sampling::ArgMax + } else { + match (args.top_k, args.top_p) { + (None, None) => Sampling::All { temperature }, + (Some(k), None) => Sampling::TopK { k, temperature }, + (None, Some(p)) => Sampling::TopP { p, temperature }, + (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature }, + } + }; + LogitsProcessor::from_sampling(args.seed, sampling) + }; + + let mut index_pos = 0; + let mut token_generated = 0; + let start_gen = std::time::Instant::now(); + for index in 0..args.sample_len { + let (context_size, context_index) = if cache.use_kv_cache && index > 0 { + (1, index_pos) + } else { + (tokens.len(), 0) + }; + let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; + let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; + let logits = llama.forward(&input, context_index, &mut cache)?; + let logits = logits.squeeze(0)?; + let logits = if args.repeat_penalty == 1. { + logits + } else { + let start_at = tokens.len().saturating_sub(args.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + args.repeat_penalty, + &tokens[start_at..], + )? + }; + index_pos += ctxt.len(); + + let next_token = logits_processor.sample(&logits)?; + token_generated += 1; + tokens.push(next_token); + + if let Some(t) = tokenizer.next_token(next_token)? { + callback(t)?; + } + } + + if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? { + callback(rest)?; + } + + let dt = start_gen.elapsed(); + println!( + "\n\n{} tokens generated ({} token/s)\n", + token_generated, + (token_generated - 1) as f64 / dt.as_secs_f64(), + ); + Ok(()) + } +} +// Optionally, you can re-export the module contents if needed +#[cfg(feature = "llm")] +pub use llm_module::*; diff --git a/screenpipe-core/src/llm.rs b/screenpipe-core/src/llm.rs old mode 100755 new mode 100644 index 9d7d1bab..5e459d43 --- a/screenpipe-core/src/llm.rs +++ b/screenpipe-core/src/llm.rs @@ -1,136 +1,88 @@ #[cfg(feature = "llm")] mod llm_module { - use anyhow::Result; - use candle::{DType, Device, Tensor}; - use candle_nn::VarBuilder; - use candle_transformers::{ - generation::LogitsProcessor, - models::phi3::{Config as Phi3Config, Model as Phi3}, - }; - - use hf_hub::{api::sync::Api, Repo, RepoType}; - use tokenizers::Tokenizer; + /// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a + /// streaming way rather than having to wait for the full decoding. + pub struct TokenOutputStream { + tokenizer: tokenizers::Tokenizer, + tokens: Vec, + prev_index: usize, + current_index: usize, + } - /// Loads the safetensors files for a model from the hub based on a json index file. - pub fn hub_load_safetensors( - repo: &hf_hub::api::sync::ApiRepo, - json_file: &str, - ) -> Result> { - let json_file = repo.get(json_file).map_err(candle::Error::wrap)?; - let json_file = std::fs::File::open(json_file)?; - let json: serde_json::Value = - serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?; - let weight_map = match json.get("weight_map") { - None => anyhow::bail!("no weight map in {json_file:?}"), - Some(serde_json::Value::Object(map)) => map, - Some(_) => anyhow::bail!("weight map in {json_file:?} is not a map"), - }; - let mut safetensors_files = std::collections::HashSet::new(); - for value in weight_map.values() { - if let Some(file) = value.as_str() { - safetensors_files.insert(file.to_string()); + impl TokenOutputStream { + pub fn new(tokenizer: tokenizers::Tokenizer) -> Self { + Self { + tokenizer, + tokens: Vec::new(), + prev_index: 0, + current_index: 0, } } - let safetensors_files = safetensors_files - .iter() - .map(|v| repo.get(v).map_err(anyhow::Error::from)) - .collect::, anyhow::Error>>()?; - Ok(safetensors_files) - } - - pub fn load_llama_model(device: &Device) -> Result<(Phi3, Tokenizer)> { - let api = Api::new()?; - let model_id = "microsoft/Phi-3-mini-4k-instruct"; - let revision = "main"; - - let api = api.repo(Repo::with_revision( - model_id.to_string(), - RepoType::Model, - revision.to_string(), - )); - let tokenizer_filename = api.get("tokenizer.json")?; - let config_filename = api.get("config.json")?; - - let config: Phi3Config = serde_json::from_slice(&std::fs::read(config_filename)?)?; - - // https://github.com/huggingface/candle/blob/ddafc61055601002622778b7762c15bd60057c1f/candle-examples/examples/phi/main.rs#L364 - // let dtype = DType::BF16; - let dtype = DType::F32; - let filenames = hub_load_safetensors(&api, "model.safetensors.index.json")?; - let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, device)? }; - - let model = Phi3::new(&config, vb)?; - let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(anyhow::Error::msg)?; - - Ok((model, tokenizer)) - } - - pub fn generate_text_streaming( - model: &mut Phi3, - tokenizer: &Tokenizer, - prompt: &str, - max_tokens: usize, - temperature: f64, - repeat_penalty: f32, - repeat_last_n: usize, - seed: u64, - top_p: f64, - device: &Device, - mut callback: F, - ) -> Result<()> - where - F: FnMut(String) -> Result<()>, - { - let mut logits_processor = LogitsProcessor::new(seed, Some(temperature), Some(top_p)); - let tokens = tokenizer.encode(prompt, true).unwrap(); - if tokens.is_empty() { - anyhow::bail!("empty prompt") + pub fn into_inner(self) -> tokenizers::Tokenizer { + self.tokenizer } - let mut tokens = tokens.get_ids().to_vec(); - let eos_token = match tokenizer.token_to_id("<|endoftext|>") { - Some(token) => token, - None => anyhow::bail!("cannot find the endoftext token"), - }; - let mut pos = 0; - for _ in 0..max_tokens { - let context_size = if pos > 0 { 1 } else { tokens.len() }; - let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; - let input = Tensor::new(ctxt, device)?.unsqueeze(0)?; - let logits = model.forward(&input, pos)?; - let logits = logits.squeeze(0)?.to_dtype(DType::F32)?; + fn decode(&self, tokens: &[u32]) -> anyhow::Result { + match self.tokenizer.decode(tokens, true) { + Ok(str) => Ok(str), + Err(err) => anyhow::bail!("cannot decode: {err}"), + } + } - let logits = if repeat_penalty == 1. { - logits + // https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68 + pub fn next_token(&mut self, token: u32) -> anyhow::Result> { + let prev_text = if self.tokens.is_empty() { + String::new() } else { - let start_at = tokens.len().saturating_sub(repeat_last_n); - candle_transformers::utils::apply_repeat_penalty( - &logits, - repeat_penalty, - &tokens[start_at..], - )? + let tokens = &self.tokens[self.prev_index..self.current_index]; + self.decode(tokens)? }; + self.tokens.push(token); + let text = self.decode(&self.tokens[self.prev_index..])?; + if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() { + let text = text.split_at(prev_text.len()); + self.prev_index = self.current_index; + self.current_index = self.tokens.len(); + Ok(Some(text.1.to_string())) + } else { + Ok(None) + } + } - // Remove the batch dimension if it exists - let logits = if logits.dims().len() > 1 { - logits.squeeze(0)? + pub fn decode_rest(&self) -> anyhow::Result> { + let prev_text = if self.tokens.is_empty() { + String::new() } else { - logits + let tokens = &self.tokens[self.prev_index..self.current_index]; + self.decode(tokens)? }; - - let next_token = logits_processor.sample(&logits)?; - tokens.push(next_token); - if next_token == eos_token { - break; - } - if let Ok(t) = tokenizer.decode(&[next_token], false) { - callback(t)?; + let text = self.decode(&self.tokens[self.prev_index..])?; + if text.len() > prev_text.len() { + let text = text.split_at(prev_text.len()); + Ok(Some(text.1.to_string())) + } else { + Ok(None) } - pos += 1; } - Ok(()) + pub fn decode_all(&self) -> anyhow::Result { + self.decode(&self.tokens) + } + + pub fn get_token(&self, token_s: &str) -> Option { + self.tokenizer.get_vocab(true).get(token_s).copied() + } + + pub fn tokenizer(&self) -> &tokenizers::Tokenizer { + &self.tokenizer + } + + pub fn clear(&mut self) { + self.tokens.clear(); + self.prev_index = 0; + self.current_index = 0; + } } } diff --git a/screenpipe-core/src/mistral.rs b/screenpipe-core/src/mistral.rs index b270a01e..501f34fe 100644 --- a/screenpipe-core/src/mistral.rs +++ b/screenpipe-core/src/mistral.rs @@ -13,92 +13,7 @@ mod llm_module { use hf_hub::{Repo, RepoType}; use tokenizers::Tokenizer; - use crate::hub_load_safetensors; - - /// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a - /// streaming way rather than having to wait for the full decoding. - pub struct TokenOutputStream { - tokenizer: tokenizers::Tokenizer, - tokens: Vec, - prev_index: usize, - current_index: usize, - } - - impl TokenOutputStream { - pub fn new(tokenizer: tokenizers::Tokenizer) -> Self { - Self { - tokenizer, - tokens: Vec::new(), - prev_index: 0, - current_index: 0, - } - } - - pub fn into_inner(self) -> tokenizers::Tokenizer { - self.tokenizer - } - - fn decode(&self, tokens: &[u32]) -> Result { - match self.tokenizer.decode(tokens, true) { - Ok(str) => Ok(str), - Err(err) => anyhow::bail!("cannot decode: {err}"), - } - } - - // https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68 - pub fn next_token(&mut self, token: u32) -> Result> { - let prev_text = if self.tokens.is_empty() { - String::new() - } else { - let tokens = &self.tokens[self.prev_index..self.current_index]; - self.decode(tokens)? - }; - self.tokens.push(token); - let text = self.decode(&self.tokens[self.prev_index..])?; - if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() { - let text = text.split_at(prev_text.len()); - self.prev_index = self.current_index; - self.current_index = self.tokens.len(); - Ok(Some(text.1.to_string())) - } else { - Ok(None) - } - } - - pub fn decode_rest(&self) -> Result> { - let prev_text = if self.tokens.is_empty() { - String::new() - } else { - let tokens = &self.tokens[self.prev_index..self.current_index]; - self.decode(tokens)? - }; - let text = self.decode(&self.tokens[self.prev_index..])?; - if text.len() > prev_text.len() { - let text = text.split_at(prev_text.len()); - Ok(Some(text.1.to_string())) - } else { - Ok(None) - } - } - - pub fn decode_all(&self) -> Result { - self.decode(&self.tokens) - } - - pub fn get_token(&self, token_s: &str) -> Option { - self.tokenizer.get_vocab(true).get(token_s).copied() - } - - pub fn tokenizer(&self) -> &tokenizers::Tokenizer { - &self.tokenizer - } - - pub fn clear(&mut self) { - self.tokens.clear(); - self.prev_index = 0; - self.current_index = 0; - } - } + use crate::{hub_load_safetensors, TokenOutputStream}; enum Model { Mistral(Mistral), @@ -255,12 +170,6 @@ mod llm_module { #[derive(Debug)] pub struct MistralConfig { - /// Run on CPU rather than on GPU. - cpu: bool, - - /// Enable tracing (generates a trace-timestamp.json file). - tracing: bool, - use_flash_attn: bool, prompt: String, @@ -308,8 +217,6 @@ mod llm_module { impl Default for MistralConfig { fn default() -> Self { Self { - cpu: false, - tracing: false, use_flash_attn: false, prompt: String::new(), temperature: Some(0.8), diff --git a/screenpipe-core/src/phi.rs b/screenpipe-core/src/phi.rs new file mode 100755 index 00000000..9d7d1bab --- /dev/null +++ b/screenpipe-core/src/phi.rs @@ -0,0 +1,139 @@ +#[cfg(feature = "llm")] +mod llm_module { + use anyhow::Result; + use candle::{DType, Device, Tensor}; + use candle_nn::VarBuilder; + use candle_transformers::{ + generation::LogitsProcessor, + models::phi3::{Config as Phi3Config, Model as Phi3}, + }; + + use hf_hub::{api::sync::Api, Repo, RepoType}; + use tokenizers::Tokenizer; + + /// Loads the safetensors files for a model from the hub based on a json index file. + pub fn hub_load_safetensors( + repo: &hf_hub::api::sync::ApiRepo, + json_file: &str, + ) -> Result> { + let json_file = repo.get(json_file).map_err(candle::Error::wrap)?; + let json_file = std::fs::File::open(json_file)?; + let json: serde_json::Value = + serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?; + let weight_map = match json.get("weight_map") { + None => anyhow::bail!("no weight map in {json_file:?}"), + Some(serde_json::Value::Object(map)) => map, + Some(_) => anyhow::bail!("weight map in {json_file:?} is not a map"), + }; + let mut safetensors_files = std::collections::HashSet::new(); + for value in weight_map.values() { + if let Some(file) = value.as_str() { + safetensors_files.insert(file.to_string()); + } + } + let safetensors_files = safetensors_files + .iter() + .map(|v| repo.get(v).map_err(anyhow::Error::from)) + .collect::, anyhow::Error>>()?; + Ok(safetensors_files) + } + + pub fn load_llama_model(device: &Device) -> Result<(Phi3, Tokenizer)> { + let api = Api::new()?; + let model_id = "microsoft/Phi-3-mini-4k-instruct"; + let revision = "main"; + + let api = api.repo(Repo::with_revision( + model_id.to_string(), + RepoType::Model, + revision.to_string(), + )); + let tokenizer_filename = api.get("tokenizer.json")?; + let config_filename = api.get("config.json")?; + + let config: Phi3Config = serde_json::from_slice(&std::fs::read(config_filename)?)?; + + // https://github.com/huggingface/candle/blob/ddafc61055601002622778b7762c15bd60057c1f/candle-examples/examples/phi/main.rs#L364 + // let dtype = DType::BF16; + let dtype = DType::F32; + + let filenames = hub_load_safetensors(&api, "model.safetensors.index.json")?; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, device)? }; + + let model = Phi3::new(&config, vb)?; + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(anyhow::Error::msg)?; + + Ok((model, tokenizer)) + } + + pub fn generate_text_streaming( + model: &mut Phi3, + tokenizer: &Tokenizer, + prompt: &str, + max_tokens: usize, + temperature: f64, + repeat_penalty: f32, + repeat_last_n: usize, + seed: u64, + top_p: f64, + device: &Device, + mut callback: F, + ) -> Result<()> + where + F: FnMut(String) -> Result<()>, + { + let mut logits_processor = LogitsProcessor::new(seed, Some(temperature), Some(top_p)); + let tokens = tokenizer.encode(prompt, true).unwrap(); + if tokens.is_empty() { + anyhow::bail!("empty prompt") + } + let mut tokens = tokens.get_ids().to_vec(); + let eos_token = match tokenizer.token_to_id("<|endoftext|>") { + Some(token) => token, + None => anyhow::bail!("cannot find the endoftext token"), + }; + + let mut pos = 0; + for _ in 0..max_tokens { + let context_size = if pos > 0 { 1 } else { tokens.len() }; + let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; + let input = Tensor::new(ctxt, device)?.unsqueeze(0)?; + let logits = model.forward(&input, pos)?; + let logits = logits.squeeze(0)?.to_dtype(DType::F32)?; + + let logits = if repeat_penalty == 1. { + logits + } else { + let start_at = tokens.len().saturating_sub(repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + repeat_penalty, + &tokens[start_at..], + )? + }; + + // Remove the batch dimension if it exists + let logits = if logits.dims().len() > 1 { + logits.squeeze(0)? + } else { + logits + }; + + let next_token = logits_processor.sample(&logits)?; + tokens.push(next_token); + if next_token == eos_token { + break; + } + if let Ok(t) = tokenizer.decode(&[next_token], false) { + callback(t)?; + } + pos += 1; + } + + Ok(()) + } +} + +// Optionally, you can re-export the module contents if needed +#[cfg(feature = "llm")] +pub use llm_module::*;