Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

POC: DeepFilter #1017

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ crossbeam = "0.8.4"
image = "0.25"
reqwest = { version = "0.11", features = ["blocking", "multipart", "json"] }
criterion = { version = "0.5.1", features = ["async_tokio"] }
vcpkg = "0.2"
cc = "1.0"

once_cell = "1.20.2"

Expand Down
8 changes: 7 additions & 1 deletion screenpipe-audio/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,19 @@ crossbeam = { workspace = true }
# Directories
dirs = "5.0.1"

deep_filter = { git = "https://github.com/EzraEllette/DeepFilterNet.git", features = [
"transforms",
"tract",
] }
lazy_static = { version = "1.4.0" }
realfft = "3.4.0"
regex = "1.11.0"
ndarray = "0.16"
tract-core = { version = "^0.21.4" }
ort = "=2.0.0-rc.6"
knf-rs = { git = "https://github.com/Neptune650/knf-rs.git" }
knf-rs = { git = "https://github.com/Neptune650/knf-rs.git", branch = "main" }
ort-sys = "=2.0.0-rc.8"
parking_lot = { version = "0.12.3", features = ["send_guard"] }

[target.'cfg(target_os = "windows")'.dependencies]
ort = { version = "=2.0.0-rc.6", features = [
Expand Down
Binary file not shown.
Binary file not shown.
103 changes: 0 additions & 103 deletions screenpipe-audio/src/audio_processing.rs

This file was deleted.

17 changes: 17 additions & 0 deletions screenpipe-audio/src/audio_processing/audio_to_mono.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
pub fn audio_to_mono(audio: &[f32], channels: u16) -> Vec<f32> {
let mut mono_samples = Vec::with_capacity(audio.len() / channels as usize);

// Iterate over the audio slice in chunks, each containing `channels` samples
for chunk in audio.chunks(channels as usize) {
// Sum the samples from all channels in the current chunk
let sum: f32 = chunk.iter().sum();

// Calculate the averagechannelsono sample
let mono_sample = sum / channels as f32;

// Store the computed mono sample
mono_samples.push(mono_sample);
}

mono_samples
}
8 changes: 8 additions & 0 deletions screenpipe-audio/src/audio_processing/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
mod audio_to_mono;
mod noise_reduction;
mod normalize_v2;

pub use audio_to_mono::audio_to_mono;
pub use noise_reduction::NoiseFilter;
pub use noise_reduction::NoiseReductionModel;
pub use normalize_v2::normalize_v2;
228 changes: 228 additions & 0 deletions screenpipe-audio/src/audio_processing/noise_reduction.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
use anyhow::Result;
use df::tract::{DfParams, DfTract, RuntimeParams};
use df::transforms::resample;
use dirs;
use lazy_static::lazy_static;
use log::{debug, info};
use std::path::PathBuf;
use std::sync::Once;
use tokio::sync::Mutex;
use tract_core::ndarray::{Array2, ArrayD, Axis};

lazy_static! {
static ref V3_MODEL_PATH: Mutex<Option<PathBuf>> = Mutex::new(None);
static ref V3LL_MODEL_PATH: Mutex<Option<PathBuf>> = Mutex::new(None);
}

static DOWNLOAD_V3_ONCE: Once = Once::new();
static DOWNLOAD_V3LL_ONCE: Once = Once::new();

#[derive(Clone, Copy)]
pub enum NoiseReductionModel {
V3,
V3LL,
}

impl std::fmt::Display for NoiseReductionModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
NoiseReductionModel::V3 => write!(f, "DeepFilterNet3_onnx.tar.gz"),
NoiseReductionModel::V3LL => write!(f, "DeepFilterNet3_ll_onnx.tar.gz"),
}
}
}

impl NoiseReductionModel {
pub fn filename(&self) -> &'static str {
match self {
NoiseReductionModel::V3 => "DeepFilterNet3_onnx.tar.gz",
NoiseReductionModel::V3LL => "DeepFilterNet3_ll_onnx.tar.gz",
}
}
}

/// NoiseFilter is used to process the audio stream and reduce the noise.
/// Audio processed by this filter will be resampled to 16KHz.
pub struct NoiseFilter {
delay: usize,

Check warning on line 47 in screenpipe-audio/src/audio_processing/noise_reduction.rs

View workflow job for this annotation

GitHub Actions / test-ubuntu

field `delay` is never read

Check warning on line 47 in screenpipe-audio/src/audio_processing/noise_reduction.rs

View workflow job for this annotation

GitHub Actions / test-macos

field `delay` is never read

Check warning on line 47 in screenpipe-audio/src/audio_processing/noise_reduction.rs

View workflow job for this annotation

GitHub Actions / test-linux

field `delay` is never read

Check warning on line 47 in screenpipe-audio/src/audio_processing/noise_reduction.rs

View workflow job for this annotation

GitHub Actions / test-windows

field `delay` is never read

Check warning on line 47 in screenpipe-audio/src/audio_processing/noise_reduction.rs

View workflow job for this annotation

GitHub Actions / test-windows

field `delay` is never read
model: DfTract,
stream_sample_rate: usize,
model_sample_rate: usize,
}

unsafe impl Send for NoiseFilter {}
unsafe impl Sync for NoiseFilter {}

