diff --git a/screenpipe-audio/src/lib.rs b/screenpipe-audio/src/lib.rs index e2f09087..1cc31fa5 100644 --- a/screenpipe-audio/src/lib.rs +++ b/screenpipe-audio/src/lib.rs @@ -2,6 +2,7 @@ mod core; mod multilingual; pub mod pcm_decode; pub mod stt; +pub mod vad_engine; pub use core::{ default_input_device, default_output_device, list_audio_devices, parse_audio_device, record_and_transcribe, AudioDevice, AudioTranscriptionEngine, DeviceControl, DeviceType diff --git a/screenpipe-audio/src/stt.rs b/screenpipe-audio/src/stt.rs index 54705506..9668cf6d 100644 --- a/screenpipe-audio/src/stt.rs +++ b/screenpipe-audio/src/stt.rs @@ -17,7 +17,12 @@ use rubato::{ Resampler, SincFixedIn, SincInterpolationParameters, SincInterpolationType, WindowFunction, }; -use crate::{multilingual, pcm_decode::pcm_decode, AudioTranscriptionEngine}; +use crate::{ + multilingual, + pcm_decode::pcm_decode, + vad_engine::{SileroVad, VadEngine, VadEngineEnum, WebRtcVad}, + AudioTranscriptionEngine, +}; use webrtc_vad::{Vad, VadMode}; @@ -719,6 +724,7 @@ pub struct TranscriptionResult { } pub async fn create_whisper_channel( audio_transcription_engine: Arc, + vad_engine: VadEngineEnum, ) -> Result<( UnboundedSender, UnboundedReceiver, @@ -732,7 +738,10 @@ pub async fn create_whisper_channel( UnboundedSender, UnboundedReceiver, ) = unbounded_channel(); - + let mut vad_engine: Box = match vad_engine { + VadEngineEnum::WebRtc => Box::new(WebRtcVad::new()), + VadEngineEnum::Silero => Box::new(SileroVad::new().unwrap()), + }; tokio::spawn(async move { loop { tokio::select! { @@ -742,7 +751,7 @@ pub async fn create_whisper_channel( .expect("Time went backwards") .as_secs(); - let transcription_result = match stt(&input.path, &whisper_model, audio_transcription_engine.clone()) { + let transcription_result = match stt(&input.path, &whisper_model, audio_transcription_engine.clone(), &mut *vad_engine) { Ok(transcription) => TranscriptionResult { input: input.clone(), transcription: Some(transcription), diff --git a/screenpipe-audio/src/vad_engine.rs b/screenpipe-audio/src/vad_engine.rs new file mode 100644 index 00000000..8c2cb5cb --- /dev/null +++ b/screenpipe-audio/src/vad_engine.rs @@ -0,0 +1,78 @@ +use anyhow::Result; +use candle::{Device, Tensor}; +use candle_nn::VarBuilder; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use std::path::PathBuf; + +pub enum VadEngineEnum { + WebRtc, + Silero, +} + +pub trait VadEngine { + fn is_voice_segment(&mut self, audio_chunk: &[i16]) -> Result; +} + +pub struct WebRtcVad(webrtc_vad::Vad); + +impl WebRtcVad { + pub fn new() -> Self { + let mut vad = webrtc_vad::Vad::new(); + vad.set_mode(webrtc_vad::VadMode::Quality); + Self(vad) + } +} + +impl VadEngine for WebRtcVad { + fn is_voice_segment(&mut self, audio_chunk: &[i16]) -> Result { + self.0.is_voice_segment(audio_chunk).map_err(Into::into) + } +} + +pub struct SileroVad { + model: candle_nn::Module, + device: Device, +} + +impl SileroVad { + pub fn new() -> Result { + let device = Device::Cpu; + let repo = Repo::with_revision( + "snakers4/silero-vad".to_string(), + RepoType::Model, + "master".to_string(), + ); + let api = Api::new()?; + let api = api.repo(repo); + let model_path: PathBuf = api.get("silero_vad.onnx")?; + + let vb = VarBuilder::from_onnx(model_path, &device)?; + let model = candle_nn::Module::new(vb)?; + + Ok(Self { model, device }) + } + + fn preprocess_audio(&self, audio_chunk: &[i16]) -> anyhow::Result { + let float_chunk: Vec = audio_chunk.iter().map(|&x| x as f32 / 32768.0).collect(); + Tensor::from_vec(float_chunk, (1, audio_chunk.len()), &self.device) + } +} + +impl VadEngine for SileroVad { + fn is_voice_segment(&mut self, audio_chunk: &[i16]) -> Result { + let input = self.preprocess_audio(audio_chunk)?; + let output = self.model.forward(&input)?; + let probability = output.squeeze(0)?.squeeze(0)?.to_vec1::()?[0]; + + // You may need to adjust this threshold based on your specific use case + const VOICE_THRESHOLD: f32 = 0.5; + Ok(probability > VOICE_THRESHOLD) + } +} + +pub fn create_vad_engine(engine: VadEngineEnum) -> Result> { + match engine { + VadEngineEnum::WebRtc => Ok(Box::new(WebRtcVad::new())), + VadEngineEnum::Silero => Ok(Box::new(SileroVad::new()?)), + } +} diff --git a/screenpipe-server/src/cli.rs b/screenpipe-server/src/cli.rs index b59e97a5..cb2f6ec6 100644 --- a/screenpipe-server/src/cli.rs +++ b/screenpipe-server/src/cli.rs @@ -50,6 +50,23 @@ impl From for CoreOcrEngine { } } +#[derive(Clone, Debug, ValueEnum, PartialEq)] +pub enum CliVadEngine { + #[clap(name = "webrtc")] + WebRtc, + #[clap(name = "silero")] + Silero, +} + +impl From for VadEngineEnum { + fn from(cli_engine: CliVadEngine) -> Self { + match cli_engine { + CliVadEngine::WebRtc => VadEngineEnum::WebRtc, + CliVadEngine::Silero => VadEngineEnum::Silero, + } + } +} + #[derive(Parser)] #[command( author, @@ -151,4 +168,8 @@ pub struct Cli { /// Restart recording process every X minutes (0 means no periodic restart) #[arg(long, default_value_t = 0)] pub restart_interval: u64, + + /// VAD engine to use for speech detection + #[arg(long, value_enum, default_value_t = CliVadEngine::WebRtc)] + pub vad_engine: CliVadEngine, } \ No newline at end of file diff --git a/screenpipe-server/src/core.rs b/screenpipe-server/src/core.rs index 028e89ca..a9fc6af7 100644 --- a/screenpipe-server/src/core.rs +++ b/screenpipe-server/src/core.rs @@ -1,3 +1,4 @@ +use crate::cli::CliVadEngine; use crate::{DatabaseManager, VideoCapture}; use anyhow::Result; use chrono::Utc; @@ -32,6 +33,7 @@ pub async fn start_continuous_recording( friend_wearable_uid: Option, monitor_id: u32, use_pii_removal: bool, + vad_engine: CliVadEngine, ) -> Result<()> { let (whisper_sender, whisper_receiver) = if audio_disabled { // Create a dummy channel if no audio devices are available, e.g. audio disabled @@ -43,7 +45,7 @@ pub async fn start_continuous_recording( ) = unbounded_channel(); (input_sender, output_receiver) } else { - create_whisper_channel(audio_transcription_engine.clone()).await? + create_whisper_channel(audio_transcription_engine.clone(), vad_engine).await? }; let db_manager_video = Arc::clone(&db); let db_manager_audio = Arc::clone(&db);