Skip to content

Commit

Permalink
Merge pull request #82 from louis030195/deepgram-audio-now
Browse files Browse the repository at this point in the history
Deepgram audio now
  • Loading branch information
m13v authored Aug 1, 2024
2 parents 5027aab + f5cecb5 commit 69cf382
Show file tree
Hide file tree
Showing 7 changed files with 205 additions and 58 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ if you want to run screenpipe in debug mode to show more logs in terminal:
```bash
screenpipe --debug
```
by default screenpipe is using deepgram nova-t text-to-audio model via cloud api. To use whisper-tiny that runs locally you should add this flag:
```bash
screenpipe --cloud-audio-off
```

you can combine multiple flags if needed

Expand Down
3 changes: 3 additions & 0 deletions screenpipe-audio/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ tokio = { workspace = true }
# Detect speech/silence
webrtc-vad = "0.4.0"

# Deepgram
reqwest = { version = "0.12.5", features = ["json", "blocking"] }

screenpipe-core = { path = "../screenpipe-core" }

[dev-dependencies]
Expand Down
8 changes: 6 additions & 2 deletions screenpipe-audio/src/bin/screenpipe-audio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ struct Args {

#[clap(long, help = "List available audio devices")]
list_audio_devices: bool,

#[clap(long, help = "Disable cloud audio processing")]
cloud_audio_off: bool,
}

fn print_devices(devices: &[AudioDevice]) {
Expand Down Expand Up @@ -74,7 +77,8 @@ async fn main() -> Result<()> {

let chunk_duration = Duration::from_secs(5);
let output_path = PathBuf::from("output.mp4");
let (whisper_sender, mut whisper_receiver) = create_whisper_channel().await?;
let cloud_audio = !args.cloud_audio_off;
let (whisper_sender, mut whisper_receiver) = create_whisper_channel(cloud_audio).await?;
// Spawn threads for each device
let recording_threads: Vec<_> = devices
.into_iter()
Expand Down Expand Up @@ -128,4 +132,4 @@ async fn main() -> Result<()> {
}

Ok(())
}
}
224 changes: 176 additions & 48 deletions screenpipe-audio/src/stt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ use crate::{multilingual, pcm_decode::pcm_decode};

use webrtc_vad::{Vad, VadMode};

use hound::{WavSpec, WavWriter};
use std::io::Cursor;

