Skip to content

Commit

Permalink
Merge pull request #31 from louis030195/fix-#29
Browse files Browse the repository at this point in the history
#29 fix
  • Loading branch information
louis030195 authored Jul 10, 2024
2 parents 15018f4 + f9b5fe6 commit 30e3813
Show file tree
Hide file tree
Showing 9 changed files with 268 additions and 142 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ candle = { package = "candle-core", version = "0.6.0" }
candle-nn = { package = "candle-nn", version = "0.6.0" }
candle-transformers = { package = "candle-transformers", version = "0.6.0" }
tokenizers = "0.19.1"

tracing = "0.1.37"
6 changes: 6 additions & 0 deletions screenpipe-audio/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ env_logger = "0.10"
# File
tempfile = "3"

# Tracing
tracing = { workspace = true }

# Concurrency
crossbeam = "0.8"

[dev-dependencies]
tempfile = "3.3.0"

Expand Down
74 changes: 44 additions & 30 deletions screenpipe-audio/src/bin/screenpipe-audio.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use anyhow::{anyhow, Result};
use clap::Parser;
use log::info;
use screenpipe_audio::create_whisper_channel;
use screenpipe_audio::default_input_device;
use screenpipe_audio::default_output_device;
use screenpipe_audio::list_audio_devices;
use screenpipe_audio::parse_device_spec;
use screenpipe_audio::record_and_transcribe;
Expand All @@ -12,8 +15,12 @@ use std::time::Duration;
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[clap(short, long, help = "Audio device name")]
audio_device: Option<String>,
#[clap(
short,
long,
help = "Audio device name (can be specified multiple times)"
)]
audio_device: Vec<String>,

#[clap(long, help = "List available audio devices")]
list_audio_devices: bool,
Expand Down Expand Up @@ -45,51 +52,58 @@ fn main() -> Result<()> {
return Ok(());
}

