Skip to content

Commit

Permalink
v0 dirty #241
Browse files Browse the repository at this point in the history
  • Loading branch information
louis030195 committed Sep 1, 2024
1 parent 194eb11 commit 6a96ea1
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 4 deletions.
1 change: 1 addition & 0 deletions screenpipe-audio/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ mod core;
mod multilingual;
pub mod pcm_decode;
pub mod stt;
pub mod vad_engine;
pub use core::{
default_input_device, default_output_device, list_audio_devices, parse_audio_device,
record_and_transcribe, AudioDevice, AudioTranscriptionEngine, DeviceControl, DeviceType
Expand Down
15 changes: 12 additions & 3 deletions screenpipe-audio/src/stt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@ use rubato::{
Resampler, SincFixedIn, SincInterpolationParameters, SincInterpolationType, WindowFunction,
};

use crate::{multilingual, pcm_decode::pcm_decode, AudioTranscriptionEngine};
use crate::{
multilingual,
pcm_decode::pcm_decode,
vad_engine::{SileroVad, VadEngine, VadEngineEnum, WebRtcVad},
AudioTranscriptionEngine,
};

use webrtc_vad::{Vad, VadMode};

Expand Down Expand Up @@ -719,6 +724,7 @@ pub struct TranscriptionResult {
}
pub async fn create_whisper_channel(
audio_transcription_engine: Arc<AudioTranscriptionEngine>,
vad_engine: VadEngineEnum,
) -> Result<(
UnboundedSender<AudioInput>,
UnboundedReceiver<TranscriptionResult>,
Expand All @@ -732,7 +738,10 @@ pub async fn create_whisper_channel(
UnboundedSender<TranscriptionResult>,
UnboundedReceiver<TranscriptionResult>,
) = unbounded_channel();

let mut vad_engine: Box<dyn VadEngine> = match vad_engine {
VadEngineEnum::WebRtc => Box::new(WebRtcVad::new()),
VadEngineEnum::Silero => Box::new(SileroVad::new().unwrap()),
};
tokio::spawn(async move {
loop {
tokio::select! {
Expand All @@ -742,7 +751,7 @@ pub async fn create_whisper_channel(
.expect("Time went backwards")
.as_secs();

let transcription_result = match stt(&input.path, &whisper_model, audio_transcription_engine.clone()) {
let transcription_result = match stt(&input.path, &whisper_model, audio_transcription_engine.clone(), &mut *vad_engine) {
Ok(transcription) => TranscriptionResult {
input: input.clone(),
transcription: Some(transcription),
Expand Down
78 changes: 78 additions & 0 deletions screenpipe-audio/src/vad_engine.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use anyhow::Result;
use candle::{Device, Tensor};
use candle_nn::VarBuilder;
use hf_hub::{api::sync::Api, Repo, RepoType};
use std::path::PathBuf;

pub enum VadEngineEnum {
WebRtc,
Silero,
}

pub trait VadEngine {
fn is_voice_segment(&mut self, audio_chunk: &[i16]) -> Result<bool>;
}

pub struct WebRtcVad(webrtc_vad::Vad);

impl WebRtcVad {
pub fn new() -> Self {
let mut vad = webrtc_vad::Vad::new();
vad.set_mode(webrtc_vad::VadMode::Quality);
Self(vad)
}
}

impl VadEngine for WebRtcVad {
fn is_voice_segment(&mut self, audio_chunk: &[i16]) -> Result<bool> {
self.0.is_voice_segment(audio_chunk).map_err(Into::into)
}
}

pub struct SileroVad {
model: candle_nn::Module,
device: Device,
}

impl SileroVad {
pub fn new() -> Result<Self> {
let device = Device::Cpu;
let repo = Repo::with_revision(
"snakers4/silero-vad".to_string(),
RepoType::Model,
"master".to_string(),
);
let api = Api::new()?;
let api = api.repo(repo);
let model_path: PathBuf = api.get("silero_vad.onnx")?;

let vb = VarBuilder::from_onnx(model_path, &device)?;
let model = candle_nn::Module::new(vb)?;

Ok(Self { model, device })
}

fn preprocess_audio(&self, audio_chunk: &[i16]) -> anyhow::Result<Tensor> {
let float_chunk: Vec<f32> = audio_chunk.iter().map(|&x| x as f32 / 32768.0).collect();
Tensor::from_vec(float_chunk, (1, audio_chunk.len()), &self.device)
}
}

impl VadEngine for SileroVad {
fn is_voice_segment(&mut self, audio_chunk: &[i16]) -> Result<bool> {
let input = self.preprocess_audio(audio_chunk)?;
let output = self.model.forward(&input)?;
let probability = output.squeeze(0)?.squeeze(0)?.to_vec1::<f32>()?[0];

// You may need to adjust this threshold based on your specific use case
const VOICE_THRESHOLD: f32 = 0.5;
Ok(probability > VOICE_THRESHOLD)
}
}

pub fn create_vad_engine(engine: VadEngineEnum) -> Result<Box<dyn VadEngine>> {
match engine {
VadEngineEnum::WebRtc => Ok(Box::new(WebRtcVad::new())),
VadEngineEnum::Silero => Ok(Box::new(SileroVad::new()?)),
}
}
21 changes: 21 additions & 0 deletions screenpipe-server/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,23 @@ impl From<CliOcrEngine> for CoreOcrEngine {
}
}

#[derive(Clone, Debug, ValueEnum, PartialEq)]
pub enum CliVadEngine {
#[clap(name = "webrtc")]
WebRtc,
#[clap(name = "silero")]
Silero,
}

impl From<CliVadEngine> for VadEngineEnum {
fn from(cli_engine: CliVadEngine) -> Self {
match cli_engine {
CliVadEngine::WebRtc => VadEngineEnum::WebRtc,
CliVadEngine::Silero => VadEngineEnum::Silero,
}
}
}

#[derive(Parser)]
#[command(
author,
Expand Down Expand Up @@ -151,4 +168,8 @@ pub struct Cli {
/// Restart recording process every X minutes (0 means no periodic restart)
#[arg(long, default_value_t = 0)]
pub restart_interval: u64,

/// VAD engine to use for speech detection
#[arg(long, value_enum, default_value_t = CliVadEngine::WebRtc)]
pub vad_engine: CliVadEngine,
}
4 changes: 3 additions & 1 deletion screenpipe-server/src/core.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::cli::CliVadEngine;
use crate::{DatabaseManager, VideoCapture};
use anyhow::Result;
use chrono::Utc;
Expand Down Expand Up @@ -32,6 +33,7 @@ pub async fn start_continuous_recording(
friend_wearable_uid: Option<String>,
monitor_id: u32,
use_pii_removal: bool,
vad_engine: CliVadEngine,
) -> Result<()> {
let (whisper_sender, whisper_receiver) = if audio_disabled {
// Create a dummy channel if no audio devices are available, e.g. audio disabled
Expand All @@ -43,7 +45,7 @@ pub async fn start_continuous_recording(
) = unbounded_channel();
(input_sender, output_receiver)
} else {
create_whisper_channel(audio_transcription_engine.clone()).await?
create_whisper_channel(audio_transcription_engine.clone(), vad_engine).await?
};
let db_manager_video = Arc::clone(&db);
let db_manager_audio = Arc::clone(&db);
Expand Down

0 comments on commit 6a96ea1

Please sign in to comment.