Skip to content

Commit

Permalink
chore: memory optimisation
Browse files Browse the repository at this point in the history
  • Loading branch information
louis030195 committed Sep 20, 2024
1 parent ddf4023 commit 7f54487
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 71 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ resolver = "2"


[workspace.package]
version = "0.1.84"
version = "0.1.85"
authors = ["louis030195 <[email protected]>"]
description = ""
repository = "https://github.com/mediar-ai/screenpipe"
Expand Down
2 changes: 1 addition & 1 deletion screenpipe-app-tauri/src-tauri/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "screenpipe-app"
version = "0.2.56"
version = "0.2.57"
description = ""
authors = ["you"]
license = ""
Expand Down
3 changes: 3 additions & 0 deletions screenpipe-audio/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ reqwest = { workspace = true }

screenpipe-core = { path = "../screenpipe-core" }

# crossbeam
crossbeam = { workspace = true }

[target.'cfg(target_os = "windows")'.dependencies]
ort = { version = "2.0.0-rc.5", features = ["download-binaries", "copy-dylibs", "directml", "cuda"] }
esaxx-rs = "0.1.10"
Expand Down
10 changes: 5 additions & 5 deletions screenpipe-audio/src/bin/screenpipe-audio-forever.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ async fn main() -> Result<()> {
}