let device = match args.audio_device {
Some(d) => parse_device_spec(&d).unwrap(),
None => {
if devices.is_empty() {
return Err(anyhow!("No audio input devices found"));
}
eprintln!("No audio device specified. Available devices are:");
print_devices(&devices);
eprintln!("\nPlease specify one or more devices with:");
eprintln!(
" {} --audio-device \"Device Name (input)\" [--audio-device \"Another Device (output)\"]",
std::env::args().next().unwrap()
);
return Err(anyhow!("No device specified"));
}
let devices = if args.audio_device.is_empty() {
vec![default_input_device()?, default_output_device()?]
} else {
args.audio_device
.iter()
.map(|d| parse_device_spec(d))
.collect::<Result<Vec<_>>>()?
};

let (result_tx, result_rx) = mpsc::channel();
if devices.is_empty() {
return Err(anyhow!("No audio input devices found"));
}

let chunk_duration = Duration::from_secs(30);
let output_path = PathBuf::from("output.wav");
// Spawn a thread to handle the recording and transcription
let recording_thread = thread::spawn(move || {
record_and_transcribe(&device, chunk_duration, result_tx, output_path)
});
let (whisper_sender, whisper_receiver) = create_whisper_channel()?;

// Spawn threads for each device
let recording_threads: Vec<_> = devices
.into_iter()
.enumerate()
.map(|(i, device)| {
let whisper_sender = whisper_sender.clone();
let output_path = output_path.with_file_name(format!("output_{}.wav", i));
thread::spawn(move || {
record_and_transcribe(&device, chunk_duration, output_path, whisper_sender)
})
})
.collect();

// Main loop to receive and print transcriptions
loop {
match result_rx.recv_timeout(Duration::from_secs(5)) {
match whisper_receiver.recv_timeout(Duration::from_secs(5)) {
Ok(result) => {
info!("Transcription: {}", result.text);
info!("Transcription: {:?}", result);
}
Err(mpsc::RecvTimeoutError::Timeout) => {
Err(crossbeam::channel::RecvTimeoutError::Timeout) => {
// No transcription received in 5 seconds, continue waiting
continue;
}
Err(mpsc::RecvTimeoutError::Disconnected) => {
// Sender has been dropped, recording is complete
Err(crossbeam::channel::RecvTimeoutError::Disconnected) => {
// All senders have been dropped, recording is complete
break;
}
}
}

// Wait for the recording thread to finish
let file_path = recording_thread.join().unwrap()?;
println!("Recording complete: {:?}", file_path);
// Wait for all recording threads to finish
for (i, thread) in recording_threads.into_iter().enumerate() {
let file_path = thread.join().unwrap()?;
println!("Recording {} complete: {:?}", i, file_path);
}

Ok(())
}
53 changes: 25 additions & 28 deletions screenpipe-audio/src/core.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
use anyhow::{anyhow, Result};
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
use cpal::{FromSample, Sample};
use crossbeam::channel::Sender;
use hound::WavSpec;
use log::{error, info};
use serde::Serialize;
use std::fmt;
use std::fs::File;
use std::path::PathBuf;
use std::sync::mpsc::Sender;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use std::{io::BufWriter, thread};

use crate::stt::stt;
use crate::AudioInput;

pub struct AudioCaptureResult {
pub text: String,
Expand Down Expand Up @@ -83,8 +83,8 @@ pub fn parse_device_spec(name: &str) -> Result<DeviceSpec> {
pub fn record_and_transcribe(
device_spec: &DeviceSpec,
duration: Duration,
result_tx: Sender<AudioCaptureResult>,
output_path: PathBuf,
whisper_sender: Sender<AudioInput>,
) -> Result<PathBuf> {
let host = match device_spec {
#[cfg(target_os = "macos")]
Expand All @@ -94,7 +94,7 @@ pub fn record_and_transcribe(

info!("device: {:?}", device_spec.to_string());

let device = if device_spec.to_string() == "default" {
let audio_device = if device_spec.to_string() == "default" {
host.default_input_device()
} else {
host.input_devices()?.find(|x| {
Expand All @@ -108,10 +108,14 @@ pub fn record_and_transcribe(
.unwrap_or(false)
})
}
.ok_or_else(|| anyhow!("Device not found"))?;
.ok_or_else(|| anyhow!("Audio device not found"))?;

let config = device.default_input_config()?;
info!("Recording device: {}, Config: {:?}", device.name()?, config);
let config = audio_device.default_input_config()?;
info!(
"Recording audio device: {}, Config: {:?}",
audio_device.name()?,
config
);

let spec = wav_spec_from_config(&config);
let writer = hound::WavWriter::create(&output_path, spec)?;
Expand All @@ -120,25 +124,25 @@ pub fn record_and_transcribe(
let err_fn = |err| error!("An error occurred on the audio stream: {}", err);

let stream = match config.sample_format() {
cpal::SampleFormat::I8 => device.build_input_stream(
cpal::SampleFormat::I8 => audio_device.build_input_stream(
&config.into(),
move |data, _: &_| write_input_data::<i8, i8>(data, &writer_2),
err_fn,
None,
)?,
cpal::SampleFormat::I16 => device.build_input_stream(
cpal::SampleFormat::I16 => audio_device.build_input_stream(
&config.into(),
move |data, _: &_| write_input_data::<i16, i16>(data, &writer_2),
err_fn,
None,
)?,
cpal::SampleFormat::I32 => device.build_input_stream(
cpal::SampleFormat::I32 => audio_device.build_input_stream(
&config.into(),
move |data, _: &_| write_input_data::<i32, i32>(data, &writer_2),
err_fn,
None,
)?,
cpal::SampleFormat::F32 => device.build_input_stream(
cpal::SampleFormat::F32 => audio_device.build_input_stream(
&config.into(),
move |data, _: &_| write_input_data::<f32, f32>(data, &writer_2),
err_fn,
Expand Down Expand Up @@ -169,15 +173,11 @@ pub fn record_and_transcribe(
if let Some(writer) = writer_guard.as_mut() {
writer.flush()?;

// Transcribe the current audio chunk
match stt(output_path.to_str().unwrap()) {
Ok(transcription) => {
result_tx.send(AudioCaptureResult {
text: transcription,
})?;
}
Err(e) => error!("Transcription failed: {}", e),
}
// Send the file path to the whisper channel
whisper_sender.send(AudioInput {
path: output_path.to_str().unwrap().to_string(),
device: device_spec.to_string(),
})?;
}
}

Expand All @@ -195,14 +195,11 @@ pub fn record_and_transcribe(
if let Some(writer) = writer_guard.as_mut() {
writer.flush()?;

match stt(output_path.to_str().unwrap()) {
Ok(transcription) => {
result_tx.send(AudioCaptureResult {
text: transcription,
})?;
}
Err(e) => error!("Final transcription failed: {}", e),
}
// Send the file path to the whisper channel
whisper_sender.send(AudioInput {
path: output_path.to_str().unwrap().to_string(),
device: device_spec.to_string(),
})?;
}
}

Expand Down
2 changes: 1 addition & 1 deletion screenpipe-audio/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ mod core;
mod multilingual;
mod pcm_decode;
mod stt;

pub use core::{
default_input_device, default_output_device, list_audio_devices, parse_device_spec,
record_and_transcribe, AudioCaptureResult, AudioDevice, DeviceSpec,
};
pub use stt::{create_whisper_channel, AudioInput, TranscriptionResult};
4 changes: 3 additions & 1 deletion screenpipe-audio/src/multilingual.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use candle_transformers::models::whisper::SOT_TOKEN;
use log::info;
use tokenizers::Tokenizer;

use crate::stt::Model;

const LANGUAGES: [(&str, &str); 99] = [
("en", "english"),
("zh", "chinese"),
Expand Down Expand Up @@ -107,7 +109,7 @@ const LANGUAGES: [(&str, &str); 99] = [

/// Returns the token id for the selected language.
pub fn detect_language(
model: &mut super::stt::Model,
model: &mut Model,
tokenizer: &Tokenizer,
mel: &Tensor,
) -> Result<u32> {
Expand Down
Loading

0 comments on commit 30e3813

Please sign in to comment.