diff --git a/screenpipe-audio/src/stt.rs b/screenpipe-audio/src/stt.rs index 35b7b9e8..49b9d8c9 100644 --- a/screenpipe-audio/src/stt.rs +++ b/screenpipe-audio/src/stt.rs @@ -5,17 +5,14 @@ use std::{ use anyhow::{Error as E, Result}; use candle::{Device, IndexOp, Tensor}; -use candle_nn::ops::softmax; +use candle_nn::{ops::softmax, VarBuilder}; use crossbeam::channel::{self, Receiver, Sender}; use hf_hub::{api::sync::Api, Repo, RepoType}; use log::{error, info}; use rand::{distributions::Distribution, SeedableRng}; use tokenizers::Tokenizer; -use candle_transformers::{ - models::whisper::{self as m, audio, Config}, - quantized_var_builder::VarBuilder, -}; +use candle_transformers::models::whisper::{self as m, audio, Config}; use rubato::{ Resampler, SincFixedIn, SincInterpolationParameters, SincInterpolationType, WindowFunction, }; @@ -37,24 +34,24 @@ impl WhisperModel { let (config_filename, tokenizer_filename, weights_filename) = { let api = Api::new()?; let repo = api.repo(Repo::with_revision( - "lmz/candle-whisper".to_string(), + "openai/whisper-tiny".to_string(), RepoType::Model, "main".to_string(), )); - let config = repo.get("config-tiny.json")?; - let tokenizer = repo.get("tokenizer-tiny.json")?; - let model = repo.get("model-tiny-q80.gguf")?; + let config = repo.get("config.json")?; + let tokenizer = repo.get("tokenizer.json")?; + let model = repo.get("model.safetensors")?; (config, tokenizer, model) }; let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; - let vb = VarBuilder::from_gguf(weights_filename, &device)?; - let model = Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?); - + let vb = + unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], m::DTYPE, &device)? }; + let model = Model::Normal(m::model::Whisper::load(&vb, config.clone())?); Ok(Self { - model: model, + model, tokenizer, device, })