#[derive(Clone)]
pub struct WhisperModel {
pub model: Model,
Expand Down Expand Up @@ -399,7 +402,79 @@ enum Task {
Translate,
}

pub fn stt(file_path: &str, whisper_model: &WhisperModel) -> Result<String> {
use reqwest::blocking::Client;
use serde_json::Value;

// Replace the get_deepgram_api_key function with this:
fn get_deepgram_api_key() -> String {
"7ed2a159a094337b01fd8178b914b7ae0e77822d".to_string()
}

fn transcribe_with_deepgram(api_key: &str, audio_data: &[f32]) -> Result<String> {
debug!("Starting Deepgram transcription");
let client = Client::new();

// Create a WAV file in memory
let mut cursor = Cursor::new(Vec::new());
{
let spec = WavSpec {
channels: 1,
sample_rate: 16000,
bits_per_sample: 32,
sample_format: hound::SampleFormat::Float,
};
let mut writer = WavWriter::new(&mut cursor, spec)?;
for &sample in audio_data {
writer.write_sample(sample)?;
}
writer.finalize()?;
}

// Get the WAV data from the cursor
let wav_data = cursor.into_inner();

let response = client.post("https://api.deepgram.com/v1/listen?model=nova-2&smart_format=true")
.header("Content-Type", "audio/wav")
.header("Authorization", format!("Token {}", api_key))
.body(wav_data)
.send();

match response {
Ok(resp) => {
debug!("Received response from Deepgram API");
match resp.json::<Value>() {
Ok(result) => {
debug!("Successfully parsed JSON response");
if let Some(err_code) = result.get("err_code") {
error!("Deepgram API error code: {:?}, result: {:?}", err_code, result);
return Err(anyhow::anyhow!("Deepgram API error: {:?}", result));
}
let transcription = result["results"]["channels"][0]["alternatives"][0]["transcript"]
.as_str()
.unwrap_or("");

if transcription.is_empty() {
info!("Transcription is empty. Full response: {:?}", result);
} else {
info!("Transcription successful. Length: {} characters", transcription.len());
}

Ok(transcription.to_string())
},
Err(e) => {
error!("Failed to parse JSON response: {:?}", e);
Err(anyhow::anyhow!("Failed to parse JSON response: {:?}", e))
}
}
},
Err(e) => {
error!("Failed to send request to Deepgram API: {:?}", e);
Err(anyhow::anyhow!("Failed to send request to Deepgram API: {:?}", e))
}
}
}

pub fn stt(file_path: &str, whisper_model: &WhisperModel, cloud_audio: bool) -> Result<String> {
debug!("Starting speech to text for file: {}", file_path);
let model = &whisper_model.model;
let tokenizer = &whisper_model.tokenizer;
Expand Down Expand Up @@ -431,7 +506,7 @@ pub fn stt(file_path: &str, whisper_model: &WhisperModel) -> Result<String> {
vad.set_mode(VadMode::VeryAggressive); // Set mode to very aggressive

// Filter out non-speech segments
// debug!("VAD: Filtering out non-speech segments");
debug!("VAD: Filtering out non-speech segments");
let frame_size = 160; // 10ms frame size for 16kHz audio
let mut speech_frames = Vec::new();
for (frame_index, chunk) in pcm_data.chunks(frame_size).enumerate() {
Expand All @@ -440,21 +515,21 @@ pub fn stt(file_path: &str, whisper_model: &WhisperModel) -> Result<String> {
match vad.is_voice_segment(&i16_chunk) {
Ok(is_voice) => {
if is_voice {
debug!("VAD: Speech detected in frame {}", frame_index);
// debug!("VAD: Speech detected in frame {}", frame_index);
speech_frames.extend_from_slice(chunk);
} else {
// debug!("VAD: Non-speech frame {} filtered out", frame_index);
}
},
Err(e) => {
error!("VAD failed for frame {}: {:?}", frame_index, e);
debug!("VAD failed for frame {}: {:?}", frame_index, e);
// Optionally, you can choose to include the frame if VAD fails
// speech_frames.extend_from_slice(chunk);
}
}
}

debug!("Total frames processed: {}, Speech frames: {}", pcm_data.len() / frame_size, speech_frames.len() / frame_size);
info!("Total audio_frames processed: {}, frames that include speech: {}", pcm_data.len() / frame_size, speech_frames.len() / frame_size);

// If no speech frames detected, skip processing
if speech_frames.is_empty() {
Expand All @@ -464,47 +539,100 @@ pub fn stt(file_path: &str, whisper_model: &WhisperModel) -> Result<String> {

debug!("Using {} speech frames out of {} total frames", speech_frames.len() / frame_size, pcm_data.len() / frame_size);

debug!("Converting PCM to mel spectrogram");
// let mel = audio::pcm_to_mel(&model.config(), &pcm_data, &mel_filters);
let mel = audio::pcm_to_mel(&model.config(), &speech_frames, &mel_filters);
let mel_len = mel.len();
debug!("Creating tensor from mel spectrogram");
let mel = Tensor::from_vec(
mel,
(
1,
model.config().num_mel_bins,
mel_len / model.config().num_mel_bins,
),
&device,
)?;

debug!("Detecting language");
let language_token = Some(multilingual::detect_language(
&mut model.clone(),
&tokenizer,
&mel,
)?);
let mut model = model.clone();
debug!("Initializing decoder");
let mut dc = Decoder::new(
&mut model,
tokenizer,
42,
&device,
language_token,
Some(Task::Transcribe),
true,
false,
)?;
debug!("Starting decoding process");
let segments = dc.run(&mel)?;
debug!("Decoding complete");
Ok(segments
.iter()
.map(|s| s.dr.text.clone())
.collect::<Vec<String>>()
.join("\n"))
if cloud_audio {
// Deepgram implementation
let api_key = get_deepgram_api_key();
match transcribe_with_deepgram(&api_key, &speech_frames) {
Ok(transcription) => Ok(transcription),
Err(e) => {
error!("Deepgram transcription failed, falling back to Whisper: {:?}", e);
// Existing Whisper implementation
debug!("Converting PCM to mel spectrogram");
let mel = audio::pcm_to_mel(&model.config(), &speech_frames, &mel_filters);
let mel_len = mel.len();
debug!("Creating tensor from mel spectrogram");
let mel = Tensor::from_vec(
mel,
(
1,
model.config().num_mel_bins,
mel_len / model.config().num_mel_bins,
),
&device,
)?;

debug!("Detecting language");
let language_token = Some(multilingual::detect_language(
&mut model.clone(),
&tokenizer,
&mel,
)?);
let mut model = model.clone();
debug!("Initializing decoder");
let mut dc = Decoder::new(
&mut model,
tokenizer,
42,
&device,
language_token,
Some(Task::Transcribe),
true,
false,
)?;
debug!("Starting decoding process");
let segments = dc.run(&mel)?;
debug!("Decoding complete");
Ok(segments
.iter()
.map(|s| s.dr.text.clone())
.collect::<Vec<String>>()
.join("\n"))
}
}
} else {
// Existing Whisper implementation
debug!("Starting Whisper transcription");
debug!("Converting PCM to mel spectrogram");
let mel = audio::pcm_to_mel(&model.config(), &speech_frames, &mel_filters);
let mel_len = mel.len();
debug!("Creating tensor from mel spectrogram");
let mel = Tensor::from_vec(
mel,
(
1,
model.config().num_mel_bins,
mel_len / model.config().num_mel_bins,
),
&device,
)?;

debug!("Detecting language");
let language_token = Some(multilingual::detect_language(
&mut model.clone(),
&tokenizer,
&mel,
)?);
let mut model = model.clone();
debug!("Initializing decoder");
let mut dc = Decoder::new(
&mut model,
tokenizer,
42,
&device,
language_token,
Some(Task::Transcribe),
true,
false,
)?;
debug!("Starting decoding process");
let segments = dc.run(&mel)?;
debug!("Decoding complete");
Ok(segments
.iter()
.map(|s| s.dr.text.clone())
.collect::<Vec<String>>()
.join("\n"))
}
}

fn resample(input: Vec<f32>, from_sample_rate: u32, to_sample_rate: u32) -> Result<Vec<f32>> {
Expand Down Expand Up @@ -545,7 +673,7 @@ pub struct TranscriptionResult {
pub timestamp: u64,
pub error: Option<String>,
}
pub async fn create_whisper_channel() -> Result<(
pub async fn create_whisper_channel(cloud_audio: bool) -> Result<(
UnboundedSender<AudioInput>,
UnboundedReceiver<TranscriptionResult>,
)> {
Expand All @@ -568,7 +696,7 @@ pub async fn create_whisper_channel() -> Result<(
.expect("Time went backwards")
.as_secs();

let result = stt(&input.path, &whisper_model);
let result = stt(&input.path, &whisper_model, cloud_audio);

let transcription_result = match result {
Ok(transcription) => TranscriptionResult {
Expand Down
12 changes: 7 additions & 5 deletions screenpipe-audio/tests/core_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ mod tests {
use chrono::Utc;
use log::{debug, LevelFilter};
use screenpipe_audio::{default_output_device, list_audio_devices, stt, WhisperModel};
use screenpipe_audio::{parse_audio_device, record_and_transcribe};
use screenpipe_audio::{parse_audio_device, record_and_transcribe, create_whisper_channel};
use std::path::PathBuf;
use std::process::Command;
use std::str::FromStr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::mpsc::unbounded_channel;
use tokio::time::timeout;

fn setup() {
// Initialize the logger with an info level filter
Expand Down Expand Up @@ -43,12 +44,12 @@ mod tests {
fn test_speech_to_text() {
setup();
println!("Starting speech to text test");

println!("Loading audio file");
let start = std::time::Instant::now();
let whisper_model = WhisperModel::new().unwrap();
let cloud_audio = true; // Set this based on your test requirements

let text = stt("./test_data/selah.mp4", &whisper_model).unwrap();
let text = stt("./test_data/selah.mp4", &whisper_model, cloud_audio).unwrap();
let duration = start.elapsed();

println!("Speech to text completed in {:?}", duration);
Expand Down Expand Up @@ -223,7 +224,8 @@ mod tests {
let output_path =
PathBuf::from(format!("test_output_{}.mp4", Utc::now().timestamp_millis()));
let output_path_2 = output_path.clone();
let (whisper_sender, mut whisper_receiver) = create_whisper_channel().await.unwrap();
let cloud_audio = true; // Set this based on your test requirements
let (whisper_sender, mut whisper_receiver) = create_whisper_channel(cloud_audio).await.unwrap();
let is_running = Arc::new(AtomicBool::new(true));
// Start recording in a separate thread
let recording_thread = tokio::spawn(async move {
Expand Down Expand Up @@ -285,4 +287,4 @@ mod tests {
let _ = recording_thread.abort();
std::fs::remove_file(output_path_2).unwrap_or_default();
}
}
}
Loading

0 comments on commit 69cf382

Please sign in to comment.