Skip to content

Commit

Permalink
Merge pull request #330 from kerosina/patch-fix-tokio-errors
Browse files Browse the repository at this point in the history
Remove all references to blocking reqwest, replacing with async reqwest
  • Loading branch information
louis030195 committed Sep 18, 2024
2 parents a9df45b + 5f593bd commit 68e6c94
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 42 deletions.
57 changes: 44 additions & 13 deletions screenpipe-audio/src/stt.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{
path::PathBuf,
sync::Arc,
sync::{Arc, Mutex},
time::{SystemTime, UNIX_EPOCH},
};

Expand Down Expand Up @@ -425,16 +425,15 @@ enum Task {
Translate,
}

use reqwest::blocking::Client;
use reqwest::Client;
use serde_json::Value;

// Replace the get_deepgram_api_key function with this:
fn get_deepgram_api_key() -> String {
"7ed2a159a094337b01fd8178b914b7ae0e77822d".to_string()
}

// TODO: this should use async reqwest not blocking, cause crash issue because all our code is async
fn transcribe_with_deepgram(
async fn transcribe_with_deepgram(
api_key: &str,
audio_data: &[f32],
device: &str,
Expand Down Expand Up @@ -469,10 +468,10 @@ fn transcribe_with_deepgram(
.body(wav_data)
.send();

match response {
match response.await {
Ok(resp) => {
debug!("received response from deepgram api");
match resp.json::<Value>() {
match resp.json::<Value>().await {
Ok(result) => {
debug!("successfully parsed json response");
if let Some(err_code) = result.get("err_code") {
Expand Down Expand Up @@ -518,7 +517,37 @@ fn transcribe_with_deepgram(
}
}

pub fn stt(
pub fn stt_sync(
audio_input: &AudioInput,
whisper_model: &WhisperModel,
audio_transcription_engine: Arc<AudioTranscriptionEngine>,
vad_engine: Arc<Mutex<Box<dyn VadEngine + Send>>>, // Changed type here
deepgram_api_key: Option<String>,
output_path: &PathBuf,
) -> Result<(String, String)> {
let audio_input = audio_input.clone();
let whisper_model = whisper_model.clone();
let output_path = output_path.clone();
let vad_engine = vad_engine.clone(); // Clone the Arc to move into the closure

let handle = std::thread::spawn(move || {
let rt = tokio::runtime::Runtime::new().unwrap();
let mut vad_engine_guard = vad_engine.lock().unwrap();

rt.block_on(stt(
&audio_input,
&whisper_model,
audio_transcription_engine,
&mut **vad_engine_guard, // Obtain &mut dyn VadEngine
deepgram_api_key,
&output_path,
))
});

handle.join().unwrap()
}

pub async fn stt(
audio_input: &AudioInput,
whisper_model: &WhisperModel,
audio_transcription_engine: Arc<AudioTranscriptionEngine>,
Expand Down Expand Up @@ -648,7 +677,9 @@ pub fn stt(
&speech_frames,
&audio_input.device.name,
audio_input.sample_rate,
) {
)
.await
{
Ok(transcription) => Ok(transcription),
Err(e) => {
error!(
Expand Down Expand Up @@ -838,11 +869,11 @@ pub async fn create_whisper_channel(
UnboundedSender<TranscriptionResult>,
UnboundedReceiver<TranscriptionResult>,
) = unbounded_channel();
let mut vad_engine: Box<dyn VadEngine + Send> = match vad_engine {
let vad_engine: Box<dyn VadEngine + Send> = match vad_engine {
VadEngineEnum::WebRtc => Box::new(WebRtcVad::new()),
VadEngineEnum::Silero => Box::new(SileroVad::new()?),
VadEngineEnum::Silero => Box::new(SileroVad::new().await?),
};

let vad_engine = Arc::new(Mutex::new(vad_engine));
let shutdown_flag = Arc::new(AtomicBool::new(false));
let shutdown_flag_clone = shutdown_flag.clone();
let output_path = output_path.clone();
Expand All @@ -867,7 +898,7 @@ pub async fn create_whisper_channel(
#[cfg(target_os = "macos")]
{
autoreleasepool(|| {
match stt(&input, &whisper_model, audio_transcription_engine.clone(), &mut *vad_engine, deepgram_api_key.clone(), &output_path) {
match stt_sync(&input, &whisper_model, audio_transcription_engine.clone(), vad_engine.clone(), deepgram_api_key.clone(), &output_path) {
Ok((transcription, path)) => TranscriptionResult {
input: input.clone(),
transcription: Some(transcription),
Expand All @@ -893,7 +924,7 @@ pub async fn create_whisper_channel(
unreachable!("This code should not be reached on non-macOS platforms")
}
} else {
match stt(&input, &whisper_model, audio_transcription_engine.clone(), &mut *vad_engine, deepgram_api_key.clone(), &output_path) {
match stt_sync(&input, &whisper_model, audio_transcription_engine.clone(), vad_engine.clone(), deepgram_api_key.clone(), &output_path) {
Ok((transcription, path)) => TranscriptionResult {
input: input.clone(),
transcription: Some(transcription),
Expand Down
16 changes: 8 additions & 8 deletions screenpipe-audio/src/vad_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub enum VadEngineEnum {
Silero,
}

pub trait VadEngine {
pub trait VadEngine: Send {
fn is_voice_segment(&mut self, audio_chunk: &[f32]) -> anyhow::Result<bool>;
}

Expand Down Expand Up @@ -44,9 +44,9 @@ pub struct SileroVad {
}

impl SileroVad {
pub fn new() -> anyhow::Result<Self> {
pub async fn new() -> anyhow::Result<Self> {
debug!("Initializing SileroVad...");
let model_path = Self::download_model()?;
let model_path = Self::download_model().await?;
debug!("SileroVad Model downloaded to: {:?}", model_path);
let vad = Vad::new(model_path, 16000).map_err(|e| {
debug!("SileroVad Error creating Vad: {}", e);
Expand All @@ -56,12 +56,12 @@ impl SileroVad {
Ok(Self { vad })
}

fn download_model() -> anyhow::Result<PathBuf> {
async fn download_model() -> anyhow::Result<PathBuf> {
debug!("Downloading SileroVAD model...");
let url =
"https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx";
let response = reqwest::blocking::get(url)?;
let model_data = response.bytes()?;
let response = reqwest::get(url).await?;
let model_data = response.bytes().await?;

let path = std::env::temp_dir().join("silero_vad.onnx");
let mut file = std::fs::File::create(&path)?;
Expand Down Expand Up @@ -94,11 +94,11 @@ impl VadEngine for SileroVad {
}
}

pub fn create_vad_engine(engine: VadEngineEnum) -> anyhow::Result<Box<dyn VadEngine>> {
pub async fn create_vad_engine(engine: VadEngineEnum) -> anyhow::Result<Box<dyn VadEngine>> {
match engine {
VadEngineEnum::WebRtc => Ok(Box::new(WebRtcVad::new())),
VadEngineEnum::Silero => {
let silero_vad = SileroVad::new()?;
let silero_vad = SileroVad::new().await?;
Ok(Box::new(silero_vad))
}
}
Expand Down
3 changes: 2 additions & 1 deletion screenpipe-integrations/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ chrono = { version = "0.4", features = ["serde"] }
log = "0.4"
tempfile = "3.2"
anyhow = "1.0"
chrono-tz = "0.8"
chrono-tz = "0.8"
mime_guess = "2.0.5"
78 changes: 58 additions & 20 deletions screenpipe-integrations/src/unstructured_ocr.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
use image::{DynamicImage, ImageEncoder, codecs::png::PngEncoder};
use anyhow::{anyhow, Result};
use image::{codecs::png::PngEncoder, DynamicImage, ImageEncoder};
use log::error;
use reqwest::multipart::{Form, Part};
use reqwest::Client;
use serde_json;
use serde_json::Value;
use std::collections::HashMap;
use std::env;
use std::io::Cursor;
use std::io::Read;
use std::io::Write;
use tokio::time::{timeout, Duration};
use reqwest::blocking::Client;
use serde_json::Value;
use std::path::PathBuf;
use tempfile::NamedTempFile;
use anyhow::{Result, anyhow};
use log::error;
use tokio::time::{timeout, Duration};

pub async fn perform_ocr_cloud(image: &DynamicImage) -> Result<(String, String, Option<f64>)> {
let api_key = match env::var("UNSTRUCTURED_API_KEY") {
Expand Down Expand Up @@ -44,12 +46,17 @@ pub async fn perform_ocr_cloud(image: &DynamicImage) -> Result<(String, String,
.text("coordinates", "true");

let client = reqwest::Client::new();
let response = match timeout(Duration::from_secs(180), client
.post(&api_url)
.header("accept", "application/json")
.header("unstructured-api-key", &api_key)
.multipart(form)
.send()).await {
let response = match timeout(
Duration::from_secs(180),
client
.post(&api_url)
.header("accept", "application/json")
.header("unstructured-api-key", &api_key)
.multipart(form)
.send(),
)
.await
{
Ok(Ok(response)) => response,
Ok(Err(e)) => return Err(anyhow!("Request error: {}", e)),
Err(_) => return Err(anyhow!("Request timed out")),
Expand Down Expand Up @@ -89,7 +96,7 @@ fn calculate_overall_confidence(parsed_response: &Vec<HashMap<String, serde_json
}
}

pub fn unstructured_chunking(text: &str) -> Result<Vec<String>> {
pub async fn unstructured_chunking(text: &str) -> Result<Vec<String>> {
let client = Client::new();
let api_key = match env::var("UNSTRUCTURED_API_KEY") {
Ok(key) => key,
Expand All @@ -100,32 +107,63 @@ pub fn unstructured_chunking(text: &str) -> Result<Vec<String>> {
};
// Create temporary file
let mut temp_file = NamedTempFile::new().map_err(|e| anyhow!(e.to_string()))?;
temp_file.write_all(text.as_bytes()).map_err(|e| anyhow!(e.to_string()))?;
temp_file
.write_all(text.as_bytes())
.map_err(|e| anyhow!(e.to_string()))?;

// Prepare request
let form = reqwest::blocking::multipart::Form::new()
.file("files", temp_file.path()).map_err(|e| anyhow!(e.to_string()))?
let form = reqwest::multipart::Form::new()
.part("files", {
let mut bytes = vec![];
temp_file.read_to_end(&mut bytes)?;

let path = PathBuf::from(temp_file.path());

let file_name = path
.file_name()
.ok_or(anyhow!("Couldn't send files to unstructuredapp API"))?
.to_string_lossy()
.into_owned();

let mime_type = mime_guess::from_path(path)
.first()
.ok_or(anyhow!("Couldn't determine file's MIME type."))?
.essence_str()
.to_owned();

let part = Part::bytes(bytes)
.file_name(file_name)
.mime_str(&mime_type)?;

part
})
.text("chunking_strategy", "by_similarity")
.text("similarity_threshold", "0.5")
.text("max_characters", "300")
.text("output_format", "application/json");

// Send request
let response = client.post("https://api.unstructuredapp.io/general/v0/general")
let response = client
.post("https://api.unstructuredapp.io/general/v0/general")
.header("accept", "application/json")
.header("unstructured-api-key", &api_key)
.multipart(form)
.send()
.await
.map_err(|e| anyhow!(e.to_string()))?;

if response.status().is_success() {
let chunks: Vec<Value> = response.json().map_err(|e| anyhow!(e.to_string()))?;
let texts: Vec<String> = chunks.iter()
let chunks = response
.json::<Vec<Value>>()
.await
.map_err(|e| anyhow!(e.to_string()))?;
let texts: Vec<String> = chunks
.iter()
.filter_map(|chunk| chunk["text"].as_str().map(String::from))
.collect();

Ok(texts)
} else {
Err(anyhow!("Error: {}", response.status()))
}
}
}

0 comments on commit 68e6c94

Please sign in to comment.