Skip to content

Commit

Permalink
chore: properly propagate metadata across the channel
Browse files Browse the repository at this point in the history
  • Loading branch information
louis030195 committed Jul 10, 2024
1 parent c7d644e commit f9b5fe6
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 76 deletions.
74 changes: 44 additions & 30 deletions screenpipe-audio/src/bin/screenpipe-audio.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use anyhow::{anyhow, Result};
use clap::Parser;
use log::info;
use screenpipe_audio::create_whisper_channel;
use screenpipe_audio::default_input_device;
use screenpipe_audio::default_output_device;
use screenpipe_audio::list_audio_devices;
use screenpipe_audio::parse_device_spec;
use screenpipe_audio::record_and_transcribe;
Expand All @@ -12,8 +15,12 @@ use std::time::Duration;
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[clap(short, long, help = "Audio device name")]
audio_device: Option<String>,
#[clap(
short,
long,
help = "Audio device name (can be specified multiple times)"
)]
audio_device: Vec<String>,

#[clap(long, help = "List available audio devices")]
list_audio_devices: bool,
Expand Down Expand Up @@ -45,51 +52,58 @@ fn main() -> Result<()> {
return Ok(());
}

let device = match args.audio_device {
Some(d) => parse_device_spec(&d).unwrap(),
None => {
if devices.is_empty() {
return Err(anyhow!("No audio input devices found"));
}
eprintln!("No audio device specified. Available devices are:");
print_devices(&devices);
eprintln!("\nPlease specify one or more devices with:");
eprintln!(
" {} --audio-device \"Device Name (input)\" [--audio-device \"Another Device (output)\"]",
std::env::args().next().unwrap()
);
return Err(anyhow!("No device specified"));
}
let devices = if args.audio_device.is_empty() {
vec![default_input_device()?, default_output_device()?]
} else {
args.audio_device
.iter()
.map(|d| parse_device_spec(d))
.collect::<Result<Vec<_>>>()?
};

let (result_tx, result_rx) = mpsc::channel();
if devices.is_empty() {
return Err(anyhow!("No audio input devices found"));
}

let chunk_duration = Duration::from_secs(30);
let output_path = PathBuf::from("output.wav");
// Spawn a thread to handle the recording and transcription
let recording_thread = thread::spawn(move || {
record_and_transcribe(&device, chunk_duration, result_tx, output_path)
});
let (whisper_sender, whisper_receiver) = create_whisper_channel()?;

// Spawn threads for each device
let recording_threads: Vec<_> = devices
.into_iter()
.enumerate()
.map(|(i, device)| {
let whisper_sender = whisper_sender.clone();
let output_path = output_path.with_file_name(format!("output_{}.wav", i));
thread::spawn(move || {
record_and_transcribe(&device, chunk_duration, output_path, whisper_sender)
})
})
.collect();

// Main loop to receive and print transcriptions
loop {
match result_rx.recv_timeout(Duration::from_secs(5)) {
match whisper_receiver.recv_timeout(Duration::from_secs(5)) {
Ok(result) => {
info!("Transcription: {}", result.text);
info!("Transcription: {:?}", result);
}
Err(mpsc::RecvTimeoutError::Timeout) => {
Err(crossbeam::channel::RecvTimeoutError::Timeout) => {
// No transcription received in 5 seconds, continue waiting
continue;
}
Err(mpsc::RecvTimeoutError::Disconnected) => {
// Sender has been dropped, recording is complete
Err(crossbeam::channel::RecvTimeoutError::Disconnected) => {
// All senders have been dropped, recording is complete
break;
}
}
}

// Wait for the recording thread to finish
let file_path = recording_thread.join().unwrap()?;
println!("Recording complete: {:?}", file_path);
// Wait for all recording threads to finish
for (i, thread) in recording_threads.into_iter().enumerate() {
let file_path = thread.join().unwrap()?;
println!("Recording {} complete: {:?}", i, file_path);
}

Ok(())
}
14 changes: 11 additions & 3 deletions screenpipe-audio/src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use std::{io::BufWriter, thread};

use crate::AudioInput;

pub struct AudioCaptureResult {
pub text: String,
}
Expand Down Expand Up @@ -82,7 +84,7 @@ pub fn record_and_transcribe(
device_spec: &DeviceSpec,
duration: Duration,
output_path: PathBuf,
whisper_sender: Sender<String>,
whisper_sender: Sender<AudioInput>,
) -> Result<PathBuf> {
let host = match device_spec {
#[cfg(target_os = "macos")]
Expand Down Expand Up @@ -172,7 +174,10 @@ pub fn record_and_transcribe(
writer.flush()?;

// Send the file path to the whisper channel
whisper_sender.send(output_path.to_str().unwrap().to_string())?;
whisper_sender.send(AudioInput {
path: output_path.to_str().unwrap().to_string(),
device: device_spec.to_string(),
})?;
}
}

Expand All @@ -191,7 +196,10 @@ pub fn record_and_transcribe(
writer.flush()?;

// Send the file path to the whisper channel
whisper_sender.send(output_path.to_str().unwrap().to_string())?;
whisper_sender.send(AudioInput {
path: output_path.to_str().unwrap().to_string(),
device: device_spec.to_string(),
})?;
}
}

Expand Down
2 changes: 1 addition & 1 deletion screenpipe-audio/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ pub use core::{
default_input_device, default_output_device, list_audio_devices, parse_device_spec,
record_and_transcribe, AudioCaptureResult, AudioDevice, DeviceSpec,
};
pub use stt::create_whisper_channel;
pub use stt::{create_whisper_channel, AudioInput, TranscriptionResult};
67 changes: 50 additions & 17 deletions screenpipe-audio/src/stt.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::thread;
use std::{
thread,
time::{SystemTime, UNIX_EPOCH},
};

use anyhow::{Error as E, Result};
use candle::{Device, IndexOp, Tensor};
Expand Down Expand Up @@ -393,7 +396,7 @@ enum Task {
Translate,
}

pub fn stt(input: &str, whisper_model: &WhisperModel) -> Result<String> {
pub fn stt(file_path: &str, whisper_model: &WhisperModel) -> Result<String> {
info!("Starting speech to text");
let mut model = &whisper_model.model;
let tokenizer = &whisper_model.tokenizer;
Expand All @@ -407,7 +410,7 @@ pub fn stt(input: &str, whisper_model: &WhisperModel) -> Result<String> {
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);

let (mut pcm_data, sample_rate) = pcm_decode(input)?;
let (mut pcm_data, sample_rate) = pcm_decode(file_path)?;
if sample_rate != m::SAMPLE_RATE as u32 {
info!(
"Resampling from {} Hz to {} Hz",
Expand Down Expand Up @@ -479,24 +482,54 @@ fn resample(input: Vec<f32>, from_sample_rate: u32, to_sample_rate: u32) -> Resu
Ok(waves_out.into_iter().next().unwrap())
}

pub fn create_whisper_channel() -> Result<(Sender<String>, Receiver<String>)> {
#[derive(Debug, Clone)]
pub struct AudioInput {
pub path: String,
pub device: String,
}

#[derive(Debug, Clone)]
pub struct TranscriptionResult {
pub input: AudioInput,
pub transcription: Option<String>,
pub timestamp: u64,
pub error: Option<String>,
}
pub fn create_whisper_channel() -> Result<(Sender<AudioInput>, Receiver<TranscriptionResult>)> {
let whisper_model = WhisperModel::new()?;
let (input_sender, input_receiver): (Sender<String>, Receiver<String>) = channel::unbounded();
let (output_sender, output_receiver): (Sender<String>, Receiver<String>) = channel::unbounded();
let (input_sender, input_receiver): (Sender<AudioInput>, Receiver<AudioInput>) =
channel::unbounded();
let (output_sender, output_receiver): (
Sender<TranscriptionResult>,
Receiver<TranscriptionResult>,
) = channel::unbounded();

thread::spawn(move || {
while let Ok(input) = input_receiver.recv() {
match stt(&input, &whisper_model) {
Ok(result) => {
if output_sender.send(result).is_err() {
break;
}
}
Err(e) => {
if output_sender.send(format!("Error: {}", e)).is_err() {
break;
}
}
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time went backwards")
.as_secs();

let result = stt(&input.path, &whisper_model);

let transcription_result = match result {
Ok(transcription) => TranscriptionResult {
input: input.clone(),
transcription: Some(transcription),
timestamp,
error: None,
},
Err(e) => TranscriptionResult {
input: input.clone(),
transcription: None,
timestamp,
error: Some(e.to_string()),
},
};

if output_sender.send(transcription_result).is_err() {
break;
}
}
});
Expand Down
45 changes: 20 additions & 25 deletions screenpipe-server/src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use chrono::Utc;
use crossbeam::channel::{Receiver, Sender};
use log::{debug, error, info};
use screenpipe_audio::{
create_whisper_channel, record_and_transcribe, AudioCaptureResult, DeviceSpec,
create_whisper_channel, record_and_transcribe, AudioCaptureResult, AudioInput, DeviceSpec,
TranscriptionResult,
};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
Expand Down Expand Up @@ -129,8 +130,8 @@ async fn record_audio(
chunk_duration: Duration,
is_running: Arc<AtomicBool>,
devices: Vec<Arc<DeviceSpec>>,
whisper_sender: Sender<String>,
whisper_receiver: Receiver<String>,
whisper_sender: Sender<AudioInput>,
whisper_receiver: Receiver<TranscriptionResult>,
) -> Result<()> {
let mut handles = vec![];

Expand Down Expand Up @@ -180,16 +181,7 @@ async fn record_audio(

// Process the recorded chunk
if let Ok(transcription) = whisper_receiver.recv() {
let result = AudioCaptureResult {
text: transcription,
};
process_audio_result(
&db_clone,
&file_path.to_str().unwrap(),
&device_spec_clone,
result,
)
.await;
process_audio_result(&db_clone, transcription).await;
}
}
Ok(Err(e)) => error!("Error in record_and_transcribe: {}", e),
Expand All @@ -209,33 +201,36 @@ async fn record_audio(

Ok(())
}
async fn process_audio_result(
db: &DatabaseManager,
output_path: &str,
device_spec: &DeviceSpec,
result: AudioCaptureResult,
) {
info!("Inserting audio chunk: {:?}", result.text);
match db.insert_audio_chunk(&output_path).await {
async fn process_audio_result(db: &DatabaseManager, result: TranscriptionResult) {
info!("Inserting audio chunk: {:?}", result.transcription);
if result.error.is_some() || result.transcription.is_none() {
error!(
"Error in audio recording: {}",
result.error.unwrap_or_default()
);
return;
}
let transcription = result.transcription.unwrap();
match db.insert_audio_chunk(&result.input.path).await {
Ok(audio_chunk_id) => {
if let Err(e) = db
.insert_audio_transcription(audio_chunk_id, &result.text, 0) // TODO index is in the text atm
.insert_audio_transcription(audio_chunk_id, &transcription, 0) // TODO index is in the text atm
.await
{
error!(
"Failed to insert audio transcription for device {}: {}",
device_spec, e
result.input.device, e
);
} else {
debug!(
"Inserted audio transcription for chunk {} from device {}",
audio_chunk_id, device_spec
audio_chunk_id, result.input.device
);
}
}
Err(e) => error!(
"Failed to insert audio chunk for device {}: {}",
device_spec, e
result.input.device, e
),
}
}

0 comments on commit f9b5fe6

Please sign in to comment.