Skip to content

Commit

Permalink
#29 fix
Browse files Browse the repository at this point in the history
  • Loading branch information
louis030195 committed Jul 10, 2024
1 parent 15018f4 commit c7d644e
Show file tree
Hide file tree
Showing 8 changed files with 169 additions and 93 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ candle = { package = "candle-core", version = "0.6.0" }
candle-nn = { package = "candle-nn", version = "0.6.0" }
candle-transformers = { package = "candle-transformers", version = "0.6.0" }
tokenizers = "0.19.1"

tracing = "0.1.37"
6 changes: 6 additions & 0 deletions screenpipe-audio/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ env_logger = "0.10"
# File
tempfile = "3"

# Tracing
tracing = { workspace = true }

# Concurrency
crossbeam = "0.8"

[dev-dependencies]
tempfile = "3.3.0"

Expand Down
47 changes: 18 additions & 29 deletions screenpipe-audio/src/core.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
use anyhow::{anyhow, Result};
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
use cpal::{FromSample, Sample};
use crossbeam::channel::Sender;
use hound::WavSpec;
use log::{error, info};
use serde::Serialize;
use std::fmt;
use std::fs::File;
use std::path::PathBuf;
use std::sync::mpsc::Sender;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use std::{io::BufWriter, thread};

use crate::stt::stt;

pub struct AudioCaptureResult {
pub text: String,
}
Expand Down Expand Up @@ -83,8 +81,8 @@ pub fn parse_device_spec(name: &str) -> Result<DeviceSpec> {
pub fn record_and_transcribe(
device_spec: &DeviceSpec,
duration: Duration,
result_tx: Sender<AudioCaptureResult>,
output_path: PathBuf,
whisper_sender: Sender<String>,
) -> Result<PathBuf> {
let host = match device_spec {
#[cfg(target_os = "macos")]
Expand All @@ -94,7 +92,7 @@ pub fn record_and_transcribe(

info!("device: {:?}", device_spec.to_string());

let device = if device_spec.to_string() == "default" {
let audio_device = if device_spec.to_string() == "default" {
host.default_input_device()
} else {
host.input_devices()?.find(|x| {
Expand All @@ -108,10 +106,14 @@ pub fn record_and_transcribe(
.unwrap_or(false)
})
}
.ok_or_else(|| anyhow!("Device not found"))?;
.ok_or_else(|| anyhow!("Audio device not found"))?;

let config = device.default_input_config()?;
info!("Recording device: {}, Config: {:?}", device.name()?, config);
let config = audio_device.default_input_config()?;
info!(
"Recording audio device: {}, Config: {:?}",
audio_device.name()?,
config
);

let spec = wav_spec_from_config(&config);
let writer = hound::WavWriter::create(&output_path, spec)?;
Expand All @@ -120,25 +122,25 @@ pub fn record_and_transcribe(
let err_fn = |err| error!("An error occurred on the audio stream: {}", err);

let stream = match config.sample_format() {
cpal::SampleFormat::I8 => device.build_input_stream(
cpal::SampleFormat::I8 => audio_device.build_input_stream(
&config.into(),
move |data, _: &_| write_input_data::<i8, i8>(data, &writer_2),
err_fn,
None,
)?,
cpal::SampleFormat::I16 => device.build_input_stream(
cpal::SampleFormat::I16 => audio_device.build_input_stream(
&config.into(),
move |data, _: &_| write_input_data::<i16, i16>(data, &writer_2),
err_fn,
None,
)?,
cpal::SampleFormat::I32 => device.build_input_stream(
cpal::SampleFormat::I32 => audio_device.build_input_stream(
&config.into(),
move |data, _: &_| write_input_data::<i32, i32>(data, &writer_2),
err_fn,
None,
)?,
cpal::SampleFormat::F32 => device.build_input_stream(
cpal::SampleFormat::F32 => audio_device.build_input_stream(
&config.into(),
move |data, _: &_| write_input_data::<f32, f32>(data, &writer_2),
err_fn,
Expand Down Expand Up @@ -169,15 +171,8 @@ pub fn record_and_transcribe(
if let Some(writer) = writer_guard.as_mut() {
writer.flush()?;

// Transcribe the current audio chunk
match stt(output_path.to_str().unwrap()) {
Ok(transcription) => {
result_tx.send(AudioCaptureResult {
text: transcription,
})?;
}
Err(e) => error!("Transcription failed: {}", e),
}
// Send the file path to the whisper channel
whisper_sender.send(output_path.to_str().unwrap().to_string())?;
}
}

Expand All @@ -195,14 +190,8 @@ pub fn record_and_transcribe(
if let Some(writer) = writer_guard.as_mut() {
writer.flush()?;

match stt(output_path.to_str().unwrap()) {
Ok(transcription) => {
result_tx.send(AudioCaptureResult {
text: transcription,
})?;
}
Err(e) => error!("Final transcription failed: {}", e),
}
// Send the file path to the whisper channel
whisper_sender.send(output_path.to_str().unwrap().to_string())?;
}
}

Expand Down
2 changes: 1 addition & 1 deletion screenpipe-audio/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ mod core;
mod multilingual;
mod pcm_decode;
mod stt;

pub use core::{
default_input_device, default_output_device, list_audio_devices, parse_device_spec,
record_and_transcribe, AudioCaptureResult, AudioDevice, DeviceSpec,
};
pub use stt::create_whisper_channel;
4 changes: 3 additions & 1 deletion screenpipe-audio/src/multilingual.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use candle_transformers::models::whisper::SOT_TOKEN;
use log::info;
use tokenizers::Tokenizer;

use crate::stt::Model;

const LANGUAGES: [(&str, &str); 99] = [
("en", "english"),
("zh", "chinese"),
Expand Down Expand Up @@ -107,7 +109,7 @@ const LANGUAGES: [(&str, &str); 99] = [

/// Returns the token id for the selected language.
pub fn detect_language(
model: &mut super::stt::Model,
model: &mut Model,
tokenizer: &Tokenizer,
mel: &Tensor,
) -> Result<u32> {
Expand Down
Loading

0 comments on commit c7d644e

Please sign in to comment.