impl NoiseFilter {
/// Initialize the noise filter with the given model and sample rate.
///
/// This function will download the model if it is not already present in the cache directory.
///
/// # Arguments
/// * `model_path` - The path to the noise reduction model
///
/// # Returns
/// The initialized noise filter
pub async fn new(model_path: PathBuf, sample_rate: u32) -> Result<Self> {
let mut r_params = RuntimeParams::default();
r_params = r_params
.with_atten_lim(100.0)
.with_thresholds(-15.0, 35.0, 35.0);
// if args.post_filter {
r_params = r_params.with_post_filter(0.02);
// }
if let Ok(red) = 1.try_into() {
r_params = r_params.with_mask_reduce(red);
} else {
log::warn!("Input not valid for `reduce_mask`.")
}
let df_params = match DfParams::new(model_path.clone()) {
Ok(p) => p,
Err(e) => {
log::error!("Error opening model {}: {}", model_path.display(), e);
return Err(e);
}
};

let mut model: DfTract = DfTract::new(df_params.clone(), &r_params)?;
model.ch = 1;
let sr = model.sr;
let mut delay = model.fft_size - model.hop_size; // STFT delay
delay += model.lookahead * model.hop_size; // Add model latency due to lookahead
Ok(Self {
model,
delay,
stream_sample_rate: sample_rate as usize,
model_sample_rate: sr,
})
}

pub fn process(&mut self, input: &[f32]) -> Result<Vec<f32>> {
// if self.model.n_ch != reader.channels {
// self.model.n_ch = reader.channels;
// model = DfTract::new(df_params.clone(), &r_params)?;
// sr = model.sr;
// }

let mut noisy = Array2::from_shape_vec((1, input.len()), input.to_vec())?;

if self.model_sample_rate != self.stream_sample_rate {
noisy = resample(
noisy.view(),
self.stream_sample_rate,
self.model_sample_rate,
None,
)
.expect("Error during resample()");
}
let noisy = noisy.as_standard_layout();
let mut enh: Array2<f32> = ArrayD::default(noisy.shape()).into_dimensionality()?;

for (ns_f, enh_f) in noisy
.view()
.axis_chunks_iter(Axis(1), self.model.hop_size)
.zip(
enh.view_mut()
.axis_chunks_iter_mut(Axis(1), self.model.hop_size),
)
{
if ns_f.len_of(Axis(1)) < self.model.hop_size {
break;
}
self.model.process(ns_f, enh_f)?;
}

// if self.compensate_delay {
// enh.slice_axis_inplace(Axis(1), tract_core::ndarray::Slice::from(self.delay..));
// }

if self.model_sample_rate != self.stream_sample_rate {
enh = resample(
enh.view(),
self.model_sample_rate,
self.stream_sample_rate,
None,
)
.expect("Error during resample()");
}

Ok(enh.view().to_slice().unwrap().to_vec())
}

pub async fn get_or_download_model(model_type: NoiseReductionModel) -> Result<PathBuf> {
let (model_path, model_caller): (&Mutex<Option<PathBuf>>, &Once) = match model_type {
NoiseReductionModel::V3 => (&*V3_MODEL_PATH, &DOWNLOAD_V3_ONCE),
NoiseReductionModel::V3LL => (&*V3LL_MODEL_PATH, &DOWNLOAD_V3LL_ONCE),
};

let mut model_path = model_path.lock().await;
if let Some(path) = model_path.as_ref() {
debug!("using cached {} model: {:?}", model_type, path);
return Ok(path.clone());
}

let cache_dir = NoiseFilter::get_cache_dir()?;
let path = cache_dir.join(format!("{}.onnx", model_type));

if path.exists() {
debug!("found existing {} model at: {:?}", model_type, path);
*model_path = Some(path.clone());
return Ok(path);
}

info!("initiating {} model download...", model_type);
model_caller.call_once(|| {
tokio::spawn(async move {
if let Err(e) = NoiseFilter::download_model(model_type).await {
debug!("error downloading {} model: {}", model_type, e);
}
});
});

while !path.exists() {
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}

*model_path = Some(path.clone());
Ok(path)
}

async fn download_model(model_type: NoiseReductionModel) -> Result<()> {
let (url, filename) = match model_type {
NoiseReductionModel::V3 => (
"https://github.com/mediar-ai/screenpipe/raw/refs/heads/main/screenpipe-audio/models/deep-filter-net/DeepFilterNet3_onnx.tar.gz",
"DeepFilterNet3_onnx.tar.gz",
),
NoiseReductionModel::V3LL => (
"https://github.com/mediar-ai/screenpipe/raw/refs/heads/main/screenpipe-audio/models/deep-filter-net/DeepFilterNet3_ll_onnx.tar.gz",
"DeepFilterNet3_ll_onnx.tar.gz",
),
};

info!("downloading {} model from {}", filename, url);
let response = reqwest::get(url).await?;
let model_data = response.bytes().await?;

let cache_dir = NoiseFilter::get_cache_dir()?;
tokio::fs::create_dir_all(&cache_dir).await?;
let path = cache_dir.join(filename);

info!(
"saving {} model ({} bytes) to {:?}",
filename,
model_data.len(),
path
);
let mut file = tokio::fs::File::create(&path).await?;
tokio::io::AsyncWriteExt::write_all(&mut file, &model_data).await?;
info!("{} model successfully downloaded and saved", filename);

Ok(())
}

fn get_cache_dir() -> Result<PathBuf> {
let proj_dirs =
dirs::cache_dir().ok_or_else(|| anyhow::anyhow!("failed to get cache dir"))?;
Ok(proj_dirs.join("screenpipe").join("models"))
}
}
Loading
Loading