let chunk_duration = Duration::from_secs_f32(args.audio_chunk_duration);
let (whisper_sender, mut whisper_receiver, _) = create_whisper_channel(
let (whisper_sender, whisper_receiver, _) = create_whisper_channel(
Arc::new(AudioTranscriptionEngine::WhisperDistilLargeV3),
VadEngineEnum::Silero, // Or VadEngineEnum::WebRtc, hardcoded for now
args.deepgram_api_key,
Expand Down Expand Up @@ -121,12 +121,12 @@ async fn main() -> Result<()> {

// Main loop to receive and print transcriptions
loop {
match whisper_receiver.recv().await {
Some(result) => {
match whisper_receiver.recv() {
Ok(result) => {
info!("Transcription: {:?}", result);
}
None => {
eprintln!("Error receiving transcription");
Err(e) => {
eprintln!("Error receiving transcription: {:?}", e);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion screenpipe-audio/src/bin/screenpipe-audio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ async fn main() -> Result<()> {

let chunk_duration = Duration::from_secs(10);
let output_path = PathBuf::from("output.mp4");
let (whisper_sender, mut whisper_receiver, _) = create_whisper_channel(
let (whisper_sender, whisper_receiver, _) = create_whisper_channel(
Arc::new(AudioTranscriptionEngine::WhisperDistilLargeV3),
VadEngineEnum::WebRtc, // Or VadEngineEnum::WebRtc, hardcoded for now
deepgram_api_key,
Expand Down
3 changes: 1 addition & 2 deletions screenpipe-audio/src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use std::{fmt, thread};
use tokio::sync::mpsc::UnboundedSender;
use tokio::sync::Mutex;

#[derive(Clone, Debug, PartialEq)]
Expand Down Expand Up @@ -158,7 +157,7 @@ async fn get_device_and_config(
pub async fn record_and_transcribe(
audio_device: Arc<AudioDevice>,
duration: Duration,
whisper_sender: UnboundedSender<AudioInput>,
whisper_sender: crossbeam::channel::Sender<AudioInput>,
is_running: Arc<AtomicBool>,
) -> Result<()> {
let (cpal_audio_device, config) = get_device_and_config(&audio_device).await?;
Expand Down
114 changes: 61 additions & 53 deletions screenpipe-audio/src/stt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ use log::{debug, error, info};
use objc::rc::autoreleasepool;
use rand::{distributions::Distribution, SeedableRng};
use tokenizers::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};

use candle_transformers::models::whisper::{self as m, audio, Config};
use rubato::{
Expand Down Expand Up @@ -826,19 +825,19 @@ pub async fn create_whisper_channel(
output_path: &PathBuf,
vad_sensitivity: VadSensitivity,
) -> Result<(
UnboundedSender<AudioInput>,
UnboundedReceiver<TranscriptionResult>,
crossbeam::channel::Sender<AudioInput>,
crossbeam::channel::Receiver<TranscriptionResult>,
Arc<AtomicBool>, // Shutdown flag
)> {
let whisper_model = WhisperModel::new(audio_transcription_engine.clone())?;
let (input_sender, mut input_receiver): (
UnboundedSender<AudioInput>,
UnboundedReceiver<AudioInput>,
) = unbounded_channel();
let (input_sender, input_receiver): (
crossbeam::channel::Sender<AudioInput>,
crossbeam::channel::Receiver<AudioInput>,
) = crossbeam::channel::bounded(100);
let (output_sender, output_receiver): (
UnboundedSender<TranscriptionResult>,
UnboundedReceiver<TranscriptionResult>,
) = unbounded_channel();
crossbeam::channel::Sender<TranscriptionResult>,
crossbeam::channel::Receiver<TranscriptionResult>,
) = crossbeam::channel::bounded(100);
let mut vad_engine: Box<dyn VadEngine + Send> = match vad_engine {
VadEngineEnum::WebRtc => Box::new(WebRtcVad::new()),
VadEngineEnum::Silero => Box::new(SileroVad::new().await?),
Expand All @@ -857,18 +856,46 @@ pub async fn create_whisper_channel(
}
debug!("Waiting for input from input_receiver");

tokio::select! {
Some(input) = input_receiver.recv() => {
debug!("Received input from input_receiver");
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time went backwards")
.as_secs();

let transcription_result = if cfg!(target_os = "macos") {
#[cfg(target_os = "macos")]
{
autoreleasepool(|| {
crossbeam::select! {
recv(input_receiver) -> input_result => {
match input_result {
Ok(input) => {
debug!("Received input from input_receiver");
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time went backwards")
.as_secs();

let transcription_result = if cfg!(target_os = "macos") {
#[cfg(target_os = "macos")]
{
autoreleasepool(|| {
match stt_sync(&input, &whisper_model, audio_transcription_engine.clone(), vad_engine.clone(), deepgram_api_key.clone(), &output_path) {
Ok((transcription, path)) => TranscriptionResult {
input: input.clone(),
transcription: Some(transcription),
path,
timestamp,
error: None,
},
Err(e) => {
error!("STT error for input {}: {:?}", input.device, e);
TranscriptionResult {
input: input.clone(),
transcription: None,
path: "".to_string(),
timestamp,
error: Some(e.to_string()),
}
},
}
})
}
#[cfg(not(target_os = "macos"))]
{
unreachable!("This code should not be reached on non-macOS platforms")
}
} else {
match stt_sync(&input, &whisper_model, audio_transcription_engine.clone(), vad_engine.clone(), deepgram_api_key.clone(), &output_path) {
Ok((transcription, path)) => TranscriptionResult {
input: input.clone(),
Expand All @@ -888,39 +915,20 @@ pub async fn create_whisper_channel(
}
},
}
})
};

if output_sender.send(transcription_result).is_err() {
break;
}
},
Err(e) => {
error!("Error receiving input: {:?}", e);
// Depending on the error type, you might want to break the loop or continue
// For now, we'll continue to the next iteration
continue;
}
#[cfg(not(target_os = "macos"))]
{
unreachable!("This code should not be reached on non-macOS platforms")
}
} else {
match stt_sync(&input, &whisper_model, audio_transcription_engine.clone(), vad_engine.clone(), deepgram_api_key.clone(), &output_path) {
Ok((transcription, path)) => TranscriptionResult {
input: input.clone(),
transcription: Some(transcription),
path,
timestamp,
error: None,
},
Err(e) => {
error!("STT error for input {}: {:?}", input.device, e);
TranscriptionResult {
input: input.clone(),
transcription: None,
path: "".to_string(),
timestamp,
error: Some(e.to_string()),
}
},
}
};

if output_sender.send(transcription_result).is_err() {
break;
}
}
else => break,
},
}
}
// Cleanup code here (if needed)
Expand Down
17 changes: 9 additions & 8 deletions screenpipe-server/src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::runtime::Handle;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::task::JoinHandle;

pub async fn start_continuous_recording(
Expand Down Expand Up @@ -47,12 +46,14 @@ pub async fn start_continuous_recording(
) -> Result<()> {
let (whisper_sender, whisper_receiver, whisper_shutdown_flag) = if audio_disabled {
// Create a dummy channel if no audio devices are available, e.g. audio disabled
let (input_sender, _): (UnboundedSender<AudioInput>, UnboundedReceiver<AudioInput>) =
unbounded_channel();
let (input_sender, _): (
crossbeam::channel::Sender<AudioInput>,
crossbeam::channel::Receiver<AudioInput>,
) = crossbeam::channel::bounded(100);
let (_, output_receiver): (
UnboundedSender<TranscriptionResult>,
UnboundedReceiver<TranscriptionResult>,
) = unbounded_channel();
crossbeam::channel::Sender<TranscriptionResult>,
crossbeam::channel::Receiver<TranscriptionResult>,
) = crossbeam::channel::bounded(100);
(
input_sender,
output_receiver,
Expand Down Expand Up @@ -252,8 +253,8 @@ async fn record_video(
async fn record_audio(
db: Arc<DatabaseManager>,
chunk_duration: Duration,
whisper_sender: UnboundedSender<AudioInput>,
mut whisper_receiver: UnboundedReceiver<TranscriptionResult>,
whisper_sender: crossbeam::channel::Sender<AudioInput>,
whisper_receiver: crossbeam::channel::Receiver<TranscriptionResult>,
audio_devices_control: Arc<SegQueue<(AudioDevice, DeviceControl)>>,
friend_wearable_uid: Option<String>,
audio_transcription_engine: Arc<AudioTranscriptionEngine>,
Expand Down

0 comments on commit 7f54487

Please sign in to comment.