diff --git a/screenpipe-audio/src/bin/screenpipe-audio.rs b/screenpipe-audio/src/bin/screenpipe-audio.rs index c0186d33..7465e263 100644 --- a/screenpipe-audio/src/bin/screenpipe-audio.rs +++ b/screenpipe-audio/src/bin/screenpipe-audio.rs @@ -1,6 +1,9 @@ use anyhow::{anyhow, Result}; use clap::Parser; use log::info; +use screenpipe_audio::create_whisper_channel; +use screenpipe_audio::default_input_device; +use screenpipe_audio::default_output_device; use screenpipe_audio::list_audio_devices; use screenpipe_audio::parse_device_spec; use screenpipe_audio::record_and_transcribe; @@ -12,8 +15,12 @@ use std::time::Duration; #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] struct Args { - #[clap(short, long, help = "Audio device name")] - audio_device: Option, + #[clap( + short, + long, + help = "Audio device name (can be specified multiple times)" + )] + audio_device: Vec, #[clap(long, help = "List available audio devices")] list_audio_devices: bool, @@ -45,51 +52,58 @@ fn main() -> Result<()> { return Ok(()); } - let device = match args.audio_device { - Some(d) => parse_device_spec(&d).unwrap(), - None => { - if devices.is_empty() { - return Err(anyhow!("No audio input devices found")); - } - eprintln!("No audio device specified. Available devices are:"); - print_devices(&devices); - eprintln!("\nPlease specify one or more devices with:"); - eprintln!( - " {} --audio-device \"Device Name (input)\" [--audio-device \"Another Device (output)\"]", - std::env::args().next().unwrap() - ); - return Err(anyhow!("No device specified")); - } + let devices = if args.audio_device.is_empty() { + vec![default_input_device()?, default_output_device()?] + } else { + args.audio_device + .iter() + .map(|d| parse_device_spec(d)) + .collect::>>()? }; - let (result_tx, result_rx) = mpsc::channel(); + if devices.is_empty() { + return Err(anyhow!("No audio input devices found")); + } + let chunk_duration = Duration::from_secs(30); let output_path = PathBuf::from("output.wav"); - // Spawn a thread to handle the recording and transcription - let recording_thread = thread::spawn(move || { - record_and_transcribe(&device, chunk_duration, result_tx, output_path) - }); + let (whisper_sender, whisper_receiver) = create_whisper_channel()?; + + // Spawn threads for each device + let recording_threads: Vec<_> = devices + .into_iter() + .enumerate() + .map(|(i, device)| { + let whisper_sender = whisper_sender.clone(); + let output_path = output_path.with_file_name(format!("output_{}.wav", i)); + thread::spawn(move || { + record_and_transcribe(&device, chunk_duration, output_path, whisper_sender) + }) + }) + .collect(); // Main loop to receive and print transcriptions loop { - match result_rx.recv_timeout(Duration::from_secs(5)) { + match whisper_receiver.recv_timeout(Duration::from_secs(5)) { Ok(result) => { - info!("Transcription: {}", result.text); + info!("Transcription: {:?}", result); } - Err(mpsc::RecvTimeoutError::Timeout) => { + Err(crossbeam::channel::RecvTimeoutError::Timeout) => { // No transcription received in 5 seconds, continue waiting continue; } - Err(mpsc::RecvTimeoutError::Disconnected) => { - // Sender has been dropped, recording is complete + Err(crossbeam::channel::RecvTimeoutError::Disconnected) => { + // All senders have been dropped, recording is complete break; } } } - // Wait for the recording thread to finish - let file_path = recording_thread.join().unwrap()?; - println!("Recording complete: {:?}", file_path); + // Wait for all recording threads to finish + for (i, thread) in recording_threads.into_iter().enumerate() { + let file_path = thread.join().unwrap()?; + println!("Recording {} complete: {:?}", i, file_path); + } Ok(()) } diff --git a/screenpipe-audio/src/core.rs b/screenpipe-audio/src/core.rs index e3a40129..8431abd3 100644 --- a/screenpipe-audio/src/core.rs +++ b/screenpipe-audio/src/core.rs @@ -12,6 +12,8 @@ use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; use std::{io::BufWriter, thread}; +use crate::AudioInput; + pub struct AudioCaptureResult { pub text: String, } @@ -82,7 +84,7 @@ pub fn record_and_transcribe( device_spec: &DeviceSpec, duration: Duration, output_path: PathBuf, - whisper_sender: Sender, + whisper_sender: Sender, ) -> Result { let host = match device_spec { #[cfg(target_os = "macos")] @@ -172,7 +174,10 @@ pub fn record_and_transcribe( writer.flush()?; // Send the file path to the whisper channel - whisper_sender.send(output_path.to_str().unwrap().to_string())?; + whisper_sender.send(AudioInput { + path: output_path.to_str().unwrap().to_string(), + device: device_spec.to_string(), + })?; } } @@ -191,7 +196,10 @@ pub fn record_and_transcribe( writer.flush()?; // Send the file path to the whisper channel - whisper_sender.send(output_path.to_str().unwrap().to_string())?; + whisper_sender.send(AudioInput { + path: output_path.to_str().unwrap().to_string(), + device: device_spec.to_string(), + })?; } } diff --git a/screenpipe-audio/src/lib.rs b/screenpipe-audio/src/lib.rs index 38f62bd4..d5f6907c 100644 --- a/screenpipe-audio/src/lib.rs +++ b/screenpipe-audio/src/lib.rs @@ -6,4 +6,4 @@ pub use core::{ default_input_device, default_output_device, list_audio_devices, parse_device_spec, record_and_transcribe, AudioCaptureResult, AudioDevice, DeviceSpec, }; -pub use stt::create_whisper_channel; +pub use stt::{create_whisper_channel, AudioInput, TranscriptionResult}; diff --git a/screenpipe-audio/src/stt.rs b/screenpipe-audio/src/stt.rs index 18690bd1..35b7b9e8 100644 --- a/screenpipe-audio/src/stt.rs +++ b/screenpipe-audio/src/stt.rs @@ -1,4 +1,7 @@ -use std::thread; +use std::{ + thread, + time::{SystemTime, UNIX_EPOCH}, +}; use anyhow::{Error as E, Result}; use candle::{Device, IndexOp, Tensor}; @@ -393,7 +396,7 @@ enum Task { Translate, } -pub fn stt(input: &str, whisper_model: &WhisperModel) -> Result { +pub fn stt(file_path: &str, whisper_model: &WhisperModel) -> Result { info!("Starting speech to text"); let mut model = &whisper_model.model; let tokenizer = &whisper_model.tokenizer; @@ -407,7 +410,7 @@ pub fn stt(input: &str, whisper_model: &WhisperModel) -> Result { let mut mel_filters = vec![0f32; mel_bytes.len() / 4]; ::read_f32_into(mel_bytes, &mut mel_filters); - let (mut pcm_data, sample_rate) = pcm_decode(input)?; + let (mut pcm_data, sample_rate) = pcm_decode(file_path)?; if sample_rate != m::SAMPLE_RATE as u32 { info!( "Resampling from {} Hz to {} Hz", @@ -479,24 +482,54 @@ fn resample(input: Vec, from_sample_rate: u32, to_sample_rate: u32) -> Resu Ok(waves_out.into_iter().next().unwrap()) } -pub fn create_whisper_channel() -> Result<(Sender, Receiver)> { +#[derive(Debug, Clone)] +pub struct AudioInput { + pub path: String, + pub device: String, +} + +#[derive(Debug, Clone)] +pub struct TranscriptionResult { + pub input: AudioInput, + pub transcription: Option, + pub timestamp: u64, + pub error: Option, +} +pub fn create_whisper_channel() -> Result<(Sender, Receiver)> { let whisper_model = WhisperModel::new()?; - let (input_sender, input_receiver): (Sender, Receiver) = channel::unbounded(); - let (output_sender, output_receiver): (Sender, Receiver) = channel::unbounded(); + let (input_sender, input_receiver): (Sender, Receiver) = + channel::unbounded(); + let (output_sender, output_receiver): ( + Sender, + Receiver, + ) = channel::unbounded(); thread::spawn(move || { while let Ok(input) = input_receiver.recv() { - match stt(&input, &whisper_model) { - Ok(result) => { - if output_sender.send(result).is_err() { - break; - } - } - Err(e) => { - if output_sender.send(format!("Error: {}", e)).is_err() { - break; - } - } + let timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time went backwards") + .as_secs(); + + let result = stt(&input.path, &whisper_model); + + let transcription_result = match result { + Ok(transcription) => TranscriptionResult { + input: input.clone(), + transcription: Some(transcription), + timestamp, + error: None, + }, + Err(e) => TranscriptionResult { + input: input.clone(), + transcription: None, + timestamp, + error: Some(e.to_string()), + }, + }; + + if output_sender.send(transcription_result).is_err() { + break; } } }); diff --git a/screenpipe-server/src/core.rs b/screenpipe-server/src/core.rs index 167eca47..d3bfcd7a 100644 --- a/screenpipe-server/src/core.rs +++ b/screenpipe-server/src/core.rs @@ -4,7 +4,8 @@ use chrono::Utc; use crossbeam::channel::{Receiver, Sender}; use log::{debug, error, info}; use screenpipe_audio::{ - create_whisper_channel, record_and_transcribe, AudioCaptureResult, DeviceSpec, + create_whisper_channel, record_and_transcribe, AudioCaptureResult, AudioInput, DeviceSpec, + TranscriptionResult, }; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; @@ -129,8 +130,8 @@ async fn record_audio( chunk_duration: Duration, is_running: Arc, devices: Vec>, - whisper_sender: Sender, - whisper_receiver: Receiver, + whisper_sender: Sender, + whisper_receiver: Receiver, ) -> Result<()> { let mut handles = vec![]; @@ -180,16 +181,7 @@ async fn record_audio( // Process the recorded chunk if let Ok(transcription) = whisper_receiver.recv() { - let result = AudioCaptureResult { - text: transcription, - }; - process_audio_result( - &db_clone, - &file_path.to_str().unwrap(), - &device_spec_clone, - result, - ) - .await; + process_audio_result(&db_clone, transcription).await; } } Ok(Err(e)) => error!("Error in record_and_transcribe: {}", e), @@ -209,33 +201,36 @@ async fn record_audio( Ok(()) } -async fn process_audio_result( - db: &DatabaseManager, - output_path: &str, - device_spec: &DeviceSpec, - result: AudioCaptureResult, -) { - info!("Inserting audio chunk: {:?}", result.text); - match db.insert_audio_chunk(&output_path).await { +async fn process_audio_result(db: &DatabaseManager, result: TranscriptionResult) { + info!("Inserting audio chunk: {:?}", result.transcription); + if result.error.is_some() || result.transcription.is_none() { + error!( + "Error in audio recording: {}", + result.error.unwrap_or_default() + ); + return; + } + let transcription = result.transcription.unwrap(); + match db.insert_audio_chunk(&result.input.path).await { Ok(audio_chunk_id) => { if let Err(e) = db - .insert_audio_transcription(audio_chunk_id, &result.text, 0) // TODO index is in the text atm + .insert_audio_transcription(audio_chunk_id, &transcription, 0) // TODO index is in the text atm .await { error!( "Failed to insert audio transcription for device {}: {}", - device_spec, e + result.input.device, e ); } else { debug!( "Inserted audio transcription for chunk {} from device {}", - audio_chunk_id, device_spec + audio_chunk_id, result.input.device ); } } Err(e) => error!( "Failed to insert audio chunk for device {}: {}", - device_spec, e + result.input.device, e ), } }