From eb5b8f895502ba19e80f0cd17bee83a4a00596e2 Mon Sep 17 00:00:00 2001 From: Josh Holmer Date: Wed, 13 Apr 2022 00:48:42 -0400 Subject: [PATCH 01/11] Enable photon noise arg in stable Add pre-processing denoising Align denoising arrays Autovectorization pass --- Cargo.toml | 10 +- clippy.toml | 1 + src/api/config/encoder.rs | 3 + src/api/internal.rs | 34 +- src/api/test.rs | 2 + src/bin/common.rs | 26 +- src/bin/rav1e-ch.rs | 1 - src/bin/rav1e.rs | 1 - src/denoise.rs | 587 ++++++++++++++++++++++++++++++++++ src/fuzzing.rs | 1 + src/lib.rs | 1 + src/test_encode_decode/mod.rs | 3 +- src/util/align.rs | 14 + 13 files changed, 674 insertions(+), 10 deletions(-) create mode 100644 src/denoise.rs diff --git a/Cargo.toml b/Cargo.toml index 7de4ccce74..cddc892753 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,7 +36,6 @@ binaries = [ "fern", "console", "av-metrics", - "nom", ] default = ["binaries", "asm", "threading", "signal_support"] asm = ["nasm-rs", "cc", "regex"] @@ -100,11 +99,18 @@ wasm-bindgen = { version = "0.2.63", optional = true } rust_hawktracer = "0.7.0" arrayref = "0.3.6" const_fn_assert = "0.1.2" -nom = { version = "7.0.0", optional = true } +# `unreachable!` macro which panics in debug mode +# and optimizes away in release mode new_debug_unreachable = "1.0.4" once_cell = "1.13.0" av1-grain = { version = "0.2.0", features = ["serialize"] } serde-big-array = { version = "0.4.1", optional = true } +# Used for parsing film grain table files +nom = "7.0.0" +# Used as a data holder during denoising +ndarray = "0.15.4" +# Used for running FFTs during denoising +ndrustfft = "0.3.0" [dependencies.image] version = "0.24.3" diff --git a/clippy.toml b/clippy.toml index f26dd286a7..2b357613e4 100644 --- a/clippy.toml +++ b/clippy.toml @@ -1,4 +1,5 @@ too-many-arguments-threshold = 16 cognitive-complexity-threshold = 40 trivial-copy-size-limit = 16 # 128-bits = 2 64-bit registers +doc-valid-idents = ["DFTTest"] msrv = "1.59" diff --git a/src/api/config/encoder.rs b/src/api/config/encoder.rs index 7f84d5a081..c36bbda0f7 100644 --- a/src/api/config/encoder.rs +++ b/src/api/config/encoder.rs @@ -85,6 +85,8 @@ pub struct EncoderConfig { pub tune: Tune, /// Parameters for grain synthesis. pub film_grain_params: Option>, + /// Strength of denoising, 0 = disabled + pub denoise_strength: u8, /// Number of tiles horizontally. Must be a power of two. /// /// Overridden by [`tiles`], if present. @@ -159,6 +161,7 @@ impl EncoderConfig { bitrate: 0, tune: Tune::default(), film_grain_params: None, + denoise_strength: 0, tile_cols: 0, tile_rows: 0, tiles: 0, diff --git a/src/api/internal.rs b/src/api/internal.rs index 169a9c96ed..e9d0743c5b 100644 --- a/src/api/internal.rs +++ b/src/api/internal.rs @@ -15,6 +15,7 @@ use crate::api::{ }; use crate::color::ChromaSampling::Cs400; use crate::cpu_features::CpuFeatureLevel; +use crate::denoise::{DftDenoiser, TB_MIDPOINT}; use crate::dist::get_satd; use crate::encoder::*; use crate::frame::*; @@ -220,7 +221,7 @@ impl FrameData { } } -type FrameQueue = BTreeMap>>>; +pub(crate) type FrameQueue = BTreeMap>>>; type FrameDataQueue = BTreeMap>>; // the fields pub(super) are accessed only by the tests @@ -248,6 +249,7 @@ pub(crate) struct ContextInner { /// Maps `output_frameno` to `gop_input_frameno_start`. pub(crate) gop_input_frameno_start: BTreeMap, keyframe_detector: SceneChangeDetector, + denoiser: Option>, pub(crate) config: Arc, seq: Arc, pub(crate) rc_state: RCState, @@ -295,6 +297,17 @@ impl ContextInner { lookahead_distance, seq.clone(), ), + denoiser: if enc.denoise_strength > 0 { + Some(DftDenoiser::new( + enc.denoise_strength as f32 / 10.0, + enc.width, + enc.height, + enc.bit_depth as u8, + enc.chroma_sampling, + )) + } else { + None + }, config: Arc::new(enc.clone()), seq, rc_state: RCState::new( @@ -359,6 +372,25 @@ impl ContextInner { self.t35_q.insert(input_frameno, params.t35_metadata); } + // If denoising is enabled, run it now because we want the entire + // encoding process, including lookahead, to see the denoised frame. + if let Some(ref mut denoiser) = self.denoiser { + loop { + let denoiser_frame = denoiser.cur_frameno; + if (!is_flushing + && input_frameno >= denoiser_frame + TB_MIDPOINT as u64) + || (is_flushing && Some(denoiser_frame) < self.limit) + { + self.frame_q.insert( + denoiser_frame, + Some(Arc::new(denoiser.filter_frame(&self.frame_q).unwrap())), + ); + } else { + break; + } + } + } + if !self.needs_more_frame_q_lookahead(self.next_lookahead_frame) { let lookahead_frames = self .frame_q diff --git a/src/api/test.rs b/src/api/test.rs index 072562631f..3d167fdfc5 100644 --- a/src/api/test.rs +++ b/src/api/test.rs @@ -2131,6 +2131,7 @@ fn log_q_exp_overflow() { tile_cols: 0, tile_rows: 0, tiles: 0, + denoise_strength: 0, speed_settings: SpeedSettings { multiref: false, fast_deblock: true, @@ -2207,6 +2208,7 @@ fn guess_frame_subtypes_assert() { tile_cols: 0, tile_rows: 0, tiles: 0, + denoise_strength: 0, speed_settings: SpeedSettings { multiref: false, fast_deblock: true, diff --git a/src/bin/common.rs b/src/bin/common.rs index 18cb05d388..7b21a2ef6b 100644 --- a/src/bin/common.rs +++ b/src/bin/common.rs @@ -176,7 +176,6 @@ pub struct CliOptions { pub still_picture: bool, /// Uses grain synthesis to add photon noise to the resulting encode. /// Takes a strength value 0-64. - #[cfg(feature = "unstable")] #[clap( long, conflicts_with = "film-grain-table", @@ -185,6 +184,17 @@ pub struct CliOptions { help_heading = "ENCODE SETTINGS" )] pub photon_noise: u8, + /// Enable spatio-temporal denoising, intended to be used with grain synthesis. + /// Takes a strength value 0-50. + /// + /// Default strength is 1/2 of photon noise strength, + /// or 4 if a photon noise table is specified. + #[clap( + long, + value_parser = clap::value_parser!(u8).range(0..=50), + help_heading = "ENCODE SETTINGS" + )] + pub denoise: Option, /// Uses a film grain table file to apply grain synthesis to the encode. /// Uses the same table file format as aomenc and svt-av1. #[clap( @@ -334,7 +344,6 @@ pub struct ParsedCliOptions { pub save_config: Option, #[cfg(feature = "unstable")] pub slots: usize, - #[cfg(feature = "unstable")] pub generate_grain_strength: u8, } @@ -482,7 +491,6 @@ pub fn parse_cli() -> Result { save_config: save_config_path, #[cfg(feature = "unstable")] slots, - #[cfg(feature = "unstable")] generate_grain_strength: matches.photon_noise, }) } @@ -676,7 +684,19 @@ fn parse_config(matches: &CliOptions) -> Result { .expect("Failed to parse film grain table"); if !table.is_empty() { cfg.film_grain_params = Some(table); + cfg.denoise_strength = 4; + } + } else if matches.photon_noise > 0 { + cfg.denoise_strength = matches.photon_noise / 2; + // We have to know the video resolution before we can generate a table, + // so we must handle that elsewhere. + } + // A user set denoise strength overrides the defaults above + if let Some(denoise_str) = matches.denoise { + if denoise_str > 50 { + panic!("Denoising strength must be between 0-50"); } + cfg.denoise_strength = denoise_str; } if let Some(frame_rate) = matches.frame_rate { diff --git a/src/bin/rav1e-ch.rs b/src/bin/rav1e-ch.rs index e0221fbdd9..71d559995b 100644 --- a/src/bin/rav1e-ch.rs +++ b/src/bin/rav1e-ch.rs @@ -473,7 +473,6 @@ fn run() -> Result<(), error::CliError> { cli.enc.time_base = video_info.time_base; } - #[cfg(feature = "unstable")] if cli.generate_grain_strength > 0 && cli.enc.film_grain_params.is_none() { cli.enc.film_grain_params = Some(vec![generate_photon_noise_params( 0, diff --git a/src/bin/rav1e.rs b/src/bin/rav1e.rs index 8d672736d8..2390ffb0bb 100644 --- a/src/bin/rav1e.rs +++ b/src/bin/rav1e.rs @@ -456,7 +456,6 @@ fn run() -> Result<(), error::CliError> { cli.enc.time_base = video_info.time_base; } - #[cfg(feature = "unstable")] if cli.generate_grain_strength > 0 && cli.enc.film_grain_params.is_none() { cli.enc.film_grain_params = Some(vec![generate_photon_noise_params( 0, diff --git a/src/denoise.rs b/src/denoise.rs new file mode 100644 index 0000000000..93cfb499ae --- /dev/null +++ b/src/denoise.rs @@ -0,0 +1,587 @@ +use crate::api::FrameQueue; +use crate::util::Aligned; +use crate::EncoderStatus; +use arrayvec::ArrayVec; +use ndarray::{Array3, ArrayView3, ArrayViewMut3}; +use ndrustfft::{ + ndfft, ndfft_r2c, ndifft, ndifft_r2c, Complex, FftHandler, R2cFftHandler, +}; +use std::collections::{BTreeMap, VecDeque}; +use std::f64::consts::PI; +use std::iter::once; +use std::mem::size_of; +use std::ptr::copy_nonoverlapping; +use std::sync::Arc; +use v_frame::frame::Frame; +use v_frame::math::clamp; +use v_frame::pixel::{CastFromPrimitive, ChromaSampling, Pixel}; +use v_frame::plane::Plane; + +const SB_SIZE: usize = 16; +const SO_SIZE: usize = 12; +const TB_SIZE: usize = 3; +pub(crate) const TB_MIDPOINT: usize = TB_SIZE / 2; +const BLOCK_AREA: usize = SB_SIZE * SB_SIZE; +const BLOCK_VOLUME: usize = BLOCK_AREA * TB_SIZE; +const COMPLEX_COUNT: usize = (SB_SIZE / 2 + 1) * SB_SIZE * TB_SIZE; +const CCNT2: usize = COMPLEX_COUNT * 2; +const INC: usize = SB_SIZE - SO_SIZE; + +/// This denoiser is based on the DFTTest plugin from Vapoursynth. +/// This type of denoising was chosen because it provides +/// high quality while not being too slow. +pub(crate) struct DftDenoiser +where + T: Pixel, +{ + chroma_sampling: ChromaSampling, + dest_scale: f32, + src_scale: f32, + peak: T, + + // These indices refer to planes of the input + pad_dimensions: ArrayVec<(usize, usize), 3>, + effective_heights: ArrayVec, + + hw: Aligned<[f32; BLOCK_VOLUME]>, + dftgc: Aligned<[Complex; COMPLEX_COUNT]>, + fft: (R2cFftHandler, FftHandler, FftHandler), + sigmas: Aligned<[f32; CCNT2]>, + + // This stores a copy of the unfiltered previous frame, + // since in `frame_q` it will be filtered already. + // We only have one frame, but it's left as a Vec so that + // TB_SIZE could potentially be tweaked without any + // code changes. + frame_buffer: VecDeque>>, + pub(crate) cur_frameno: u64, +} + +impl DftDenoiser +where + T: Pixel, +{ + // This should only need to run once per video. + pub fn new( + sigma: f32, width: usize, height: usize, bit_depth: u8, + chroma_sampling: ChromaSampling, + ) -> Self { + if size_of::() == 1 { + assert!(bit_depth <= 8); + } else { + assert!(bit_depth > 8); + } + + let dest_scale = (1 << (bit_depth - 8)) as f32; + let src_scale = 1.0 / dest_scale; + let peak = T::cast_from((1u16 << bit_depth) - 1); + + let mut pad_dimensions = ArrayVec::<_, 3>::new(); + let mut effective_heights = ArrayVec::<_, 3>::new(); + for plane in 0..3 { + let ae = (SB_SIZE - SO_SIZE).max(SO_SIZE) * 2; + let (width, height) = if plane == 0 { + (width, height) + } else { + chroma_sampling.get_chroma_dimensions(width, height) + }; + let pad_w = width + extra(width, SB_SIZE) + ae; + let pad_h = height + extra(height, SB_SIZE) + ae; + let e_h = + ((pad_h - SO_SIZE) / (SB_SIZE - SO_SIZE)) * (SB_SIZE - SO_SIZE); + pad_dimensions.push((pad_w, pad_h)); + effective_heights.push(e_h); + } + + let hw = Aligned::new(Self::create_window()); + let mut dftgr = Aligned::new([0f32; BLOCK_VOLUME]); + + let fft = ( + R2cFftHandler::new(SB_SIZE), + FftHandler::new(SB_SIZE), + FftHandler::new(TB_SIZE), + ); + + let mut wscale = 0.0f32; + for k in 0..BLOCK_VOLUME { + dftgr[k] = 255.0 * hw[k]; + wscale += hw[k].powi(2); + } + let wscale = 1.0 / wscale; + + let mut sigmas = Aligned::new([0f32; CCNT2]); + sigmas.fill(sigma / wscale); + + let mut denoiser = DftDenoiser { + chroma_sampling, + dest_scale, + src_scale, + peak, + pad_dimensions, + effective_heights, + hw, + fft, + sigmas, + dftgc: Aligned::new([Complex::default(); COMPLEX_COUNT]), + frame_buffer: VecDeque::with_capacity(TB_MIDPOINT), + cur_frameno: 0, + }; + + let mut dftgc = Aligned::new([Complex::default(); COMPLEX_COUNT]); + denoiser.real_to_complex_3d(&dftgr, &mut dftgc); + denoiser.dftgc = dftgc; + + denoiser + } + + pub fn filter_frame( + &mut self, frame_q: &FrameQueue, + ) -> Result, EncoderStatus> { + if self.frame_buffer.len() < TB_MIDPOINT.min(self.cur_frameno as usize) { + // We need to have the previous unfiltered frame + // in the buffer for temporal filtering. + return Err(EncoderStatus::NeedMoreData); + } + let future_frames = frame_q + .range((self.cur_frameno + 1)..) + .take(TB_MIDPOINT) + .map(|(_, f)| f) + .collect::>(); + if future_frames.len() != TB_MIDPOINT + && !future_frames.iter().any(|f| f.is_none()) + { + // We also need to have the next unfiltered frame, + // unless we are at the end of the video. + return Err(EncoderStatus::NeedMoreData); + } + + let orig_frame = frame_q.get(&self.cur_frameno).unwrap().as_ref().unwrap(); + let frames = self + .frame_buffer + .iter() + .cloned() + .enumerate() + .chain(once(((TB_MIDPOINT), Arc::clone(orig_frame)))) + .chain( + future_frames + .into_iter() + .flatten() + .cloned() + .enumerate() + .map(|(i, f)| (i + 1 + TB_MIDPOINT, f)), + ) + .collect::>(); + + let mut dest = (**orig_frame).clone(); + let mut pad = ArrayVec::<_, TB_SIZE>::new(); + for i in 0..TB_SIZE { + let dec = self.chroma_sampling.get_decimation().unwrap_or((0, 0)); + let mut pad_frame = [ + Plane::new( + self.pad_dimensions[0].0, + self.pad_dimensions[0].1, + 0, + 0, + 0, + 0, + ), + Plane::new( + self.pad_dimensions[1].0, + self.pad_dimensions[1].1, + dec.0, + dec.1, + 0, + 0, + ), + Plane::new( + self.pad_dimensions[2].0, + self.pad_dimensions[2].1, + dec.0, + dec.1, + 0, + 0, + ), + ]; + + let frame = frames.get(&i).unwrap_or(&frames[&TB_MIDPOINT]); + self.copy_pad(frame, &mut pad_frame); + pad.push(pad_frame); + } + self.do_filtering(&pad, &mut dest); + + if self.frame_buffer.len() == TB_MIDPOINT { + self.frame_buffer.pop_front(); + } + self.frame_buffer.push_back(Arc::clone(orig_frame)); + self.cur_frameno += 1; + + Ok(dest) + } + + fn do_filtering(&mut self, src: &[[Plane; 3]], dest: &mut Frame) { + let mut dftr = [0f32; BLOCK_VOLUME]; + let mut dftc = [Complex::::default(); COMPLEX_COUNT]; + let mut means = [Complex::::default(); COMPLEX_COUNT]; + + for p in 0..3 { + let (pad_width, pad_height) = self.pad_dimensions[p]; + let mut ebuff = vec![0f32; pad_width * pad_height]; + let effective_height = self.effective_heights[p]; + let src_stride = src[0][p].cfg.stride; + let ebuff_stride = pad_width; + + let mut src_planes = src + .iter() + .map(|f| f[p].data_origin()) + .collect::>(); + + // SAFETY: We know the size of the planes we're working on, + // so we can safely ensure we are not out of bounds. + // There are a fair number of unsafe function calls here + // which are unsafe for optimization purposes. + // All are safe as long as we do not pass out-of-bounds parameters. + unsafe { + for y in (0..effective_height).step_by(INC) { + for x in (0..=(pad_width - SB_SIZE)).step_by(INC) { + for z in 0..TB_SIZE { + self.proc0( + &src_planes[z][x..], + &self.hw[(BLOCK_AREA * z)..], + &mut dftr[(BLOCK_AREA * z)..], + src_stride, + SB_SIZE, + self.src_scale, + ); + } + + self.real_to_complex_3d(&dftr, &mut dftc); + self.remove_mean(&mut dftc, &self.dftgc, &mut means); + + self.filter_coeffs(&mut dftc); + + self.add_mean(&mut dftc, &means); + self.complex_to_real_3d(&dftc, &mut dftr); + + self.proc1( + &dftr[(TB_MIDPOINT * BLOCK_AREA)..], + &self.hw[(TB_MIDPOINT * BLOCK_AREA)..], + &mut ebuff[(y * ebuff_stride + x)..], + SB_SIZE, + ebuff_stride, + ); + } + + for q in 0..TB_SIZE { + src_planes[q] = &src_planes[q][(INC * src_stride)..]; + } + } + } + + let dest_width = dest.planes[p].cfg.width; + let dest_height = dest.planes[p].cfg.height; + let dest_stride = dest.planes[p].cfg.stride; + let dest_plane = dest.planes[p].data_origin_mut(); + let ebp_offset = ebuff_stride * ((pad_height - dest_height) / 2) + + (pad_width - dest_width) / 2; + let ebp = &ebuff[ebp_offset..]; + + self.cast( + ebp, + dest_plane, + dest_width, + dest_height, + dest_stride, + ebuff_stride, + ); + } + } + + fn create_window() -> [f32; BLOCK_VOLUME] { + let mut hw = [0f32; BLOCK_VOLUME]; + let mut tw = [0f64; TB_SIZE]; + let mut sw = [0f64; SB_SIZE]; + + tw.fill_with(Self::temporal_window); + sw.iter_mut().enumerate().for_each(|(j, sw)| { + *sw = Self::spatial_window(j as f64 + 0.5); + }); + Self::normalize_for_overlap_add(&mut sw); + + let nscale = 1.0 / (BLOCK_VOLUME as f64).sqrt(); + for j in 0..TB_SIZE { + for k in 0..SB_SIZE { + for q in 0..SB_SIZE { + hw[(j * SB_SIZE + k) * SB_SIZE + q] = + (tw[j] * sw[k] * sw[q] * nscale) as f32; + } + } + } + + hw + } + + #[inline(always)] + // Hanning windowing + fn spatial_window(n: f64) -> f64 { + 0.5 - 0.5 * (2.0 * PI * n / SB_SIZE as f64).cos() + } + + #[inline(always)] + // Simple rectangular windowing + fn temporal_window() -> f64 { + 1.0 + } + + // Accounts for spatial block overlap + fn normalize_for_overlap_add(hw: &mut [f64]) { + let inc = SB_SIZE - SO_SIZE; + + let mut nw = [0f64; SB_SIZE]; + let hw = &mut hw[..SB_SIZE]; + + for q in 0..SB_SIZE { + for h in (0..=q).rev().step_by(inc) { + nw[q] += hw[h].powi(2); + } + for h in ((q + inc)..SB_SIZE).step_by(inc) { + nw[q] += hw[h].powi(2); + } + } + + for q in 0..SB_SIZE { + hw[q] /= nw[q].sqrt(); + } + } + + #[inline] + unsafe fn proc0( + &self, s0: &[T], s1: &[f32], dest: &mut [f32], p0: usize, p1: usize, + src_scale: f32, + ) { + let s0 = s0.as_ptr(); + let s1 = s1.as_ptr(); + let dest = dest.as_mut_ptr(); + + for u in 0..p1 { + for v in 0..p1 { + let s0 = s0.add(u * p0 + v); + let s1 = s1.add(u * p1 + v); + let dest = dest.add(u * p1 + v); + dest.write(u16::cast_from(s0.read()) as f32 * src_scale * s1.read()) + } + } + } + + #[inline] + unsafe fn proc1( + &self, s0: &[f32], s1: &[f32], dest: &mut [f32], p0: usize, p1: usize, + ) { + let s0 = s0.as_ptr(); + let s1 = s1.as_ptr(); + let dest = dest.as_mut_ptr(); + + for u in 0..p0 { + for v in 0..p0 { + let s0 = s0.add(u * p0 + v); + let s1 = s1.add(u * p0 + v); + let dest = dest.add(u * p1 + v); + dest.write(s0.read().mul_add(s1.read(), dest.read())); + } + } + } + + #[inline] + fn remove_mean( + &self, dftc: &mut [Complex; COMPLEX_COUNT], + dftgc: &[Complex; COMPLEX_COUNT], + means: &mut [Complex; COMPLEX_COUNT], + ) { + let gf = dftc[0].re / dftgc[0].re; + + for h in 0..COMPLEX_COUNT { + means[h].re = gf * dftgc[h].re; + means[h].im = gf * dftgc[h].im; + dftc[h].re -= means[h].re; + dftc[h].im -= means[h].im; + } + } + + #[inline] + fn add_mean( + &self, dftc: &mut [Complex; COMPLEX_COUNT], + means: &[Complex; COMPLEX_COUNT], + ) { + for h in 0..COMPLEX_COUNT { + dftc[h].re += means[h].re; + dftc[h].im += means[h].im; + } + } + + #[inline] + // Applies a generalized wiener filter + fn filter_coeffs(&self, dftc: &mut [Complex; COMPLEX_COUNT]) { + for h in 0..COMPLEX_COUNT { + let psd = dftc[h].re.mul_add(dftc[h].re, dftc[h].im.powi(2)); + let mult = ((psd - self.sigmas[h]) / (psd + 1e-15)).max(0.0); + dftc[h].re *= mult; + dftc[h].im *= mult; + } + } + + fn copy_pad(&self, src: &Frame, dest: &mut [Plane; 3]) { + for p in 0..src.planes.len() { + let src_width = src.planes[p].cfg.width; + let dest_width = dest[p].cfg.width; + let src_height = src.planes[p].cfg.height; + let dest_height = dest[p].cfg.height; + let src_stride = src.planes[p].cfg.stride; + let dest_stride = dest[p].cfg.stride; + + let offy = (dest_height - src_height) / 2; + let offx = (dest_width - src_width) / 2; + + bitblt( + &mut dest[p].data_origin_mut()[(dest_stride * offy + offx)..], + dest_stride, + src.planes[p].data_origin(), + src_stride, + src_width, + src_height, + ); + + let mut dest_ptr = + &mut dest[p].data_origin_mut()[(dest_stride * offy)..]; + for _ in offy..(src_height + offy) { + let dest_slice = &mut dest_ptr[..dest_width]; + + let mut w = offx * 2; + for x in 0..offx { + dest_slice[x] = dest_slice[w]; + w -= 1; + } + + w = offx + src_width - 2; + for x in (offx + src_width)..dest_width { + dest_slice[x] = dest_slice[w]; + w -= 1; + } + + dest_ptr = &mut dest_ptr[dest_stride..]; + } + + let dest_origin = dest[p].data_origin_mut(); + let mut w = offy * 2; + for y in 0..offy { + // SAFETY: `copy_from_slice` has borrow checker issues here + // because we are copying from `dest` to `dest`, but we manually + // know that the two slices will not overlap. We still slice + // the start and end as a safety check. + unsafe { + copy_nonoverlapping( + dest_origin[(dest_stride * w)..][..dest_width].as_ptr(), + dest_origin[(dest_stride * y)..][..dest_width].as_mut_ptr(), + dest_width, + ); + } + w -= 1; + } + + w = offy + src_height - 2; + for y in (offy + src_height)..dest_height { + // SAFETY: `copy_from_slice` has borrow checker issues here + // because we are copying from `dest` to `dest`, but we manually + // know that the two slices will not overlap. We still slice + // the start and end as a safety check. + unsafe { + copy_nonoverlapping( + dest_origin[(dest_stride * w)..][..dest_width].as_ptr(), + dest_origin[(dest_stride * y)..][..dest_width].as_mut_ptr(), + dest_width, + ); + } + w -= 1; + } + } + } + + fn cast( + &self, ebuff: &[f32], dest: &mut [T], dest_width: usize, + dest_height: usize, dest_stride: usize, ebp_stride: usize, + ) { + let ebuff = ebuff.chunks(ebp_stride); + let dest = dest.chunks_mut(dest_stride); + + for (ebuff, dest) in ebuff.zip(dest).take(dest_height) { + for x in 0..dest_width { + let fval = ebuff[x].mul_add(self.dest_scale, 0.5); + dest[x] = + clamp(T::cast_from(fval as u16), T::cast_from(0u16), self.peak); + } + } + } + + // Applies a real-to-complex 3-dimensional FFT to `real` + fn real_to_complex_3d( + &mut self, real: &[f32; BLOCK_VOLUME], + output: &mut [Complex; COMPLEX_COUNT], + ) { + let input = + ArrayView3::from_shape((TB_SIZE, SB_SIZE, SB_SIZE), real).unwrap(); + let mut temp1 = Array3::zeros((TB_SIZE, SB_SIZE, SB_SIZE / 2 + 1)); + let mut temp2 = Array3::zeros((TB_SIZE, SB_SIZE, SB_SIZE / 2 + 1)); + let mut output = + ArrayViewMut3::from_shape((TB_SIZE, SB_SIZE, SB_SIZE / 2 + 1), output) + .unwrap(); + + ndfft_r2c(&input, &mut temp1, &mut self.fft.0, 2); + ndfft(&temp1, &mut temp2, &mut self.fft.1, 1); + ndfft(&temp2, &mut output, &mut self.fft.2, 0); + } + + // Applies a complex-to-real 3-dimensional FFT to `complex` + fn complex_to_real_3d( + &mut self, complex: &[Complex; COMPLEX_COUNT], + output: &mut [f32; BLOCK_VOLUME], + ) { + let input = + ArrayView3::from_shape((TB_SIZE, SB_SIZE, SB_SIZE / 2 + 1), complex) + .unwrap(); + let mut temp0 = Array3::zeros((TB_SIZE, SB_SIZE, SB_SIZE / 2 + 1)); + let mut temp1 = Array3::zeros((TB_SIZE, SB_SIZE, SB_SIZE / 2 + 1)); + let mut output = + ArrayViewMut3::from_shape((TB_SIZE, SB_SIZE, SB_SIZE), output).unwrap(); + + ndifft(&input, &mut temp0, &mut self.fft.2, 0); + ndifft(&temp0, &mut temp1, &mut self.fft.1, 1); + ndifft_r2c(&temp1, &mut output, &mut self.fft.0, 2); + output.iter_mut().for_each(|d| { + *d *= BLOCK_VOLUME as f32; + }); + } +} + +#[inline(always)] +fn extra(a: usize, b: usize) -> usize { + if a % b > 0 { + b - (a % b) + } else { + 0 + } +} + +// Identical to Vapoursynth's implementation `vs_bitblt` +// which basically copies the pixels in a plane. +fn bitblt( + mut dest: &mut [T], dest_stride: usize, mut src: &[T], src_stride: usize, + width: usize, height: usize, +) { + if src_stride == dest_stride && src_stride == width { + dest[..(width * height)].copy_from_slice(&src[..(width * height)]); + } else { + for _ in 0..height { + dest[..width].copy_from_slice(&src[..width]); + src = &src[src_stride..]; + dest = &mut dest[dest_stride..]; + } + } +} diff --git a/src/fuzzing.rs b/src/fuzzing.rs index aab9abe059..d440f3f88f 100644 --- a/src/fuzzing.rs +++ b/src/fuzzing.rs @@ -257,6 +257,7 @@ impl Arbitrary for ArbitraryEncoder { switch_frame_interval: u.int_in_range(0..=3)?, tune: *u.choose(&[Tune::Psnr, Tune::Psychovisual])?, film_grain_params: None, + denoise_strength: u.int_in_range(0..=50)?, }; let frame_count = diff --git a/src/lib.rs b/src/lib.rs index 3425588db4..05e2270115 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -257,6 +257,7 @@ mod cdef; #[doc(hidden)] pub mod context; mod deblock; +mod denoise; mod encoder; mod entropymode; mod lrf; diff --git a/src/test_encode_decode/mod.rs b/src/test_encode_decode/mod.rs index 9e6082ed55..8bdecca838 100644 --- a/src/test_encode_decode/mod.rs +++ b/src/test_encode_decode/mod.rs @@ -10,9 +10,8 @@ // Fuzzing only uses a subset of these. #![cfg_attr(fuzzing, allow(unused))] -use crate::color::ChromaSampling; - use crate::api::config::GrainTableSegment; +use crate::color::ChromaSampling; use crate::util::Pixel; use crate::*; diff --git a/src/util/align.rs b/src/util/align.rs index c86424e8b2..02928698cd 100644 --- a/src/util/align.rs +++ b/src/util/align.rs @@ -42,6 +42,20 @@ impl Aligned { } } +impl std::ops::Deref for Aligned { + type Target = T; + + fn deref(&self) -> &T { + &self.data + } +} + +impl std::ops::DerefMut for Aligned { + fn deref_mut(&mut self) -> &mut T { + &mut self.data + } +} + /// An analog to a Box<[T]> where the underlying slice is aligned. /// Alignment is according to the architecture-specific SIMD constraints. pub struct AlignedBoxedSlice { From e804929b48f619a29b172cde0a1838e49f5cbb0c Mon Sep 17 00:00:00 2001 From: Josh Holmer Date: Wed, 7 Sep 2022 01:30:55 -0400 Subject: [PATCH 02/11] Add benchmarks --- benches/bench.rs | 27 +++++++++++++++- benches/denoise.rs | 76 ++++++++++++++++++++++++++++++++++++++++++++++ benches/dist.rs | 21 +------------ benches/mc.rs | 23 ++------------ src/denoise.rs | 2 +- src/lib.rs | 6 +++- 6 files changed, 111 insertions(+), 44 deletions(-) create mode 100644 benches/denoise.rs diff --git a/benches/bench.rs b/benches/bench.rs index 4929d83fdc..e5395b99de 100644 --- a/benches/bench.rs +++ b/benches/bench.rs @@ -7,6 +7,7 @@ // Media Patent License 1.0 was not distributed with this source code in the // PATENTS file, you can obtain it at www.aomedia.org/license/patent. +mod denoise; mod dist; mod mc; mod plane; @@ -23,12 +24,15 @@ use rav1e::bench::partition::*; use rav1e::bench::predict::*; use rav1e::bench::rdo::*; use rav1e::bench::transform::*; +use rav1e::prelude::*; use crate::plane::plane; use crate::rdo::rdo; use crate::transform::{forward_transforms, inverse_transforms}; use criterion::*; +use rand::Rng; +use rand_chacha::ChaChaRng; use std::sync::Arc; use std::time::Duration; @@ -193,6 +197,26 @@ fn update_cdf_4(b: &mut Bencher) { }); } +fn fill_plane(ra: &mut ChaChaRng, plane: &mut Plane) { + let stride = plane.cfg.stride; + for row in plane.data_origin_mut().chunks_mut(stride) { + for pixel in row { + let v: u8 = ra.gen(); + *pixel = T::cast_from(v); + } + } +} + +fn new_plane( + ra: &mut ChaChaRng, width: usize, height: usize, +) -> Plane { + let mut p = Plane::new(width, height, 0, 0, 128 + 8, 128 + 8); + + fill_plane(ra, &mut p); + + p +} + criterion_group!(intra_prediction, predict::pred_bench,); criterion_group!(cfl, cfl_rdo); @@ -217,5 +241,6 @@ criterion_main!( ec, rdo, plane, - mc::mc + mc::mc, + denoise::denoise ); diff --git a/benches/denoise.rs b/benches/denoise.rs new file mode 100644 index 0000000000..5daad7577b --- /dev/null +++ b/benches/denoise.rs @@ -0,0 +1,76 @@ +use super::new_plane; +use criterion::*; +use rand::SeedableRng; +use rand_chacha::ChaChaRng; +use rav1e::bench::denoise::*; +use rav1e::prelude::*; +use std::collections::BTreeMap; +use std::sync::Arc; + +fn bench_dft_denoiser_8b(c: &mut Criterion) { + let mut ra = ChaChaRng::from_seed([0; 32]); + let w = 640; + let h = 480; + let mut frame_queue = BTreeMap::new(); + for i in 0..3 { + frame_queue.insert( + i, + Some(Arc::new(Frame { + planes: [ + new_plane::(&mut ra, w, h), + new_plane::(&mut ra, w / 2, h / 2), + new_plane::(&mut ra, w / 2, h / 2), + ], + })), + ); + } + frame_queue.insert(3, None); + + c.bench_function("dft_denoiser_8b", |b| { + b.iter_with_setup( + || DftDenoiser::new(2.0, w, h, 8, ChromaSampling::Cs420), + |mut denoiser| { + for _ in 0..3 { + let _ = black_box(denoiser.filter_frame(&frame_queue)); + } + }, + ) + }); +} + +fn bench_dft_denoiser_10b(c: &mut Criterion) { + let mut ra = ChaChaRng::from_seed([0; 32]); + let w = 640; + let h = 480; + let mut frame_queue = BTreeMap::new(); + for i in 0..3 { + let mut frame = Frame { + planes: [ + new_plane::(&mut ra, w, h), + new_plane::(&mut ra, w / 2, h / 2), + new_plane::(&mut ra, w / 2, h / 2), + ], + }; + for p in 0..3 { + // Shift from 16-bit to 10-bit + frame.planes[p].data.iter_mut().for_each(|pix| { + *pix = *pix >> 6; + }); + } + frame_queue.insert(i, Some(Arc::new(frame))); + } + frame_queue.insert(3, None); + + c.bench_function("dft_denoiser_10b", |b| { + b.iter_with_setup( + || DftDenoiser::new(2.0, w, h, 10, ChromaSampling::Cs420), + |mut denoiser| { + for _ in 0..3 { + let _ = black_box(denoiser.filter_frame(&frame_queue)); + } + }, + ) + }); +} + +criterion_group!(denoise, bench_dft_denoiser_8b, bench_dft_denoiser_10b); diff --git a/benches/dist.rs b/benches/dist.rs index b9cffd3adb..6c1415310d 100644 --- a/benches/dist.rs +++ b/benches/dist.rs @@ -9,6 +9,7 @@ #![allow(clippy::trivially_copy_pass_by_ref)] +use super::new_plane; use criterion::*; use rand::{Rng, SeedableRng}; use rand_chacha::ChaChaRng; @@ -69,26 +70,6 @@ const DIST_BENCH_SET: &[(BlockSize, usize)] = &[ (BLOCK_64X16, 10), ]; -fn fill_plane(ra: &mut ChaChaRng, plane: &mut Plane) { - let stride = plane.cfg.stride; - for row in plane.data_origin_mut().chunks_mut(stride) { - for pixel in row { - let v: u8 = ra.gen(); - *pixel = T::cast_from(v); - } - } -} - -fn new_plane( - ra: &mut ChaChaRng, width: usize, height: usize, -) -> Plane { - let mut p = Plane::new(width, height, 0, 0, 128 + 8, 128 + 8); - - fill_plane(ra, &mut p); - - p -} - type DistFn = fn( plane_org: &PlaneRegion<'_, T>, plane_ref: &PlaneRegion<'_, T>, diff --git a/benches/mc.rs b/benches/mc.rs index 4d3f81add7..8b6156f89b 100644 --- a/benches/mc.rs +++ b/benches/mc.rs @@ -1,7 +1,8 @@ #![allow(clippy::unit_arg)] +use super::new_plane; use criterion::*; -use rand::{Rng, SeedableRng}; +use rand::SeedableRng; use rand_chacha::ChaChaRng; use rav1e::bench::cpu_features::*; use rav1e::bench::frame::{AsRegion, PlaneOffset, PlaneSlice}; @@ -525,26 +526,6 @@ criterion_group!( bench_prep_8tap_center_hbd ); -fn fill_plane(ra: &mut ChaChaRng, plane: &mut Plane) { - let stride = plane.cfg.stride; - for row in plane.data_origin_mut().chunks_mut(stride) { - for pixel in row { - let v: u8 = ra.gen(); - *pixel = T::cast_from(v); - } - } -} - -fn new_plane( - ra: &mut ChaChaRng, width: usize, height: usize, -) -> Plane { - let mut p = Plane::new(width, height, 0, 0, 128 + 8, 128 + 8); - - fill_plane(ra, &mut p); - - p -} - fn get_params( rec_plane: &Plane, po: PlaneOffset, mv: MotionVector, ) -> (i32, i32, PlaneSlice) { diff --git a/src/denoise.rs b/src/denoise.rs index 93cfb499ae..30b8ad5995 100644 --- a/src/denoise.rs +++ b/src/denoise.rs @@ -30,7 +30,7 @@ const INC: usize = SB_SIZE - SO_SIZE; /// This denoiser is based on the DFTTest plugin from Vapoursynth. /// This type of denoising was chosen because it provides /// high quality while not being too slow. -pub(crate) struct DftDenoiser +pub struct DftDenoiser where T: Pixel, { diff --git a/src/lib.rs b/src/lib.rs index 05e2270115..0c31e91123 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -257,7 +257,8 @@ mod cdef; #[doc(hidden)] pub mod context; mod deblock; -mod denoise; +#[doc(hidden)] +pub mod denoise; mod encoder; mod entropymode; mod lrf; @@ -447,6 +448,9 @@ pub mod bench { pub mod context { pub use crate::context::*; } + pub mod denoise { + pub use crate::denoise::*; + } pub mod dist { pub use crate::dist::*; } From b0f42edbc6af2a46b8d7ba8df87adc60f8874466 Mon Sep 17 00:00:00 2001 From: Josh Holmer Date: Wed, 7 Sep 2022 20:12:03 -0400 Subject: [PATCH 03/11] Use stack arrays, 5% perf improvement --- src/denoise.rs | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/src/denoise.rs b/src/denoise.rs index 30b8ad5995..24ca2eeb05 100644 --- a/src/denoise.rs +++ b/src/denoise.rs @@ -2,7 +2,7 @@ use crate::api::FrameQueue; use crate::util::Aligned; use crate::EncoderStatus; use arrayvec::ArrayVec; -use ndarray::{Array3, ArrayView3, ArrayViewMut3}; +use ndarray::{ArrayView3, ArrayViewMut3}; use ndrustfft::{ ndfft, ndfft_r2c, ndifft, ndifft_r2c, Complex, FftHandler, R2cFftHandler, }; @@ -525,10 +525,20 @@ where &mut self, real: &[f32; BLOCK_VOLUME], output: &mut [Complex; COMPLEX_COUNT], ) { + let mut temp1_data = [Complex::default(); COMPLEX_COUNT]; + let mut temp2_data = [Complex::default(); COMPLEX_COUNT]; let input = ArrayView3::from_shape((TB_SIZE, SB_SIZE, SB_SIZE), real).unwrap(); - let mut temp1 = Array3::zeros((TB_SIZE, SB_SIZE, SB_SIZE / 2 + 1)); - let mut temp2 = Array3::zeros((TB_SIZE, SB_SIZE, SB_SIZE / 2 + 1)); + let mut temp1 = ArrayViewMut3::from_shape( + (TB_SIZE, SB_SIZE, SB_SIZE / 2 + 1), + &mut temp1_data, + ) + .unwrap(); + let mut temp2 = ArrayViewMut3::from_shape( + (TB_SIZE, SB_SIZE, SB_SIZE / 2 + 1), + &mut temp2_data, + ) + .unwrap(); let mut output = ArrayViewMut3::from_shape((TB_SIZE, SB_SIZE, SB_SIZE / 2 + 1), output) .unwrap(); @@ -543,11 +553,21 @@ where &mut self, complex: &[Complex; COMPLEX_COUNT], output: &mut [f32; BLOCK_VOLUME], ) { + let mut temp0_data = [Complex::default(); COMPLEX_COUNT]; + let mut temp1_data = [Complex::default(); COMPLEX_COUNT]; let input = ArrayView3::from_shape((TB_SIZE, SB_SIZE, SB_SIZE / 2 + 1), complex) .unwrap(); - let mut temp0 = Array3::zeros((TB_SIZE, SB_SIZE, SB_SIZE / 2 + 1)); - let mut temp1 = Array3::zeros((TB_SIZE, SB_SIZE, SB_SIZE / 2 + 1)); + let mut temp0 = ArrayViewMut3::from_shape( + (TB_SIZE, SB_SIZE, SB_SIZE / 2 + 1), + &mut temp0_data, + ) + .unwrap(); + let mut temp1 = ArrayViewMut3::from_shape( + (TB_SIZE, SB_SIZE, SB_SIZE / 2 + 1), + &mut temp1_data, + ) + .unwrap(); let mut output = ArrayViewMut3::from_shape((TB_SIZE, SB_SIZE, SB_SIZE), output).unwrap(); From f29265f40dd69bdcca574372c727e9222a0b54ad Mon Sep 17 00:00:00 2001 From: Josh Holmer Date: Fri, 28 Oct 2022 21:47:31 -0400 Subject: [PATCH 04/11] WIP --- Cargo.toml | 1 + src/api/internal.rs | 4 +- src/denoise.rs | 677 ++++++++++---------------------------------- 3 files changed, 153 insertions(+), 529 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d0ffbc579d..5329a4bf0d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -114,6 +114,7 @@ nom = "7.0.0" ndarray = "0.15.4" # Used for running FFTs during denoising ndrustfft = "0.3.0" +wide = "0.7.5" [dependencies.image] version = "0.24.3" diff --git a/src/api/internal.rs b/src/api/internal.rs index 7acb9a96ed..b4cd728421 100644 --- a/src/api/internal.rs +++ b/src/api/internal.rs @@ -15,7 +15,7 @@ use crate::api::{ }; use crate::color::ChromaSampling::Cs400; use crate::cpu_features::CpuFeatureLevel; -use crate::denoise::{DftDenoiser, TB_MIDPOINT}; +use crate::denoise::{DftDenoiser, TEMPORAL_RADIUS}; use crate::dist::get_satd; use crate::encoder::*; use crate::frame::*; @@ -378,7 +378,7 @@ impl ContextInner { loop { let denoiser_frame = denoiser.cur_frameno; if (!is_flushing - && input_frameno >= denoiser_frame + TB_MIDPOINT as u64) + && input_frameno >= denoiser_frame + TEMPORAL_RADIUS as u64) || (is_flushing && Some(denoiser_frame) < self.limit) { self.frame_q.insert( diff --git a/src/denoise.rs b/src/denoise.rs index 24ca2eeb05..43854671b4 100644 --- a/src/denoise.rs +++ b/src/denoise.rs @@ -2,59 +2,38 @@ use crate::api::FrameQueue; use crate::util::Aligned; use crate::EncoderStatus; use arrayvec::ArrayVec; -use ndarray::{ArrayView3, ArrayViewMut3}; -use ndrustfft::{ - ndfft, ndfft_r2c, ndifft, ndifft_r2c, Complex, FftHandler, R2cFftHandler, -}; -use std::collections::{BTreeMap, VecDeque}; -use std::f64::consts::PI; +use std::f32::consts::PI; use std::iter::once; -use std::mem::size_of; -use std::ptr::copy_nonoverlapping; +use std::mem::{size_of, transmute, MaybeUninit}; use std::sync::Arc; use v_frame::frame::Frame; -use v_frame::math::clamp; -use v_frame::pixel::{CastFromPrimitive, ChromaSampling, Pixel}; +use v_frame::pixel::{ChromaSampling, Pixel}; use v_frame::plane::Plane; -const SB_SIZE: usize = 16; -const SO_SIZE: usize = 12; -const TB_SIZE: usize = 3; -pub(crate) const TB_MIDPOINT: usize = TB_SIZE / 2; -const BLOCK_AREA: usize = SB_SIZE * SB_SIZE; -const BLOCK_VOLUME: usize = BLOCK_AREA * TB_SIZE; -const COMPLEX_COUNT: usize = (SB_SIZE / 2 + 1) * SB_SIZE * TB_SIZE; -const CCNT2: usize = COMPLEX_COUNT * 2; -const INC: usize = SB_SIZE - SO_SIZE; +pub const TEMPORAL_RADIUS: usize = 1; +const TEMPORAL_SIZE: usize = TEMPORAL_RADIUS * 2 + 1; +const BLOCK_SIZE: usize = 16; +const BLOCK_STEP: usize = 12; +const REAL_TOTAL: usize = TEMPORAL_SIZE * BLOCK_SIZE * BLOCK_SIZE; +const COMPLEX_TOTAL: usize = TEMPORAL_SIZE * BLOCK_SIZE * (BLOCK_SIZE / 2 + 1); -/// This denoiser is based on the DFTTest plugin from Vapoursynth. +/// This denoiser is based on the DFTTest2 plugin from Vapoursynth. /// This type of denoising was chosen because it provides -/// high quality while not being too slow. +/// high quality while not being too slow. The DFTTest2 implementation +/// is much faster than the original due to its custom FFT kernel. pub struct DftDenoiser where T: Pixel, { - chroma_sampling: ChromaSampling, - dest_scale: f32, - src_scale: f32, - peak: T, - - // These indices refer to planes of the input - pad_dimensions: ArrayVec<(usize, usize), 3>, - effective_heights: ArrayVec, - - hw: Aligned<[f32; BLOCK_VOLUME]>, - dftgc: Aligned<[Complex; COMPLEX_COUNT]>, - fft: (R2cFftHandler, FftHandler, FftHandler), - sigmas: Aligned<[f32; CCNT2]>, - - // This stores a copy of the unfiltered previous frame, - // since in `frame_q` it will be filtered already. - // We only have one frame, but it's left as a Vec so that - // TB_SIZE could potentially be tweaked without any - // code changes. - frame_buffer: VecDeque>>, + // External values + prev_frame: Option>>, pub(crate) cur_frameno: u64, + // Local values + sigma: Aligned<[f32; COMPLEX_TOTAL]>, + window: Aligned<[f32; REAL_TOTAL]>, + window_freq: Aligned<[f32; REAL_TOTAL]>, + pmin: f32, + pmax: f32, } impl DftDenoiser @@ -72,521 +51,161 @@ where assert!(bit_depth > 8); } - let dest_scale = (1 << (bit_depth - 8)) as f32; - let src_scale = 1.0 / dest_scale; - let peak = T::cast_from((1u16 << bit_depth) - 1); - - let mut pad_dimensions = ArrayVec::<_, 3>::new(); - let mut effective_heights = ArrayVec::<_, 3>::new(); - for plane in 0..3 { - let ae = (SB_SIZE - SO_SIZE).max(SO_SIZE) * 2; - let (width, height) = if plane == 0 { - (width, height) - } else { - chroma_sampling.get_chroma_dimensions(width, height) - }; - let pad_w = width + extra(width, SB_SIZE) + ae; - let pad_h = height + extra(height, SB_SIZE) + ae; - let e_h = - ((pad_h - SO_SIZE) / (SB_SIZE - SO_SIZE)) * (SB_SIZE - SO_SIZE); - pad_dimensions.push((pad_w, pad_h)); - effective_heights.push(e_h); - } - - let hw = Aligned::new(Self::create_window()); - let mut dftgr = Aligned::new([0f32; BLOCK_VOLUME]); - - let fft = ( - R2cFftHandler::new(SB_SIZE), - FftHandler::new(SB_SIZE), - FftHandler::new(TB_SIZE), - ); - - let mut wscale = 0.0f32; - for k in 0..BLOCK_VOLUME { - dftgr[k] = 255.0 * hw[k]; - wscale += hw[k].powi(2); - } - let wscale = 1.0 / wscale; - - let mut sigmas = Aligned::new([0f32; CCNT2]); - sigmas.fill(sigma / wscale); - - let mut denoiser = DftDenoiser { - chroma_sampling, - dest_scale, - src_scale, - peak, - pad_dimensions, - effective_heights, - hw, - fft, - sigmas, - dftgc: Aligned::new([Complex::default(); COMPLEX_COUNT]), - frame_buffer: VecDeque::with_capacity(TB_MIDPOINT), + let window = build_window(); + let wscale = window.iter().copied().map(|w| w * w).sum::(); + let sigma = sigma as f32 * wscale; + let pmin = 0.0f32; + let pmax = 500.0f32 * wscale; + // SAFETY: The `assume_init` is safe because the type we are claiming to have + // initialized here is a bunch of `MaybeUninit`s, which do not require initialization. + let mut window_freq_temp: Aligned<[MaybeUninit; REAL_TOTAL]> = + Aligned::new(unsafe { MaybeUninit::uninit().assume_init() }); + window_freq_temp.iter_mut().zip(window.iter()).for_each(|(freq, w)| { + freq.write(*w * 255.0); + }); + // SAFETY: Everything is initialized. Transmute the array to the + // initialized type. + let window_freq_temp: Aligned<[f32; REAL_TOTAL]> = + unsafe { transmute(window_freq_temp) }; + let window_freq = rdft(&window_freq_temp); + let sigma = Aligned::new([sigma; COMPLEX_TOTAL]); + + Self { + prev_frame: None, cur_frameno: 0, - }; - - let mut dftgc = Aligned::new([Complex::default(); COMPLEX_COUNT]); - denoiser.real_to_complex_3d(&dftgr, &mut dftgc); - denoiser.dftgc = dftgc; - - denoiser + sigma, + window, + window_freq, + pmin, + pmax, + } } pub fn filter_frame( &mut self, frame_q: &FrameQueue, ) -> Result, EncoderStatus> { - if self.frame_buffer.len() < TB_MIDPOINT.min(self.cur_frameno as usize) { - // We need to have the previous unfiltered frame - // in the buffer for temporal filtering. - return Err(EncoderStatus::NeedMoreData); - } - let future_frames = frame_q - .range((self.cur_frameno + 1)..) - .take(TB_MIDPOINT) - .map(|(_, f)| f) - .collect::>(); - if future_frames.len() != TB_MIDPOINT - && !future_frames.iter().any(|f| f.is_none()) - { + let next_frame = frame_q.get(&(self.cur_frameno + 1)); + if next_frame.is_none() { // We also need to have the next unfiltered frame, // unless we are at the end of the video. return Err(EncoderStatus::NeedMoreData); } let orig_frame = frame_q.get(&self.cur_frameno).unwrap().as_ref().unwrap(); - let frames = self - .frame_buffer - .iter() - .cloned() - .enumerate() - .chain(once(((TB_MIDPOINT), Arc::clone(orig_frame)))) - .chain( - future_frames - .into_iter() - .flatten() - .cloned() - .enumerate() - .map(|(i, f)| (i + 1 + TB_MIDPOINT, f)), - ) - .collect::>(); - - let mut dest = (**orig_frame).clone(); - let mut pad = ArrayVec::<_, TB_SIZE>::new(); - for i in 0..TB_SIZE { - let dec = self.chroma_sampling.get_decimation().unwrap_or((0, 0)); - let mut pad_frame = [ - Plane::new( - self.pad_dimensions[0].0, - self.pad_dimensions[0].1, - 0, - 0, - 0, - 0, - ), - Plane::new( - self.pad_dimensions[1].0, - self.pad_dimensions[1].1, - dec.0, - dec.1, - 0, - 0, - ), - Plane::new( - self.pad_dimensions[2].0, - self.pad_dimensions[2].1, - dec.0, - dec.1, - 0, - 0, - ), - ]; - - let frame = frames.get(&i).unwrap_or(&frames[&TB_MIDPOINT]); - self.copy_pad(frame, &mut pad_frame); - pad.push(pad_frame); - } - self.do_filtering(&pad, &mut dest); - - if self.frame_buffer.len() == TB_MIDPOINT { - self.frame_buffer.pop_front(); - } - self.frame_buffer.push_back(Arc::clone(orig_frame)); + let frames = once(self.prev_frame.clone()) + .chain(once(Some(Arc::clone(orig_frame)))) + .chain(next_frame.cloned()) + .collect::, 3>>(); + + todo!(); + // let mut dest = (**orig_frame).clone(); + // let mut pad = ArrayVec::<_, TB_SIZE>::new(); + // for i in 0..TB_SIZE { + // let dec = self.chroma_sampling.get_decimation().unwrap_or((0, 0)); + // let mut pad_frame = [ + // Plane::new( + // self.pad_dimensions[0].0, + // self.pad_dimensions[0].1, + // 0, + // 0, + // 0, + // 0, + // ), + // Plane::new( + // self.pad_dimensions[1].0, + // self.pad_dimensions[1].1, + // dec.0, + // dec.1, + // 0, + // 0, + // ), + // Plane::new( + // self.pad_dimensions[2].0, + // self.pad_dimensions[2].1, + // dec.0, + // dec.1, + // 0, + // 0, + // ), + // ]; + + // let frame = frames.get(&i).unwrap_or(&frames[&TEMP_RADIUS]); + // self.copy_pad(frame, &mut pad_frame); + // pad.push(pad_frame); + // } + // self.do_filtering(&pad, &mut dest); + + self.prev_frame = Some(Arc::clone(orig_frame)); self.cur_frameno += 1; - Ok(dest) + // Ok(dest) } fn do_filtering(&mut self, src: &[[Plane; 3]], dest: &mut Frame) { - let mut dftr = [0f32; BLOCK_VOLUME]; - let mut dftc = [Complex::::default(); COMPLEX_COUNT]; - let mut means = [Complex::::default(); COMPLEX_COUNT]; - - for p in 0..3 { - let (pad_width, pad_height) = self.pad_dimensions[p]; - let mut ebuff = vec![0f32; pad_width * pad_height]; - let effective_height = self.effective_heights[p]; - let src_stride = src[0][p].cfg.stride; - let ebuff_stride = pad_width; - - let mut src_planes = src - .iter() - .map(|f| f[p].data_origin()) - .collect::>(); - - // SAFETY: We know the size of the planes we're working on, - // so we can safely ensure we are not out of bounds. - // There are a fair number of unsafe function calls here - // which are unsafe for optimization purposes. - // All are safe as long as we do not pass out-of-bounds parameters. - unsafe { - for y in (0..effective_height).step_by(INC) { - for x in (0..=(pad_width - SB_SIZE)).step_by(INC) { - for z in 0..TB_SIZE { - self.proc0( - &src_planes[z][x..], - &self.hw[(BLOCK_AREA * z)..], - &mut dftr[(BLOCK_AREA * z)..], - src_stride, - SB_SIZE, - self.src_scale, - ); - } - - self.real_to_complex_3d(&dftr, &mut dftc); - self.remove_mean(&mut dftc, &self.dftgc, &mut means); - - self.filter_coeffs(&mut dftc); - - self.add_mean(&mut dftc, &means); - self.complex_to_real_3d(&dftc, &mut dftr); - - self.proc1( - &dftr[(TB_MIDPOINT * BLOCK_AREA)..], - &self.hw[(TB_MIDPOINT * BLOCK_AREA)..], - &mut ebuff[(y * ebuff_stride + x)..], - SB_SIZE, - ebuff_stride, - ); - } - - for q in 0..TB_SIZE { - src_planes[q] = &src_planes[q][(INC * src_stride)..]; - } - } - } - - let dest_width = dest.planes[p].cfg.width; - let dest_height = dest.planes[p].cfg.height; - let dest_stride = dest.planes[p].cfg.stride; - let dest_plane = dest.planes[p].data_origin_mut(); - let ebp_offset = ebuff_stride * ((pad_height - dest_height) / 2) - + (pad_width - dest_width) / 2; - let ebp = &ebuff[ebp_offset..]; - - self.cast( - ebp, - dest_plane, - dest_width, - dest_height, - dest_stride, - ebuff_stride, - ); - } - } - - fn create_window() -> [f32; BLOCK_VOLUME] { - let mut hw = [0f32; BLOCK_VOLUME]; - let mut tw = [0f64; TB_SIZE]; - let mut sw = [0f64; SB_SIZE]; - - tw.fill_with(Self::temporal_window); - sw.iter_mut().enumerate().for_each(|(j, sw)| { - *sw = Self::spatial_window(j as f64 + 0.5); - }); - Self::normalize_for_overlap_add(&mut sw); - - let nscale = 1.0 / (BLOCK_VOLUME as f64).sqrt(); - for j in 0..TB_SIZE { - for k in 0..SB_SIZE { - for q in 0..SB_SIZE { - hw[(j * SB_SIZE + k) * SB_SIZE + q] = - (tw[j] * sw[k] * sw[q] * nscale) as f32; - } - } - } - - hw - } - - #[inline(always)] - // Hanning windowing - fn spatial_window(n: f64) -> f64 { - 0.5 - 0.5 * (2.0 * PI * n / SB_SIZE as f64).cos() - } - - #[inline(always)] - // Simple rectangular windowing - fn temporal_window() -> f64 { - 1.0 + todo!(); } +} - // Accounts for spatial block overlap - fn normalize_for_overlap_add(hw: &mut [f64]) { - let inc = SB_SIZE - SO_SIZE; - - let mut nw = [0f64; SB_SIZE]; - let hw = &mut hw[..SB_SIZE]; - - for q in 0..SB_SIZE { - for h in (0..=q).rev().step_by(inc) { - nw[q] += hw[h].powi(2); - } - for h in ((q + inc)..SB_SIZE).step_by(inc) { - nw[q] += hw[h].powi(2); - } - } - - for q in 0..SB_SIZE { - hw[q] /= nw[q].sqrt(); - } - } - - #[inline] - unsafe fn proc0( - &self, s0: &[T], s1: &[f32], dest: &mut [f32], p0: usize, p1: usize, - src_scale: f32, - ) { - let s0 = s0.as_ptr(); - let s1 = s1.as_ptr(); - let dest = dest.as_mut_ptr(); - - for u in 0..p1 { - for v in 0..p1 { - let s0 = s0.add(u * p0 + v); - let s1 = s1.add(u * p1 + v); - let dest = dest.add(u * p1 + v); - dest.write(u16::cast_from(s0.read()) as f32 * src_scale * s1.read()) - } - } - } +#[inline(always)] +// Hanning windowing +fn spatial_window_value(n: f32) -> f32 { + let temp = PI * n / BLOCK_SIZE as f32; + 0.5 * (1.0 - (2.0 * temp).cos()) +} - #[inline] - unsafe fn proc1( - &self, s0: &[f32], s1: &[f32], dest: &mut [f32], p0: usize, p1: usize, - ) { - let s0 = s0.as_ptr(); - let s1 = s1.as_ptr(); - let dest = dest.as_mut_ptr(); +#[inline(always)] +// Simple rectangular windowing +const fn temporal_window_value() -> f32 { + 1.0 +} - for u in 0..p0 { - for v in 0..p0 { - let s0 = s0.add(u * p0 + v); - let s1 = s1.add(u * p0 + v); - let dest = dest.add(u * p1 + v); - dest.write(s0.read().mul_add(s1.read(), dest.read())); +pub fn build_window() -> Aligned<[f32; REAL_TOTAL]> { + let temporal_window = [temporal_window_value(); TEMPORAL_SIZE]; + + // SAFETY: The `assume_init` is safe because the type we are claiming to have + // initialized here is a bunch of `MaybeUninit`s, which do not require initialization. + let mut spatial_window: [MaybeUninit; BLOCK_SIZE] = + unsafe { MaybeUninit::uninit().assume_init() }; + spatial_window.iter_mut().enumerate().for_each(|(i, val)| { + val.write(spatial_window_value(i as f32 + 0.5)); + }); + // SAFETY: Everything is initialized. Transmute the array to the + // initialized type. + let spatial_window: [f32; BLOCK_SIZE] = unsafe { transmute(spatial_window) }; + let spatial_window = normalize(&spatial_window); + + let mut window: Aligned<[MaybeUninit; REAL_TOTAL]> = + Aligned::new(unsafe { MaybeUninit::uninit().assume_init() }); + let mut window_iter = window.iter_mut(); + for t_val in temporal_window { + for s_val1 in spatial_window { + for s_val2 in spatial_window { + let mut value = t_val * s_val1 * s_val2; + // normalize for unnormalized FFT implementation + value /= (TEMPORAL_SIZE as f32).sqrt() * BLOCK_SIZE as f32; + window_iter.next().unwrap().write(value); } } } + // SAFETY: Everything is initialized. Transmute the array to the + // initialized type. + unsafe { transmute(window) } +} - #[inline] - fn remove_mean( - &self, dftc: &mut [Complex; COMPLEX_COUNT], - dftgc: &[Complex; COMPLEX_COUNT], - means: &mut [Complex; COMPLEX_COUNT], - ) { - let gf = dftc[0].re / dftgc[0].re; - - for h in 0..COMPLEX_COUNT { - means[h].re = gf * dftgc[h].re; - means[h].im = gf * dftgc[h].im; - dftc[h].re -= means[h].re; - dftc[h].im -= means[h].im; - } - } - - #[inline] - fn add_mean( - &self, dftc: &mut [Complex; COMPLEX_COUNT], - means: &[Complex; COMPLEX_COUNT], - ) { - for h in 0..COMPLEX_COUNT { - dftc[h].re += means[h].re; - dftc[h].im += means[h].im; - } - } - - #[inline] - // Applies a generalized wiener filter - fn filter_coeffs(&self, dftc: &mut [Complex; COMPLEX_COUNT]) { - for h in 0..COMPLEX_COUNT { - let psd = dftc[h].re.mul_add(dftc[h].re, dftc[h].im.powi(2)); - let mult = ((psd - self.sigmas[h]) / (psd + 1e-15)).max(0.0); - dftc[h].re *= mult; - dftc[h].im *= mult; - } - } - - fn copy_pad(&self, src: &Frame, dest: &mut [Plane; 3]) { - for p in 0..src.planes.len() { - let src_width = src.planes[p].cfg.width; - let dest_width = dest[p].cfg.width; - let src_height = src.planes[p].cfg.height; - let dest_height = dest[p].cfg.height; - let src_stride = src.planes[p].cfg.stride; - let dest_stride = dest[p].cfg.stride; - - let offy = (dest_height - src_height) / 2; - let offx = (dest_width - src_width) / 2; - - bitblt( - &mut dest[p].data_origin_mut()[(dest_stride * offy + offx)..], - dest_stride, - src.planes[p].data_origin(), - src_stride, - src_width, - src_height, - ); - - let mut dest_ptr = - &mut dest[p].data_origin_mut()[(dest_stride * offy)..]; - for _ in offy..(src_height + offy) { - let dest_slice = &mut dest_ptr[..dest_width]; - - let mut w = offx * 2; - for x in 0..offx { - dest_slice[x] = dest_slice[w]; - w -= 1; - } - - w = offx + src_width - 2; - for x in (offx + src_width)..dest_width { - dest_slice[x] = dest_slice[w]; - w -= 1; - } - - dest_ptr = &mut dest_ptr[dest_stride..]; - } - - let dest_origin = dest[p].data_origin_mut(); - let mut w = offy * 2; - for y in 0..offy { - // SAFETY: `copy_from_slice` has borrow checker issues here - // because we are copying from `dest` to `dest`, but we manually - // know that the two slices will not overlap. We still slice - // the start and end as a safety check. - unsafe { - copy_nonoverlapping( - dest_origin[(dest_stride * w)..][..dest_width].as_ptr(), - dest_origin[(dest_stride * y)..][..dest_width].as_mut_ptr(), - dest_width, - ); - } - w -= 1; - } - - w = offy + src_height - 2; - for y in (offy + src_height)..dest_height { - // SAFETY: `copy_from_slice` has borrow checker issues here - // because we are copying from `dest` to `dest`, but we manually - // know that the two slices will not overlap. We still slice - // the start and end as a safety check. - unsafe { - copy_nonoverlapping( - dest_origin[(dest_stride * w)..][..dest_width].as_ptr(), - dest_origin[(dest_stride * y)..][..dest_width].as_mut_ptr(), - dest_width, - ); - } - w -= 1; - } +pub fn normalize(window: &[f32; BLOCK_SIZE]) -> [f32; BLOCK_SIZE] { + let mut new_window = [0f32; BLOCK_SIZE]; + for q in 0..BLOCK_SIZE { + for h in (0..=q).rev().step_by(BLOCK_STEP) { + new_window[q] += window[h].powi(2); } - } - - fn cast( - &self, ebuff: &[f32], dest: &mut [T], dest_width: usize, - dest_height: usize, dest_stride: usize, ebp_stride: usize, - ) { - let ebuff = ebuff.chunks(ebp_stride); - let dest = dest.chunks_mut(dest_stride); - - for (ebuff, dest) in ebuff.zip(dest).take(dest_height) { - for x in 0..dest_width { - let fval = ebuff[x].mul_add(self.dest_scale, 0.5); - dest[x] = - clamp(T::cast_from(fval as u16), T::cast_from(0u16), self.peak); - } + for h in ((q + BLOCK_STEP)..BLOCK_SIZE).step_by(BLOCK_STEP) { + new_window[q] += window[h].powi(2); } } - - // Applies a real-to-complex 3-dimensional FFT to `real` - fn real_to_complex_3d( - &mut self, real: &[f32; BLOCK_VOLUME], - output: &mut [Complex; COMPLEX_COUNT], - ) { - let mut temp1_data = [Complex::default(); COMPLEX_COUNT]; - let mut temp2_data = [Complex::default(); COMPLEX_COUNT]; - let input = - ArrayView3::from_shape((TB_SIZE, SB_SIZE, SB_SIZE), real).unwrap(); - let mut temp1 = ArrayViewMut3::from_shape( - (TB_SIZE, SB_SIZE, SB_SIZE / 2 + 1), - &mut temp1_data, - ) - .unwrap(); - let mut temp2 = ArrayViewMut3::from_shape( - (TB_SIZE, SB_SIZE, SB_SIZE / 2 + 1), - &mut temp2_data, - ) - .unwrap(); - let mut output = - ArrayViewMut3::from_shape((TB_SIZE, SB_SIZE, SB_SIZE / 2 + 1), output) - .unwrap(); - - ndfft_r2c(&input, &mut temp1, &mut self.fft.0, 2); - ndfft(&temp1, &mut temp2, &mut self.fft.1, 1); - ndfft(&temp2, &mut output, &mut self.fft.2, 0); - } - - // Applies a complex-to-real 3-dimensional FFT to `complex` - fn complex_to_real_3d( - &mut self, complex: &[Complex; COMPLEX_COUNT], - output: &mut [f32; BLOCK_VOLUME], - ) { - let mut temp0_data = [Complex::default(); COMPLEX_COUNT]; - let mut temp1_data = [Complex::default(); COMPLEX_COUNT]; - let input = - ArrayView3::from_shape((TB_SIZE, SB_SIZE, SB_SIZE / 2 + 1), complex) - .unwrap(); - let mut temp0 = ArrayViewMut3::from_shape( - (TB_SIZE, SB_SIZE, SB_SIZE / 2 + 1), - &mut temp0_data, - ) - .unwrap(); - let mut temp1 = ArrayViewMut3::from_shape( - (TB_SIZE, SB_SIZE, SB_SIZE / 2 + 1), - &mut temp1_data, - ) - .unwrap(); - let mut output = - ArrayViewMut3::from_shape((TB_SIZE, SB_SIZE, SB_SIZE), output).unwrap(); - - ndifft(&input, &mut temp0, &mut self.fft.2, 0); - ndifft(&temp0, &mut temp1, &mut self.fft.1, 1); - ndifft_r2c(&temp1, &mut output, &mut self.fft.0, 2); - output.iter_mut().for_each(|d| { - *d *= BLOCK_VOLUME as f32; - }); - } -} - -#[inline(always)] -fn extra(a: usize, b: usize) -> usize { - if a % b > 0 { - b - (a % b) - } else { - 0 + for (w, nw) in window.iter().zip(new_window.iter_mut()) { + *nw = *w / nw.sqrt(); } + new_window } // Identical to Vapoursynth's implementation `vs_bitblt` @@ -605,3 +224,7 @@ fn bitblt( } } } + +fn rdft(data: &Aligned<[f32; REAL_TOTAL]>) -> Aligned<[f32; REAL_TOTAL]> { + todo!(); +} From dbc1682859c99eee92162705907b5327cb365e21 Mon Sep 17 00:00:00 2001 From: Josh Holmer Date: Fri, 28 Oct 2022 23:27:06 -0400 Subject: [PATCH 05/11] Premature optimization --- Cargo.toml | 5 +- clippy.toml | 4 +- src/api/internal.rs | 8 +- src/asm/x86/transform/forward.rs | 14 --- src/denoise.rs | 176 +++++++++++++++++++++---------- src/util/array.rs | 13 +++ src/util/mod.rs | 2 + 7 files changed, 138 insertions(+), 84 deletions(-) create mode 100644 src/util/array.rs diff --git a/Cargo.toml b/Cargo.toml index 5329a4bf0d..2195127367 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -110,11 +110,8 @@ av1-grain = { version = "0.2.0", features = ["serialize"] } serde-big-array = { version = "0.4.1", optional = true } # Used for parsing film grain table files nom = "7.0.0" -# Used as a data holder during denoising -ndarray = "0.15.4" -# Used for running FFTs during denoising -ndrustfft = "0.3.0" wide = "0.7.5" +num-complex = "0.4.2" [dependencies.image] version = "0.24.3" diff --git a/clippy.toml b/clippy.toml index c75aafec2a..bb141ab8e0 100644 --- a/clippy.toml +++ b/clippy.toml @@ -1,5 +1,5 @@ too-many-arguments-threshold = 16 cognitive-complexity-threshold = 40 -trivial-copy-size-limit = 16 # 128-bits = 2 64-bit registers -doc-valid-idents = ["DFTTest"] # 128-bits = 2 64-bit registers +trivial-copy-size-limit = 16 # 128-bits = 2 64-bit registers +doc-valid-idents = ["DFTTest", "DFTTest2"] # 128-bits = 2 64-bit registers msrv = "1.60" diff --git a/src/api/internal.rs b/src/api/internal.rs index b4cd728421..3f6c93a74b 100644 --- a/src/api/internal.rs +++ b/src/api/internal.rs @@ -298,13 +298,7 @@ impl ContextInner { seq.clone(), ), denoiser: if enc.denoise_strength > 0 { - Some(DftDenoiser::new( - enc.denoise_strength as f32 / 10.0, - enc.width, - enc.height, - enc.bit_depth as u8, - enc.chroma_sampling, - )) + Some(DftDenoiser::new(enc.denoise_strength as f32 / 10.0)) } else { None }, diff --git a/src/asm/x86/transform/forward.rs b/src/asm/x86/transform/forward.rs index 18b1171517..7a46c310bb 100644 --- a/src/asm/x86/transform/forward.rs +++ b/src/asm/x86/transform/forward.rs @@ -316,20 +316,6 @@ impl SizeClass1D { } } -fn cast(x: &[T]) -> &[T; N] { - // SAFETY: we perform a bounds check with [..N], - // so casting to *const [T; N] is valid because the bounds - // check guarantees that x has N elements - unsafe { &*(&x[..N] as *const [T] as *const [T; N]) } -} - -fn cast_mut(x: &mut [T]) -> &mut [T; N] { - // SAFETY: we perform a bounds check with [..N], - // so casting to *mut [T; N] is valid because the bounds - // check guarantees that x has N elements - unsafe { &mut *(&mut x[..N] as *mut [T] as *mut [T; N]) } -} - #[allow(clippy::identity_op, clippy::erasing_op)] #[target_feature(enable = "avx2")] unsafe fn forward_transform_avx2( diff --git a/src/denoise.rs b/src/denoise.rs index 43854671b4..dbaebfb74d 100644 --- a/src/denoise.rs +++ b/src/denoise.rs @@ -1,21 +1,25 @@ use crate::api::FrameQueue; -use crate::util::Aligned; +use crate::util::{cast, Aligned}; use crate::EncoderStatus; use arrayvec::ArrayVec; +use num_complex::Complex64; +use num_traits::Zero; use std::f32::consts::PI; +use std::f64::consts::PI as PI64; use std::iter::once; -use std::mem::{size_of, transmute, MaybeUninit}; +use std::ptr::copy_nonoverlapping; use std::sync::Arc; use v_frame::frame::Frame; -use v_frame::pixel::{ChromaSampling, Pixel}; +use v_frame::pixel::Pixel; use v_frame::plane::Plane; +use wide::f32x8; pub const TEMPORAL_RADIUS: usize = 1; const TEMPORAL_SIZE: usize = TEMPORAL_RADIUS * 2 + 1; const BLOCK_SIZE: usize = 16; const BLOCK_STEP: usize = 12; -const REAL_TOTAL: usize = TEMPORAL_SIZE * BLOCK_SIZE * BLOCK_SIZE; -const COMPLEX_TOTAL: usize = TEMPORAL_SIZE * BLOCK_SIZE * (BLOCK_SIZE / 2 + 1); +const REAL_SIZE: usize = TEMPORAL_SIZE * BLOCK_SIZE * BLOCK_SIZE; +const COMPLEX_SIZE: usize = TEMPORAL_SIZE * BLOCK_SIZE * (BLOCK_SIZE / 2 + 1); /// This denoiser is based on the DFTTest2 plugin from Vapoursynth. /// This type of denoising was chosen because it provides @@ -29,9 +33,9 @@ where prev_frame: Option>>, pub(crate) cur_frameno: u64, // Local values - sigma: Aligned<[f32; COMPLEX_TOTAL]>, - window: Aligned<[f32; REAL_TOTAL]>, - window_freq: Aligned<[f32; REAL_TOTAL]>, + sigma: Aligned<[f32; COMPLEX_SIZE]>, + window: Aligned<[f32; REAL_SIZE]>, + window_freq: Aligned<[Complex64; COMPLEX_SIZE]>, pmin: f32, pmax: f32, } @@ -41,34 +45,18 @@ where T: Pixel, { // This should only need to run once per video. - pub fn new( - sigma: f32, width: usize, height: usize, bit_depth: u8, - chroma_sampling: ChromaSampling, - ) -> Self { - if size_of::() == 1 { - assert!(bit_depth <= 8); - } else { - assert!(bit_depth > 8); - } - + pub fn new(sigma: f32) -> Self { let window = build_window(); let wscale = window.iter().copied().map(|w| w * w).sum::(); let sigma = sigma as f32 * wscale; let pmin = 0.0f32; let pmax = 500.0f32 * wscale; - // SAFETY: The `assume_init` is safe because the type we are claiming to have - // initialized here is a bunch of `MaybeUninit`s, which do not require initialization. - let mut window_freq_temp: Aligned<[MaybeUninit; REAL_TOTAL]> = - Aligned::new(unsafe { MaybeUninit::uninit().assume_init() }); - window_freq_temp.iter_mut().zip(window.iter()).for_each(|(freq, w)| { - freq.write(*w * 255.0); + let mut window_freq_real = Aligned::new([0f64; REAL_SIZE]); + window_freq_real.iter_mut().zip(window.iter()).for_each(|(freq, w)| { + *freq = *w as f64 * 255.0; }); - // SAFETY: Everything is initialized. Transmute the array to the - // initialized type. - let window_freq_temp: Aligned<[f32; REAL_TOTAL]> = - unsafe { transmute(window_freq_temp) }; - let window_freq = rdft(&window_freq_temp); - let sigma = Aligned::new([sigma; COMPLEX_TOTAL]); + let window_freq = rdft(&window_freq_real); + let sigma = Aligned::new([sigma; COMPLEX_SIZE]); Self { prev_frame: None, @@ -159,47 +147,52 @@ const fn temporal_window_value() -> f32 { 1.0 } -pub fn build_window() -> Aligned<[f32; REAL_TOTAL]> { +fn build_window() -> Aligned<[f32; REAL_SIZE]> { let temporal_window = [temporal_window_value(); TEMPORAL_SIZE]; - // SAFETY: The `assume_init` is safe because the type we are claiming to have - // initialized here is a bunch of `MaybeUninit`s, which do not require initialization. - let mut spatial_window: [MaybeUninit; BLOCK_SIZE] = - unsafe { MaybeUninit::uninit().assume_init() }; + let mut spatial_window = [0f32; BLOCK_SIZE]; spatial_window.iter_mut().enumerate().for_each(|(i, val)| { - val.write(spatial_window_value(i as f32 + 0.5)); + *val = spatial_window_value(i as f32 + 0.5); }); - // SAFETY: Everything is initialized. Transmute the array to the - // initialized type. - let spatial_window: [f32; BLOCK_SIZE] = unsafe { transmute(spatial_window) }; let spatial_window = normalize(&spatial_window); - let mut window: Aligned<[MaybeUninit; REAL_TOTAL]> = - Aligned::new(unsafe { MaybeUninit::uninit().assume_init() }); - let mut window_iter = window.iter_mut(); + let mut window = Aligned::new([0f32; REAL_SIZE]); + let mut i = 0; for t_val in temporal_window { for s_val1 in spatial_window { - for s_val2 in spatial_window { + for s_vals2 in spatial_window.chunks_exact(8) { + let s_val2 = f32x8::new(*cast::<8, _>(s_vals2)); let mut value = t_val * s_val1 * s_val2; // normalize for unnormalized FFT implementation - value /= (TEMPORAL_SIZE as f32).sqrt() * BLOCK_SIZE as f32; - window_iter.next().unwrap().write(value); + value /= + f32x8::from((TEMPORAL_SIZE as f32).sqrt() * BLOCK_SIZE as f32); + // SAFETY: We know the slices are valid sizes + unsafe { + copy_nonoverlapping( + value.as_array_ref().as_ptr(), + window.as_mut_ptr().add(i), + 8usize, + ) + }; + i += 8; } } } - // SAFETY: Everything is initialized. Transmute the array to the - // initialized type. - unsafe { transmute(window) } + window } -pub fn normalize(window: &[f32; BLOCK_SIZE]) -> [f32; BLOCK_SIZE] { +fn normalize(window: &[f32; BLOCK_SIZE]) -> [f32; BLOCK_SIZE] { let mut new_window = [0f32; BLOCK_SIZE]; - for q in 0..BLOCK_SIZE { - for h in (0..=q).rev().step_by(BLOCK_STEP) { - new_window[q] += window[h].powi(2); - } - for h in ((q + BLOCK_STEP)..BLOCK_SIZE).step_by(BLOCK_STEP) { - new_window[q] += window[h].powi(2); + // SAFETY: We know all of the sizes, so bound checks are not needed. + unsafe { + for q in 0..BLOCK_SIZE { + let nw = new_window.get_unchecked_mut(q); + for h in (0..=q).rev().step_by(BLOCK_STEP) { + *nw += window.get_unchecked(h).powi(2); + } + for h in ((q + BLOCK_STEP)..BLOCK_SIZE).step_by(BLOCK_STEP) { + *nw += window.get_unchecked(h).powi(2); + } } } for (w, nw) in window.iter().zip(new_window.iter_mut()) { @@ -225,6 +218,75 @@ fn bitblt( } } -fn rdft(data: &Aligned<[f32; REAL_TOTAL]>) -> Aligned<[f32; REAL_TOTAL]> { - todo!(); +fn rdft( + input: &Aligned<[f64; REAL_SIZE]>, +) -> Aligned<[Complex64; COMPLEX_SIZE]> { + const SHAPE: [usize; 3] = [TEMPORAL_SIZE, BLOCK_SIZE, BLOCK_SIZE]; + + let mut output = Aligned::new([Complex64::zero(); COMPLEX_SIZE]); + + for i in 0..(SHAPE[0] * SHAPE[1]) { + dft( + &mut output[(i * (SHAPE[1] / 2 + 1))..], + DftInput::Real(&input[(i * SHAPE[1])..]), + SHAPE[2], + 1, + ); + } + + let mut output2 = Aligned::new([Complex64::zero(); COMPLEX_SIZE]); + + let stride = SHAPE[2] / 2 + 1; + for i in 0..SHAPE[0] { + for j in 0..stride { + dft( + &mut output2[(i * SHAPE[1] * stride + j)..], + DftInput::Complex(&output[(i * SHAPE[1] * stride + j)..]), + SHAPE[1], + stride, + ); + } + } + + let stride = SHAPE[1] * stride; + for i in 0..stride { + dft(&mut output[i..], DftInput::Complex(&output2[i..]), SHAPE[0], stride); + } + + output +} + +enum DftInput<'a> { + Real(&'a [f64]), + Complex(&'a [Complex64]), +} + +#[inline(always)] +fn dft(output: &mut [Complex64], input: DftInput, n: usize, stride: usize) { + match input { + DftInput::Real(input) => { + let out_num = n / 2 + 1; + for i in 0..out_num { + let mut sum = Complex64::zero(); + for j in 0..n { + let imag = -2f64 * i as f64 * j as f64 * PI64 / n as f64; + let weight = Complex64::new(imag.cos(), imag.sin()); + sum += input[j * stride] * weight; + } + output[i * stride] = sum; + } + } + DftInput::Complex(input) => { + let out_num = n; + for i in 0..out_num { + let mut sum = Complex64::zero(); + for j in 0..n { + let imag = -2f64 * i as f64 * j as f64 * PI64 / n as f64; + let weight = Complex64::new(imag.cos(), imag.sin()); + sum += input[j * stride] * weight; + } + output[i * stride] = sum; + } + } + } } diff --git a/src/util/array.rs b/src/util/array.rs new file mode 100644 index 0000000000..afeba3f244 --- /dev/null +++ b/src/util/array.rs @@ -0,0 +1,13 @@ +pub fn cast(x: &[T]) -> &[T; N] { + // SAFETY: we perform a bounds check with [..N], + // so casting to *const [T; N] is valid because the bounds + // check guarantees that x has N elements + unsafe { &*(&x[..N] as *const [T] as *const [T; N]) } +} + +pub fn cast_mut(x: &mut [T]) -> &mut [T; N] { + // SAFETY: we perform a bounds check with [..N], + // so casting to *mut [T; N] is valid because the bounds + // check guarantees that x has N elements + unsafe { &mut *(&mut x[..N] as *mut [T] as *mut [T; N]) } +} diff --git a/src/util/mod.rs b/src/util/mod.rs index e5af5c2461..c7c30aef45 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -8,6 +8,7 @@ // PATENTS file, you can obtain it at www.aomedia.org/license/patent. mod align; +mod array; #[macro_use] mod cdf; mod kmeans; @@ -18,6 +19,7 @@ pub use v_frame::math::*; pub use v_frame::pixel::*; pub use align::*; +pub(crate) use array::*; pub use uninit::*; pub use kmeans::*; From 85d98bd0a20581ad51b1a23985e21e4e528bf3a2 Mon Sep 17 00:00:00 2001 From: Josh Holmer Date: Sat, 29 Oct 2022 20:42:19 -0400 Subject: [PATCH 06/11] More changes --- Cargo.toml | 2 +- src/api/internal.rs | 7 +- src/denoise/kernel.rs | 138 +++++++++++++++++++++ src/{denoise.rs => denoise/mod.rs} | 191 +++++++++++++++++++---------- 4 files changed, 268 insertions(+), 70 deletions(-) create mode 100644 src/denoise/kernel.rs rename src/{denoise.rs => denoise/mod.rs} (59%) diff --git a/Cargo.toml b/Cargo.toml index 2195127367..664cd428ec 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -110,7 +110,7 @@ av1-grain = { version = "0.2.0", features = ["serialize"] } serde-big-array = { version = "0.4.1", optional = true } # Used for parsing film grain table files nom = "7.0.0" -wide = "0.7.5" +wide = { git = "https://github.com/shssoichiro/wide", branch = "fast-cast-fns" } num-complex = "0.4.2" [dependencies.image] diff --git a/src/api/internal.rs b/src/api/internal.rs index 3f6c93a74b..bac4910280 100644 --- a/src/api/internal.rs +++ b/src/api/internal.rs @@ -298,7 +298,12 @@ impl ContextInner { seq.clone(), ), denoiser: if enc.denoise_strength > 0 { - Some(DftDenoiser::new(enc.denoise_strength as f32 / 10.0)) + Some(DftDenoiser::::new( + enc.denoise_strength as f32 / 10.0, + enc.width, + enc.height, + enc.bit_depth, + )) } else { None }, diff --git a/src/denoise/kernel.rs b/src/denoise/kernel.rs new file mode 100644 index 0000000000..a283939b4a --- /dev/null +++ b/src/denoise/kernel.rs @@ -0,0 +1,138 @@ +use std::{ + mem::{size_of, transmute}, + ptr::copy_nonoverlapping, +}; + +use crate::util::cast; + +use super::{ + bitblt, calc_pad_size, f32x16, BLOCK_SIZE, BLOCK_STEP, TEMPORAL_SIZE, +}; +use av_metrics::video::Pixel; +use num_complex::Complex32; +use wide::{f32x8, u16x8, u8x16}; + +pub fn reflection_padding( + output: &mut [T], input: &[T], width: usize, height: usize, stride: usize, +) { + let pad_width = calc_pad_size(width); + let pad_height = calc_pad_size(height); + + let offset_y = (pad_height - height) / 2; + let offset_x = (pad_width - width) / 2; + + bitblt( + &mut output[(offset_y * pad_width + offset_x)..], + pad_width, + input, + stride, + width, + height, + ); + + // copy left and right regions + for y in offset_y..(offset_y + height) { + let dst_line = &mut output[(y * pad_width)..]; + + for x in 0..offset_x { + dst_line[x] = dst_line[offset_x * 2 - x]; + } + + for x in (offset_x + width)..pad_width { + dst_line[x] = dst_line[2 * (offset_x + width) - 2 - x]; + } + } + + // copy top region + for y in 0..offset_y { + let dst = output[(y * pad_width)..][..pad_width].as_mut_ptr(); + let src = output[((offset_y * 2 - y) * pad_width)..][..pad_width].as_ptr(); + // SAFETY: We check the start and end bounds above. + // We have to use `copy_nonoverlapping` because the borrow checker is not happy with + // `copy_from_slice`. + unsafe { + copy_nonoverlapping(src, dst, pad_width); + } + } + + // copy bottom region + for y in (offset_y + height)..pad_height { + let dst = output[(y * pad_width)..][..pad_width].as_mut_ptr(); + let src = output[((2 * (offset_y + height) - 2 - y) * pad_width)..] + [..pad_width] + .as_ptr(); + // SAFETY: We check the start and end bounds above. + // We have to use `copy_nonoverlapping` because the borrow checker is not happy with + // `copy_from_slice`. + unsafe { + copy_nonoverlapping(src, dst, pad_width); + } + } +} + +pub fn load_block( + block: &mut [f32x16], shifted_src: &[T], width: usize, height: usize, + bit_depth: usize, window: &[f32], +) { + let scale = 1.0f32 / (1 << (bit_depth - 8)) as f32; + let offset_x = calc_pad_size(width); + let offset_y = calc_pad_size(height); + for i in 0..TEMPORAL_SIZE { + for j in 0..BLOCK_SIZE { + // The compiler will optimize away these branches + let vec_input = if size_of::() == 1 { + // SAFETY: We know that T is u8 + let u8s: [u8; 16] = unsafe { + *transmute::<_, &[u8; 16]>(cast::<16, _>( + &shifted_src[((i * offset_y + j) * offset_x)..][..16], + )) + }; + let f32_upper = u8x16::new(u8s).to_u32x8().to_f32x8(); + let f32_lower = u8x16::new([ + u8s[8], u8s[9], u8s[10], u8s[11], u8s[12], u8s[13], u8s[14], + u8s[15], 0, 0, 0, 0, 0, 0, 0, 0, + ]) + .to_u32x8() + .to_f32x8(); + [f32_upper, f32_lower] + } else { + // SAFETY: We know that T is u8 + let u16s: [u16; 16] = unsafe { + *transmute::<_, &[u16; 16]>(cast::<16, _>( + &shifted_src[((i * offset_y + j) * offset_x)..][..16], + )) + }; + let f32_upper = u16x8::new(*cast(&u16s[..8])).to_u32x8().to_f32x8(); + let f32_lower = u16x8::new(*cast(&u16s[8..])).to_u32x8().to_f32x8(); + [f32_upper, f32_lower] + }; + let window = &window[(i * BLOCK_SIZE + j)..]; + let window_upper = f32x8::new(*cast(&window[..8])); + let window_lower = f32x8::new(*cast(&window[8..16])); + let result_upper = scale * window_upper * vec_input[0]; + let result_lower = scale * window_lower * vec_input[1]; + block[i * BLOCK_SIZE * 2 + j] = [result_upper, result_lower]; + } + } +} + +pub fn fused( + output: &mut [f32x16], sigma: f32, pmin: f32, pmax: f32, + window_freq: &[Complex32], +) { + todo!() +} + +pub fn store_block( + output: &mut [f32], input: &[f32x16], width: usize, height: usize, + window: &[f32], +) { + todo!() +} + +pub fn store_frame( + output: &mut [T], shifted_src: &[f32], width: usize, height: usize, + dst_stride: usize, src_stride: usize, +) { + todo!() +} diff --git a/src/denoise.rs b/src/denoise/mod.rs similarity index 59% rename from src/denoise.rs rename to src/denoise/mod.rs index dbaebfb74d..6815ffa78d 100644 --- a/src/denoise.rs +++ b/src/denoise/mod.rs @@ -1,17 +1,19 @@ +mod kernel; + use crate::api::FrameQueue; use crate::util::{cast, Aligned}; use crate::EncoderStatus; use arrayvec::ArrayVec; -use num_complex::Complex64; +use kernel::*; +use num_complex::Complex32; use num_traits::Zero; use std::f32::consts::PI; -use std::f64::consts::PI as PI64; use std::iter::once; +use std::mem::size_of; use std::ptr::copy_nonoverlapping; use std::sync::Arc; use v_frame::frame::Frame; use v_frame::pixel::Pixel; -use v_frame::plane::Plane; use wide::f32x8; pub const TEMPORAL_RADIUS: usize = 1; @@ -21,6 +23,11 @@ const BLOCK_STEP: usize = 12; const REAL_SIZE: usize = TEMPORAL_SIZE * BLOCK_SIZE * BLOCK_SIZE; const COMPLEX_SIZE: usize = TEMPORAL_SIZE * BLOCK_SIZE * (BLOCK_SIZE / 2 + 1); +// The C implementation uses this f32x16 type, which is implemented the same way +// for non-avx512 systems. `wide` doesn't have a f32x16 type, so we mimic it. +#[allow(non_camel_case_types)] +type f32x16 = [f32x8; 2]; + /// This denoiser is based on the DFTTest2 plugin from Vapoursynth. /// This type of denoising was chosen because it provides /// high quality while not being too slow. The DFTTest2 implementation @@ -32,12 +39,15 @@ where // External values prev_frame: Option>>, pub(crate) cur_frameno: u64, + bit_depth: usize, // Local values - sigma: Aligned<[f32; COMPLEX_SIZE]>, + sigma: f32, window: Aligned<[f32; REAL_SIZE]>, - window_freq: Aligned<[Complex64; COMPLEX_SIZE]>, + window_freq: Aligned<[Complex32; COMPLEX_SIZE]>, pmin: f32, pmax: f32, + padded: Vec, + padded2: Vec, } impl DftDenoiser @@ -45,27 +55,42 @@ where T: Pixel, { // This should only need to run once per video. - pub fn new(sigma: f32) -> Self { + pub fn new( + sigma: f32, width: usize, height: usize, bit_depth: usize, + ) -> Self { + if size_of::() == 1 { + assert!(bit_depth <= 8); + } else { + assert!(bit_depth > 8); + } + let window = build_window(); let wscale = window.iter().copied().map(|w| w * w).sum::(); let sigma = sigma as f32 * wscale; let pmin = 0.0f32; let pmax = 500.0f32 * wscale; - let mut window_freq_real = Aligned::new([0f64; REAL_SIZE]); + let mut window_freq_real = Aligned::new([0f32; REAL_SIZE]); window_freq_real.iter_mut().zip(window.iter()).for_each(|(freq, w)| { - *freq = *w as f64 * 255.0; + *freq = *w as f32 * 255.0; }); let window_freq = rdft(&window_freq_real); - let sigma = Aligned::new([sigma; COMPLEX_SIZE]); + let w_pad_size = calc_pad_size(width); + let h_pad_size = calc_pad_size(height); + let pad_size = w_pad_size * h_pad_size; + let padded = vec![T::zero(); pad_size * TEMPORAL_SIZE]; + let padded2 = vec![0f32; pad_size]; Self { prev_frame: None, cur_frameno: 0, + bit_depth, sigma, window, window_freq, pmin, pmax, + padded, + padded2, } } @@ -79,58 +104,80 @@ where return Err(EncoderStatus::NeedMoreData); } + let next_frame = next_frame.cloned().flatten(); let orig_frame = frame_q.get(&self.cur_frameno).unwrap().as_ref().unwrap(); - let frames = once(self.prev_frame.clone()) - .chain(once(Some(Arc::clone(orig_frame)))) - .chain(next_frame.cloned()) - .collect::, 3>>(); + let frames = + once(self.prev_frame.clone().unwrap_or_else(|| Arc::clone(orig_frame))) + .chain(once(Arc::clone(orig_frame))) + .chain(once(next_frame.unwrap_or_else(|| Arc::clone(orig_frame)))) + .collect::>(); - todo!(); - // let mut dest = (**orig_frame).clone(); - // let mut pad = ArrayVec::<_, TB_SIZE>::new(); - // for i in 0..TB_SIZE { - // let dec = self.chroma_sampling.get_decimation().unwrap_or((0, 0)); - // let mut pad_frame = [ - // Plane::new( - // self.pad_dimensions[0].0, - // self.pad_dimensions[0].1, - // 0, - // 0, - // 0, - // 0, - // ), - // Plane::new( - // self.pad_dimensions[1].0, - // self.pad_dimensions[1].1, - // dec.0, - // dec.1, - // 0, - // 0, - // ), - // Plane::new( - // self.pad_dimensions[2].0, - // self.pad_dimensions[2].1, - // dec.0, - // dec.1, - // 0, - // 0, - // ), - // ]; + let mut dest = (**orig_frame).clone(); + for p in 0..3 { + let width = frames[0].planes[p].cfg.width; + let height = frames[0].planes[p].cfg.height; + let stride = frames[0].planes[p].cfg.stride; + let w_pad_size = calc_pad_size(width); + let h_pad_size = calc_pad_size(height); + let pad_size_spatial = w_pad_size * h_pad_size; + for i in 0..TEMPORAL_SIZE { + let src = &frames[i].planes[p]; + reflection_padding( + &mut self.padded[(i * pad_size_spatial)..], + src.data_origin(), + width, + height, + stride, + ) + } + + for i in 0..h_pad_size { + for j in 0..w_pad_size { + let mut block = [f32x16::default(); 7 * BLOCK_SIZE * 2]; + let offset_x = w_pad_size; + load_block( + &mut block, + &self.padded[((i * offset_x + j) * BLOCK_STEP)..], + width, + height, + self.bit_depth, + &self.window.data, + ); + fused( + &mut block, + self.sigma, + self.pmin, + self.pmax, + &self.window_freq.data, + ); + store_block( + &mut self.padded2[((i * offset_x + j) * BLOCK_STEP)..], + &block[(TEMPORAL_RADIUS * BLOCK_SIZE * 2)..], + width, + height, + &self.window[(TEMPORAL_RADIUS * BLOCK_SIZE * 2 * 16)..], + ); + todo!() + } + } - // let frame = frames.get(&i).unwrap_or(&frames[&TEMP_RADIUS]); - // self.copy_pad(frame, &mut pad_frame); - // pad.push(pad_frame); - // } - // self.do_filtering(&pad, &mut dest); + let offset_y = (h_pad_size - height) / 2; + let offset_x = (w_pad_size - width) / 2; + let dest_plane = &mut dest.planes[p]; + store_frame( + dest_plane.data_origin_mut(), + &self.padded2[(offset_y * w_pad_size + offset_x)..], + width, + height, + stride, + w_pad_size, + ); + } self.prev_frame = Some(Arc::clone(orig_frame)); self.cur_frameno += 1; - // Ok(dest) - } - - fn do_filtering(&mut self, src: &[[Plane; 3]], dest: &mut Frame) { - todo!(); + Ok(dest) } } @@ -203,7 +250,7 @@ fn normalize(window: &[f32; BLOCK_SIZE]) -> [f32; BLOCK_SIZE] { // Identical to Vapoursynth's implementation `vs_bitblt` // which basically copies the pixels in a plane. -fn bitblt( +pub fn bitblt( mut dest: &mut [T], dest_stride: usize, mut src: &[T], src_stride: usize, width: usize, height: usize, ) { @@ -219,11 +266,11 @@ fn bitblt( } fn rdft( - input: &Aligned<[f64; REAL_SIZE]>, -) -> Aligned<[Complex64; COMPLEX_SIZE]> { + input: &Aligned<[f32; REAL_SIZE]>, +) -> Aligned<[Complex32; COMPLEX_SIZE]> { const SHAPE: [usize; 3] = [TEMPORAL_SIZE, BLOCK_SIZE, BLOCK_SIZE]; - let mut output = Aligned::new([Complex64::zero(); COMPLEX_SIZE]); + let mut output = Aligned::new([Complex32::zero(); COMPLEX_SIZE]); for i in 0..(SHAPE[0] * SHAPE[1]) { dft( @@ -234,7 +281,7 @@ fn rdft( ); } - let mut output2 = Aligned::new([Complex64::zero(); COMPLEX_SIZE]); + let mut output2 = Aligned::new([Complex32::zero(); COMPLEX_SIZE]); let stride = SHAPE[2] / 2 + 1; for i in 0..SHAPE[0] { @@ -257,20 +304,20 @@ fn rdft( } enum DftInput<'a> { - Real(&'a [f64]), - Complex(&'a [Complex64]), + Real(&'a [f32]), + Complex(&'a [Complex32]), } #[inline(always)] -fn dft(output: &mut [Complex64], input: DftInput, n: usize, stride: usize) { +fn dft(output: &mut [Complex32], input: DftInput, n: usize, stride: usize) { match input { DftInput::Real(input) => { let out_num = n / 2 + 1; for i in 0..out_num { - let mut sum = Complex64::zero(); + let mut sum = Complex32::zero(); for j in 0..n { - let imag = -2f64 * i as f64 * j as f64 * PI64 / n as f64; - let weight = Complex64::new(imag.cos(), imag.sin()); + let imag = -2f32 * i as f32 * j as f32 * PI / n as f32; + let weight = Complex32::new(imag.cos(), imag.sin()); sum += input[j * stride] * weight; } output[i * stride] = sum; @@ -279,10 +326,10 @@ fn dft(output: &mut [Complex64], input: DftInput, n: usize, stride: usize) { DftInput::Complex(input) => { let out_num = n; for i in 0..out_num { - let mut sum = Complex64::zero(); + let mut sum = Complex32::zero(); for j in 0..n { - let imag = -2f64 * i as f64 * j as f64 * PI64 / n as f64; - let weight = Complex64::new(imag.cos(), imag.sin()); + let imag = -2f32 * i as f32 * j as f32 * PI / n as f32; + let weight = Complex32::new(imag.cos(), imag.sin()); sum += input[j * stride] * weight; } output[i * stride] = sum; @@ -290,3 +337,11 @@ fn dft(output: &mut [Complex64], input: DftInput, n: usize, stride: usize) { } } } + +#[inline(always)] +fn calc_pad_size(size: usize) -> usize { + size + + if size % BLOCK_SIZE > 0 { BLOCK_SIZE - size % BLOCK_SIZE } else { 0 } + + BLOCK_SIZE + - BLOCK_STEP.max(BLOCK_STEP) * 2 +} From b067c41a32865b775c61936a5b8756e49fa76019 Mon Sep 17 00:00:00 2001 From: Josh Holmer Date: Sat, 29 Oct 2022 23:36:00 -0400 Subject: [PATCH 07/11] This is kind of a lot of work --- src/denoise/kernel.rs | 118 ++++++++++++++++++++++++++++++++++++------ src/denoise/mod.rs | 16 ++++-- 2 files changed, 115 insertions(+), 19 deletions(-) diff --git a/src/denoise/kernel.rs b/src/denoise/kernel.rs index a283939b4a..623e977da8 100644 --- a/src/denoise/kernel.rs +++ b/src/denoise/kernel.rs @@ -6,10 +6,10 @@ use std::{ use crate::util::cast; use super::{ - bitblt, calc_pad_size, f32x16, BLOCK_SIZE, BLOCK_STEP, TEMPORAL_SIZE, + bitblt, calc_pad_size, f32x16, BLOCK_SIZE, TEMPORAL_RADIUS, TEMPORAL_SIZE, }; use av_metrics::video::Pixel; -use num_complex::Complex32; +use num_traits::clamp; use wide::{f32x8, u16x8, u8x16}; pub fn reflection_padding( @@ -72,7 +72,7 @@ pub fn reflection_padding( pub fn load_block( block: &mut [f32x16], shifted_src: &[T], width: usize, height: usize, - bit_depth: usize, window: &[f32], + bit_depth: usize, window: &[f32x16], ) { let scale = 1.0f32 / (1 << (bit_depth - 8)) as f32; let offset_x = calc_pad_size(width); @@ -106,33 +106,121 @@ pub fn load_block( let f32_lower = u16x8::new(*cast(&u16s[8..])).to_u32x8().to_f32x8(); [f32_upper, f32_lower] }; - let window = &window[(i * BLOCK_SIZE + j)..]; - let window_upper = f32x8::new(*cast(&window[..8])); - let window_lower = f32x8::new(*cast(&window[8..16])); - let result_upper = scale * window_upper * vec_input[0]; - let result_lower = scale * window_lower * vec_input[1]; + let window = &window[i * BLOCK_SIZE + j]; + let result_upper = scale * window[0] * vec_input[0]; + let result_lower = scale * window[1] * vec_input[1]; block[i * BLOCK_SIZE * 2 + j] = [result_upper, result_lower]; } } } pub fn fused( - output: &mut [f32x16], sigma: f32, pmin: f32, pmax: f32, - window_freq: &[Complex32], + block: &mut [f32x16], sigma: f32, pmin: f32, pmax: f32, + window_freq: &[f32x16], ) { - todo!() + for i in 0..TEMPORAL_SIZE { + transpose_16x16(&mut block[(i * 32)..]); + rdft::<16>(&mut block[(i * 32)..]); + transpose_32x16(&mut block[(i * 32)..]); + dft::<16>(&mut block[(i * 32)..]); + } + for i in 0..16 { + dft::<3>(&mut block[(i * 2)..], 16); + } + + let gf = block[0].extract(0) / window_freq[0].extract(0); + remove_mean(block, gf, window_freq); + + frequency_filtering(block, sigma, pmin, pmax); + + add_mean(block, gf, window_freq); + + for i in 0..16 { + idft::<3>(&mut block[(i * 2)..], 16); + } + idft::<16>(&mut block[(TEMPORAL_RADIUS * 32)..]); + transpose_32x16(&mut block[(TEMPORAL_RADIUS * 32)..]); + irdft::<16>(&mut block[(TEMPORAL_RADIUS * 32)..]); + post_irdft::<16>(&mut block[(TEMPORAL_RADIUS * 32)..]); + transpose_16x16(&mut block[(TEMPORAL_RADIUS * 32)..]); } pub fn store_block( - output: &mut [f32], input: &[f32x16], width: usize, height: usize, - window: &[f32], + shifted_dst: &mut [f32], shifted_block: &[f32x16], width: usize, + height: usize, shifted_window: &[f32x16], ) { - todo!() + let pad_size = calc_pad_size(width); + for i in 0..BLOCK_SIZE { + let acc = &mut shifted_dst[(i * pad_size)..]; + let mut acc_simd = + [f32x8::new(*cast(&acc[..8])), f32x8::new(*cast(&acc[8..16]))]; + acc_simd[0] = + shifted_block[i][0].mul_add(shifted_window[i][0], acc_simd[0]); + acc_simd[1] = + shifted_block[i][1].mul_add(shifted_window[i][1], acc_simd[1]); + acc[..8].copy_from_slice(acc_simd[0].as_array_ref()); + acc[8..16].copy_from_slice(acc_simd[1].as_array_ref()); + } } pub fn store_frame( output: &mut [T], shifted_src: &[f32], width: usize, height: usize, - dst_stride: usize, src_stride: usize, + bit_depth: usize, dst_stride: usize, src_stride: usize, ) { + let scale = 1.0f32 / (1 << (bit_depth - 8)) as f32; + let peak = (1u32 << bit_depth) - 1; + for y in 0..height { + for x in 0..width { + // SAFETY: We know the bounds of the planes for src and dest + unsafe { + let clamped = clamp( + (*shifted_src.get_unchecked(y * src_stride + x) / scale + 0.5f32) + as u32, + 0u32, + peak, + ); + *output.get_unchecked_mut(y * dst_stride + x) = T::cast_from(clamped); + } + } + } +} + +fn transpose_16x16() { + todo!() +} + +fn transpose_32x16() { + todo!() +} + +fn remove_mean() { + todo!() +} + +fn frequency_filtering() { + todo!() +} + +fn add_mean() { + todo!() +} + +fn rdft() { + todo!() +} + +fn dft() { + todo!() +} + +fn idft() { + todo!() +} + +fn irdft() { + todo!() +} + +fn post_irdft() { todo!() } diff --git a/src/denoise/mod.rs b/src/denoise/mod.rs index 6815ffa78d..f5978bdcf3 100644 --- a/src/denoise/mod.rs +++ b/src/denoise/mod.rs @@ -9,7 +9,7 @@ use num_complex::Complex32; use num_traits::Zero; use std::f32::consts::PI; use std::iter::once; -use std::mem::size_of; +use std::mem::{size_of, transmute}; use std::ptr::copy_nonoverlapping; use std::sync::Arc; use v_frame::frame::Frame; @@ -141,21 +141,28 @@ where width, height, self.bit_depth, - &self.window.data, + // SAFETY: We know that the window size is a multiple of 16 + unsafe { transmute(&self.window[..]) }, ); fused( &mut block, self.sigma, self.pmin, self.pmax, - &self.window_freq.data, + // SAFETY: We know that the window size is a multiple of 16 + unsafe { transmute(&self.window_freq[..]) }, ); store_block( &mut self.padded2[((i * offset_x + j) * BLOCK_STEP)..], &block[(TEMPORAL_RADIUS * BLOCK_SIZE * 2)..], width, height, - &self.window[(TEMPORAL_RADIUS * BLOCK_SIZE * 2 * 16)..], + // SAFETY: We know that the window size is a multiple of 16 + unsafe { + transmute( + &self.window[(TEMPORAL_RADIUS * BLOCK_SIZE * 2 * 16)..], + ) + }, ); todo!() } @@ -169,6 +176,7 @@ where &self.padded2[(offset_y * w_pad_size + offset_x)..], width, height, + self.bit_depth, stride, w_pad_size, ); From 413fd00f54430fd55272b4f8e5039d8c30189d6b Mon Sep 17 00:00:00 2001 From: Josh Holmer Date: Sun, 30 Oct 2022 04:20:20 -0400 Subject: [PATCH 08/11] Code is pain --- src/denoise/blend.rs | 528 ++++++++++++++++++++++++++++++++++++++++++ src/denoise/kernel.rs | 150 +++++++++++- src/denoise/mod.rs | 5 +- 3 files changed, 672 insertions(+), 11 deletions(-) create mode 100644 src/denoise/blend.rs diff --git a/src/denoise/blend.rs b/src/denoise/blend.rs new file mode 100644 index 0000000000..50e27bc35f --- /dev/null +++ b/src/denoise/blend.rs @@ -0,0 +1,528 @@ +//! A pox on the house of whoever decided this was a good way to code anything + +use std::mem::size_of; + +use arrayvec::ArrayVec; +use wide::{f32x8, f64x4}; + +use super::{f32x16, f64x8}; + +const V_DC: i32 = -256; + +pub fn blend16< + const I0: usize, + const I1: usize, + const I2: usize, + const I3: usize, + const I4: usize, + const I5: usize, + const I6: usize, + const I7: usize, + const I8: usize, + const I9: usize, + const I10: usize, + const I11: usize, + const I12: usize, + const I13: usize, + const I14: usize, + const I15: usize, +>( + a: f32x16, b: f32x16, +) -> f32x16 { + let x0 = blend_half16::(a, b); + let x1 = blend_half16::(a, b); + [x0, x1] +} + +fn blend_half16< + const I0: usize, + const I1: usize, + const I2: usize, + const I3: usize, + const I4: usize, + const I5: usize, + const I6: usize, + const I7: usize, +>( + a: f32x16, b: f32x16, +) -> f32x8 { + const N: usize = 8; + let ind: [usize; N] = [I0, I1, I2, I3, I4, I5, I6, I7]; + + // lambda to find which of the four possible sources are used + fn list_sources(ind: &[usize; N]) -> ArrayVec { + let mut source_used = [false; 4]; + for i in 0..N { + let ix = ind[i]; + let src = ix / N; + source_used[src & 3] = true; + } + // return a list of sources used. + let mut sources = ArrayVec::new(); + for i in 0..4 { + if source_used[i] { + sources.push(i); + } + } + sources + } + + let sources = list_sources(&ind); + if sources.is_empty() { + return f32x8::ZERO; + } + + // get indexes for the first one or two sources + let uindex = if sources.len() > 2 { 1 } else { 2 }; + let l = blend_half_indexes::<8>( + uindex, + sources.get(0).copied(), + sources.get(1).copied(), + &ind, + ); + let src0 = select_blend16(sources.get(0).copied(), a, b); + let src1 = select_blend16(sources.get(1).copied(), a, b); + let mut x0 = + blend8_f32(l[0], l[1], l[2], l[3], l[4], l[5], l[6], l[7], src0, src1); + + // get last one or two sources + if sources.len() > 2 { + let m = blend_half_indexes::<8>( + 1, + sources.get(2).copied(), + sources.get(3).copied(), + &ind, + ); + let src2 = select_blend16(sources.get(2).copied(), a, b); + let src3 = select_blend16(sources.get(3).copied(), a, b); + let x1 = + blend8_f32(m[0], m[1], m[2], m[3], m[4], m[5], m[6], m[7], src2, src3); + + // combine result of two blends. Unused elements are zero + x0 |= x1; + } + + x0 +} + +fn select_blend16(action: Option, a: f32x16, b: f32x16) -> f32x8 { + match action { + Some(0) => a[0], + Some(1) => a[1], + Some(2) => b[0], + _ => b[1], + } +} + +#[inline(always)] +fn blend8_f32( + i0: i32, i1: i32, i2: i32, i3: i32, i4: i32, i5: i32, i6: i32, i7: i32, + a: f32x8, b: f32x8, +) -> f32x8 { + let indexes = [i0, i1, i2, i3, i4, i5, i6, i7]; + let y = a; + let flags = blend_flags::<8, { size_of::() }>(&indexes); + todo!() +} + +pub fn blend8< + const I0: usize, + const I1: usize, + const I2: usize, + const I3: usize, + const I4: usize, + const I5: usize, + const I6: usize, + const I7: usize, +>( + a: f64x8, b: f64x8, +) -> f64x8 { + let x0 = blend_half8::(a, b); + let x1 = blend_half8::(a, b); + [x0, x1] +} + +fn blend_half8< + const I0: usize, + const I1: usize, + const I2: usize, + const I3: usize, +>( + a: f64x8, b: f64x8, +) -> f64x4 { + const N: usize = 4; + let ind: [usize; N] = [I0, I1, I2, I3]; + + // lambda to find which of the four possible sources are used + fn list_sources(ind: &[usize; N]) -> ArrayVec { + let mut source_used = [false; 4]; + for i in 0..N { + let ix = ind[i]; + let src = ix / N; + source_used[src & 3] = true; + } + // return a list of sources used. + let mut sources = ArrayVec::new(); + for i in 0..4 { + if source_used[i] { + sources.push(i); + } + } + sources + } + + let sources = list_sources(&ind); + if sources.is_empty() { + return f64x4::ZERO; + } + + // get indexes for the first one or two sources + let uindex = if sources.len() > 2 { 1 } else { 2 }; + let l = blend_half_indexes::<4>( + uindex, + sources.get(0).copied(), + sources.get(1).copied(), + &ind, + ); + let src0 = select_blend8(sources.get(0).copied(), a, b); + let src1 = select_blend8(sources.get(1).copied(), a, b); + let mut x0 = blend4_f64(l[0], l[1], l[2], l[3], src0, src1); + + // get last one or two sources + if sources.len() > 2 { + let m = blend_half_indexes::<4>( + 1, + sources.get(2).copied(), + sources.get(3).copied(), + &ind, + ); + let src2 = select_blend8(sources.get(2).copied(), a, b); + let src3 = select_blend8(sources.get(3).copied(), a, b); + let x1 = blend4_f64(m[0], m[1], m[2], m[3], src2, src3); + + // combine result of two blends. Unused elements are zero + x0 |= x1; + } + + x0 +} + +// blend_half_indexes: return an Indexlist for emulating a blend function as +// blends or permutations from multiple sources +// dozero = 0: let unused elements be don't care. Multiple permutation results must be blended +// dozero = 1: zero unused elements in each permuation. Multiple permutation results can be OR'ed +// dozero = 2: indexes that are -1 or V_DC are preserved +// src1, src2: sources to blend in a partial implementation +fn blend_half_indexes( + dozero: u8, src1: Option, src2: Option, ind: &[usize; N], +) -> ArrayVec { + // a is a reference to a constexpr array of permutation indexes + let mut list = ArrayVec::new(); + // value to use for unused entries + let u = if dozero > 0 { -1 } else { V_DC }; + + for j in 0..N { + let idx = ind[j]; + let src = idx / N; + list.push(if src1 == Some(src) { + (idx & (N - 1)) as i32 + } else if src2 == Some(src) { + ((idx & (N - 1)) + N) as i32 + } else { + u + }); + } + + list +} + +fn select_blend8(action: Option, a: f64x8, b: f64x8) -> f64x4 { + match action { + Some(0) => a[0], + Some(1) => a[1], + Some(2) => b[0], + _ => b[1], + } +} + +#[inline(always)] +fn blend4_f64( + i0: i32, i1: i32, i2: i32, i3: i32, a: f64x4, b: f64x4, +) -> f64x4 { + todo!() +} + +// blend_flags: returns information about how a blend function can be implemented +// The return value is composed of these flag bits: + +// needs zeroing +const BLEND_ZEROING: u64 = 1; +// all is zero or don't care +const BLEND_ALLZERO: u64 = 2; +// fits blend with a larger block size (e.g permute Vec2q instead of Vec4i) +const BLEND_LARGEBLOCK: u64 = 4; +// additional zeroing needed after blend with larger block size or shift +const BLEND_ADDZ: u64 = 8; +// has data from a +const BLEND_A: u64 = 0x10; +// has data from b +const BLEND_B: u64 = 0x20; +// permutation of a needed +const BLEND_PERMA: u64 = 0x40; +// permutation of b needed +const BLEND_PERMB: u64 = 0x80; +// permutation crossing 128-bit lanes +const BLEND_CROSS_LANE: u64 = 0x100; +// same permute/blend pattern in all 128-bit lanes +const BLEND_SAME_PATTERN: u64 = 0x200; +// pattern fits punpckh(a,b) +const BLEND_PUNPCKHAB: u64 = 0x1000; +// pattern fits punpckh(b,a) +const BLEND_PUNPCKHBA: u64 = 0x2000; +// pattern fits punpckl(a,b) +const BLEND_PUNPCKLAB: u64 = 0x4000; +// pattern fits punpckl(b,a) +const BLEND_PUNPCKLBA: u64 = 0x8000; +// pattern fits palignr(a,b) +const BLEND_ROTATEAB: u64 = 0x10000; +// pattern fits palignr(b,a) +const BLEND_ROTATEBA: u64 = 0x20000; +// pattern fits shufps/shufpd(a,b) +const BLEND_SHUFAB: u64 = 0x40000; +// pattern fits shufps/shufpd(b,a) +const BLEND_SHUFBA: u64 = 0x80000; +// pattern fits rotation across lanes. count returned in bits blend_rotpattern +const BLEND_ROTATE_BIG: u64 = 0x100000; +// index out of range +const BLEND_OUTOFRANGE: u64 = 0x10000000; +// pattern for shufps/shufpd is in bit blend_shufpattern to blend_shufpattern + 7 +const BLEND_SHUFPATTERN: u64 = 32; +// pattern for palignr is in bit blend_rotpattern to blend_rotpattern + 7 +const BLEND_ROTPATTERN: u64 = 40; + +// INSTRSET = 8 +fn blend_flags(a: &[i32; N]) -> u64 { + let mut r = BLEND_LARGEBLOCK | BLEND_SAME_PATTERN | BLEND_ALLZERO; + // number of 128-bit lanes + let n_lanes = V / 16; + // elements per lane + let lane_size = N / n_lanes; + // current lane + let mut lane = 0; + // rotate left count + let mut rot: u32 = 999; + // GeNERIc pARAMeTerS canNoT BE useD In consT ConTEXTs + let mut lane_pattern = vec![0; lane_size]; + if lane_size == 2 && N <= 8 { + r |= BLEND_SHUFAB | BLEND_SHUFBA; + } + + for ii in 0..N { + let ix = a[ii]; + if ix < 0 { + if ix == -1 { + r |= BLEND_ZEROING; + } else if ix != V_DC { + r = BLEND_OUTOFRANGE; + break; + } + } else { + r &= !BLEND_ALLZERO; + if ix < N as i32 { + r |= BLEND_A; + if ix != ii as i32 { + r |= BLEND_PERMA; + } + } else if ix < 2 * N as i32 { + r |= BLEND_B; + if ix != (ii + N) as i32 { + r |= BLEND_PERMB; + } + } else { + r = BLEND_OUTOFRANGE; + break; + } + } + + // check if pattern fits a larger block size: + // even indexes must be even, odd indexes must fit the preceding even index + 1 + if (ii & 1) == 0 { + if ix >= 0 && (ix & 1) > 0 { + r &= !BLEND_LARGEBLOCK; + } + let iy = a[ii + 1]; + if iy >= 0 && (iy & 1) == 0 { + r &= !BLEND_LARGEBLOCK; + } + if ix >= 0 && iy >= 0 && iy != ix + 1 { + r &= !BLEND_LARGEBLOCK; + } + if ix == -1 && iy >= 0 { + r |= BLEND_ADDZ; + } + if iy == -1 && ix >= 0 { + r |= BLEND_ADDZ; + } + } + + lane = ii / lane_size; + if lane == 0 { + lane_pattern[ii] = ix; + } + + // check if crossing lanes + if ix >= 0 { + let lane_i = (ix & !(N as i32)) as usize / lane_size; + if lane_i != lane { + r |= BLEND_CROSS_LANE; + } + if lane_size == 2 { + // check if it fits pshufd + if lane_i != lane { + r &= !(BLEND_SHUFAB | BLEND_SHUFBA); + } + if (((ix & (N as i32)) != 0) as usize ^ ii) & 1 > 0 { + r &= !BLEND_SHUFAB; + } else { + r &= !BLEND_SHUFBA; + } + } + } + + // check if same pattern in all lanes + if lane != 0 && ix >= 0 { + let j = ii - (lane * lane_size); + let jx = ix - (lane * lane_size) as i32; + if jx < 0 || (jx & !(N as i32)) >= lane_size as i32 { + r &= !BLEND_SAME_PATTERN; + } + if lane_pattern[j] < 0 { + lane_pattern[j] = jx; + } else if lane_pattern[j] != jx { + r &= !BLEND_SAME_PATTERN; + } + } + } + + if r & BLEND_LARGEBLOCK == 0 { + r &= !BLEND_ADDZ; + } + if r & BLEND_CROSS_LANE > 0 { + r &= !BLEND_SAME_PATTERN; + } + if r & (BLEND_PERMA | BLEND_PERMB) == 0 { + return r; + } + + if r & BLEND_SAME_PATTERN > 0 { + // same pattern in all lanes. check if it fits unpack patterns + r |= BLEND_PUNPCKHAB | BLEND_PUNPCKHBA | BLEND_PUNPCKLAB | BLEND_PUNPCKLBA; + for iu in 0..(lane_size as u32) { + let ix = lane_pattern[iu as usize]; + if ix >= 0 { + let ix = ix as u32; + if ix != iu / 2 + (iu & 1) * N as u32 { + r &= !BLEND_PUNPCKLAB; + } + if ix != iu / 2 + ((iu & 1) ^ 1) * N as u32 { + r &= !BLEND_PUNPCKLBA; + } + if ix != (iu + lane_size as u32) / 2 + (iu & 1) * N as u32 { + r &= !BLEND_PUNPCKHAB; + } + if ix != (iu + lane_size as u32) / 2 + ((iu & 1) ^ 1) * N as u32 { + r &= !BLEND_PUNPCKHBA; + } + } + } + + for iu in 0..(lane_size as u32) { + // check if it fits palignr + let ix = lane_pattern[iu as usize]; + if ix >= 0 { + let ix = ix as u32; + let t = ix & !(N as u32); + if (ix & N as u32) > 0 { + t += lane_size as u32; + } + let tb = (t + 2 * lane_size as u32 - iu) % (lane_size as u32 * 2); + if rot == 999 { + rot = tb; + } else if rot != tb { + rot = 1000; + } + } + } + if rot < 999 { + // fits palignr + if rot < lane_size as u32 { + r |= BLEND_ROTATEBA; + } else { + r |= BLEND_ROTATEAB; + } + let elem_size = (V / N) as u32; + r |= (((rot & (lane_size as u32 - 1)) * elem_size) as u64) + << BLEND_ROTPATTERN; + } + + if lane_size == 4 { + // check if it fits shufps + r |= BLEND_SHUFAB | BLEND_SHUFBA; + for ii in 0..2 { + let ix = lane_pattern[ii]; + if ix >= 0 { + if ix & N as i32 > 0 { + r &= !BLEND_SHUFAB; + } else { + r &= !BLEND_SHUFBA; + } + } + } + for ii in 2..4 { + let ix = lane_pattern[ii]; + if ix >= 0 { + if ix & N as i32 > 0 { + r &= !BLEND_SHUFBA; + } else { + r &= !BLEND_SHUFAB; + } + } + } + if r & (BLEND_SHUFAB | BLEND_SHUFBA) > 0 { + // fits shufps/shufpd + let shuf_pattern = 0u8; + for iu in 0..lane_size { + shuf_pattern |= ((lane_pattern[iu] & 3) as u8) << (iu * 2); + } + r |= (shuf_pattern as u64) << BLEND_SHUFPATTERN; + } + } + } else if n_lanes > 1 { + // not same pattern in all lanes + let mut rot = 999; + for ii in 0..N { + let ix = a[ii]; + if ix >= 0 { + let rot2: u32 = + (ix + 2 * N as i32 - ii as i32) as u32 % (2 * N) as u32; + if rot == 999 { + rot = rot2; + } else if rot != rot2 { + rot = 1000; + break; + } + } + } + if rot < 2 * N as u32 { + // fits big rotate + r |= BLEND_ROTATE_BIG | (rot as u64) << BLEND_ROTPATTERN; + } + } + if lane_size == 2 && (r & (BLEND_SHUFAB | BLEND_SHUFBA)) > 0 { + for ii in 0..N { + r |= ((a[ii] & 1) as u64) << (BLEND_SHUFPATTERN + ii as u64); + } + } + + r +} diff --git a/src/denoise/kernel.rs b/src/denoise/kernel.rs index 623e977da8..443439bea1 100644 --- a/src/denoise/kernel.rs +++ b/src/denoise/kernel.rs @@ -3,11 +3,11 @@ use std::{ ptr::copy_nonoverlapping, }; -use crate::util::cast; - +use super::blend::{blend16, blend8}; use super::{ bitblt, calc_pad_size, f32x16, BLOCK_SIZE, TEMPORAL_RADIUS, TEMPORAL_SIZE, }; +use crate::util::cast; use av_metrics::video::Pixel; use num_traits::clamp; use wide::{f32x8, u16x8, u8x16}; @@ -119,7 +119,7 @@ pub fn fused( window_freq: &[f32x16], ) { for i in 0..TEMPORAL_SIZE { - transpose_16x16(&mut block[(i * 32)..]); + transpose_16x16::<1>(&mut block[(i * 32)..]); rdft::<16>(&mut block[(i * 32)..]); transpose_32x16(&mut block[(i * 32)..]); dft::<16>(&mut block[(i * 32)..]); @@ -142,7 +142,7 @@ pub fn fused( transpose_32x16(&mut block[(TEMPORAL_RADIUS * 32)..]); irdft::<16>(&mut block[(TEMPORAL_RADIUS * 32)..]); post_irdft::<16>(&mut block[(TEMPORAL_RADIUS * 32)..]); - transpose_16x16(&mut block[(TEMPORAL_RADIUS * 32)..]); + transpose_16x16::<1>(&mut block[(TEMPORAL_RADIUS * 32)..]); } pub fn store_block( @@ -185,12 +185,142 @@ pub fn store_frame( } } -fn transpose_16x16() { - todo!() +fn transpose_16x16(block: &mut [f32x16]) { + for i in 0..2 { + for j in 0..2 { + for k in 0..2 { + let id1 = ((i * 2 + j) * 2 + k) * 2 * STRIDE; + let id2 = (((i * 2 + j) * 2 + k) * 2 + 1) * STRIDE; + let temp1 = blend16::< + 0, + 2, + 16, + 18, + 4, + 6, + 20, + 22, + 8, + 10, + 24, + 26, + 12, + 14, + 28, + 30, + >(block[id1], block[id2]); + let temp2 = blend16::< + 1, + 3, + 17, + 19, + 5, + 7, + 21, + 23, + 9, + 11, + 25, + 27, + 13, + 15, + 29, + 31, + >(block[id1], block[id2]); + block[id1] = temp1; + block[id2] = temp2; + } + + for k in 0..2 { + let id1 = (((i * 2 + j) * 2) * 2 + k) * STRIDE; + let id2 = (((i * 2 + j) * 2 + 1) * 2 + k) * STRIDE; + let temp1 = blend16::< + 0, + 2, + 16, + 18, + 4, + 6, + 20, + 22, + 8, + 10, + 24, + 26, + 12, + 14, + 28, + 30, + >(block[id1], block[id2]); + let temp2 = blend16::< + 1, + 3, + 17, + 19, + 5, + 7, + 21, + 23, + 9, + 11, + 25, + 27, + 13, + 15, + 29, + 31, + >(block[id1], block[id2]); + block[id1] = temp1; + block[id2] = temp2; + } + } + + for j in 0..4 { + let id1 = (i * 8 + j) * STRIDE; + let id2 = (i * 8 + 4 + j) * STRIDE; + // SAFETY: Types are the same size + let temp1 = unsafe { + transmute(blend8::<0, 1, 8, 9, 4, 5, 12, 13>( + transmute(block[id1]), + transmute(block[id2]), + )) + }; + // SAFETY: Types are the same size + let temp2 = unsafe { + transmute(blend8::<2, 3, 10, 11, 6, 7, 14, 15>( + transmute(block[id1]), + transmute(block[id2]), + )) + }; + block[id1] = temp1; + block[id2] = temp2; + } + } + + for i in 0..8 { + // SAFETY: Types are the same size + let temp1 = unsafe { + transmute(blend8::<0, 1, 2, 3, 8, 9, 10, 11>( + transmute(block[i * STRIDE]), + transmute(block[(i + 8) * STRIDE]), + )) + }; + // SAFETY: Types are the same size + let temp2 = unsafe { + transmute(blend8::<4, 5, 6, 7, 12, 13, 14, 15>( + transmute(block[i * STRIDE]), + transmute(block[(i + 8) * STRIDE]), + )) + }; + block[i * STRIDE] = temp1; + block[(i + 8) * STRIDE] = temp2; + } } -fn transpose_32x16() { - todo!() +fn transpose_32x16(block: &mut [f32x16]) { + const STRIDE: usize = 1; + transpose_16x16::<2>(block); + transpose_16x16::<2>(&mut block[STRIDE..]); } fn remove_mean() { @@ -205,11 +335,11 @@ fn add_mean() { todo!() } -fn rdft() { +fn rdft(block: &mut [f32x16]) { todo!() } -fn dft() { +fn dft(block: &mut [f32x16]) { todo!() } diff --git a/src/denoise/mod.rs b/src/denoise/mod.rs index f5978bdcf3..429c55f6ec 100644 --- a/src/denoise/mod.rs +++ b/src/denoise/mod.rs @@ -1,3 +1,4 @@ +mod blend; mod kernel; use crate::api::FrameQueue; @@ -14,7 +15,7 @@ use std::ptr::copy_nonoverlapping; use std::sync::Arc; use v_frame::frame::Frame; use v_frame::pixel::Pixel; -use wide::f32x8; +use wide::{f32x8, f64x4}; pub const TEMPORAL_RADIUS: usize = 1; const TEMPORAL_SIZE: usize = TEMPORAL_RADIUS * 2 + 1; @@ -27,6 +28,8 @@ const COMPLEX_SIZE: usize = TEMPORAL_SIZE * BLOCK_SIZE * (BLOCK_SIZE / 2 + 1); // for non-avx512 systems. `wide` doesn't have a f32x16 type, so we mimic it. #[allow(non_camel_case_types)] type f32x16 = [f32x8; 2]; +#[allow(non_camel_case_types)] +type f64x8 = [f64x4; 2]; /// This denoiser is based on the DFTTest2 plugin from Vapoursynth. /// This type of denoising was chosen because it provides From bad4ad07f029454d14726ba06bcd0f23b3858c0e Mon Sep 17 00:00:00 2001 From: Josh Holmer Date: Tue, 1 Nov 2022 00:10:23 -0400 Subject: [PATCH 09/11] This is a ridiculously stupid amount of code --- src/denoise/blend.rs | 662 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 597 insertions(+), 65 deletions(-) diff --git a/src/denoise/blend.rs b/src/denoise/blend.rs index 50e27bc35f..19411222e6 100644 --- a/src/denoise/blend.rs +++ b/src/denoise/blend.rs @@ -1,10 +1,12 @@ //! A pox on the house of whoever decided this was a good way to code anything -use std::mem::size_of; +use std::mem::{size_of, transmute}; use arrayvec::ArrayVec; use wide::{f32x8, f64x4}; +use crate::util::cast; + use super::{f32x16, f64x8}; const V_DC: i32 = -256; @@ -82,8 +84,7 @@ fn blend_half16< ); let src0 = select_blend16(sources.get(0).copied(), a, b); let src1 = select_blend16(sources.get(1).copied(), a, b); - let mut x0 = - blend8_f32(l[0], l[1], l[2], l[3], l[4], l[5], l[6], l[7], src0, src1); + let mut x0 = blend8_f32(cast(&l[..8]), src0, src1); // get last one or two sources if sources.len() > 2 { @@ -95,8 +96,7 @@ fn blend_half16< ); let src2 = select_blend16(sources.get(2).copied(), a, b); let src3 = select_blend16(sources.get(3).copied(), a, b); - let x1 = - blend8_f32(m[0], m[1], m[2], m[3], m[4], m[5], m[6], m[7], src2, src3); + let x1 = blend8_f32(cast(&m[..8]), src2, src3); // combine result of two blends. Unused elements are zero x0 |= x1; @@ -115,14 +115,74 @@ fn select_blend16(action: Option, a: f32x16, b: f32x16) -> f32x8 { } #[inline(always)] -fn blend8_f32( - i0: i32, i1: i32, i2: i32, i3: i32, i4: i32, i5: i32, i6: i32, i7: i32, - a: f32x8, b: f32x8, -) -> f32x8 { - let indexes = [i0, i1, i2, i3, i4, i5, i6, i7]; - let y = a; +fn blend8_f32(indexes: &[i32; 8], a: f32x8, b: f32x8) -> f32x8 { + let mut y = a; let flags = blend_flags::<8, { size_of::() }>(&indexes); - todo!() + + assert!(flags & BLEND_OUTOFRANGE == 0); + + if flags & BLEND_ALLZERO > 0 { + return f32x8::ZERO; + } + + if flags & BLEND_LARGEBLOCK > 0 { + // blend and permute 32-bit blocks + let l = largeblock_perm::<8, 4>(indexes); + // SAFETY: Types are of same size + let b4: f32x8 = + unsafe { transmute(blend4_f64(&l, transmute(a), transmute(b))) }; + if flags & BLEND_ADDZ == 0 { + // no remaining zeroing + return y; + } + } else if flags & BLEND_B == 0 { + // nothing from b. just permute a + return permute8(indexes, a); + } else if flags & BLEND_A == 0 { + let l = blend_perm_indexes::<8, 2>(indexes); + return permute8(cast(&l[8..]), b); + } else if flags & (BLEND_PERMA | BLEND_PERMB) == 0 { + // no permutation, only blending + let mb = make_bit_mask::<8>(0x303, indexes) as u8; + y = mb.blend(b, a); + } else if flags & BLEND_PUNPCKLAB > 0 { + // y = _mm256_unpacklo_ps(a, b); + todo!(); + } else if flags & BLEND_PUNPCKLBA > 0 { + // y = _mm256_unpacklo_ps(b, a); + todo!(); + } else if flags & BLEND_PUNPCKHAB > 0 { + // y = _mm256_unpackhi_ps(a, b); + todo!(); + } else if flags & BLEND_PUNPCKHBA > 0 { + // y = _mm256_unpackhi_ps(b, a); + todo!(); + } else if flags & BLEND_SHUFAB > 0 { + // use floating point instruction shufpd + // y = _mm256_shuffle_ps(a, b, (flags >> BLEND_SHUFPATTERN) as u8); + todo!(); + } else if flags & BLEND_SHUFBA > 0 { + // use floating point instruction shufpd + // y = _mm256_shuffle_ps(b, a, (flags >> BLEND_SHUFPATTERN) as u8); + todo!(); + } else { + // No special cases + // permute a and b separately, then blend. + let l = blend_perm_indexes::<8, 0>(indexes); + let ya = permute8(cast(&l[..8]), a); + let yb = permute8(cast(&l[8..]), b); + let mb = make_bit_mask::<8>(0x303, indexes) as u8; + y = mb.blend(yb, ya); + } + if flags & BLEND_ZEROING > 0 { + // additional zeroing needed + // let bm = zero_mask_broad::<8>(indexes); + // let bm1 = _mm256_loadu_si256(bm as __m256i); + // y = _mm256_and_ps(_mm256_castsi256_ps(bm1), y); + todo!(); + } + + y } pub fn blend8< @@ -186,7 +246,7 @@ fn blend_half8< ); let src0 = select_blend8(sources.get(0).copied(), a, b); let src1 = select_blend8(sources.get(1).copied(), a, b); - let mut x0 = blend4_f64(l[0], l[1], l[2], l[3], src0, src1); + let mut x0 = blend4_f64(cast(&l[..4]), src0, src1); // get last one or two sources if sources.len() > 2 { @@ -198,7 +258,7 @@ fn blend_half8< ); let src2 = select_blend8(sources.get(2).copied(), a, b); let src3 = select_blend8(sources.get(3).copied(), a, b); - let x1 = blend4_f64(m[0], m[1], m[2], m[3], src2, src3); + let x1 = blend4_f64(cast(&m[..4]), src2, src3); // combine result of two blends. Unused elements are zero x0 |= x1; @@ -246,9 +306,7 @@ fn select_blend8(action: Option, a: f64x8, b: f64x8) -> f64x4 { } #[inline(always)] -fn blend4_f64( - i0: i32, i1: i32, i2: i32, i3: i32, a: f64x4, b: f64x4, -) -> f64x4 { +fn blend4_f64(indexes: &[i32; 4], a: f64x4, b: f64x4) -> f64x4 { todo!() } @@ -302,7 +360,7 @@ const BLEND_ROTPATTERN: u64 = 40; // INSTRSET = 8 fn blend_flags(a: &[i32; N]) -> u64 { - let mut r = BLEND_LARGEBLOCK | BLEND_SAME_PATTERN | BLEND_ALLZERO; + let mut ret = BLEND_LARGEBLOCK | BLEND_SAME_PATTERN | BLEND_ALLZERO; // number of 128-bit lanes let n_lanes = V / 16; // elements per lane @@ -312,34 +370,34 @@ fn blend_flags(a: &[i32; N]) -> u64 { // rotate left count let mut rot: u32 = 999; // GeNERIc pARAMeTerS canNoT BE useD In consT ConTEXTs - let mut lane_pattern = vec![0; lane_size]; + let mut lane_pattern = vec![0; lane_size].into_boxed_slice(); if lane_size == 2 && N <= 8 { - r |= BLEND_SHUFAB | BLEND_SHUFBA; + ret |= BLEND_SHUFAB | BLEND_SHUFBA; } for ii in 0..N { let ix = a[ii]; if ix < 0 { if ix == -1 { - r |= BLEND_ZEROING; + ret |= BLEND_ZEROING; } else if ix != V_DC { - r = BLEND_OUTOFRANGE; + ret = BLEND_OUTOFRANGE; break; } } else { - r &= !BLEND_ALLZERO; + ret &= !BLEND_ALLZERO; if ix < N as i32 { - r |= BLEND_A; + ret |= BLEND_A; if ix != ii as i32 { - r |= BLEND_PERMA; + ret |= BLEND_PERMA; } } else if ix < 2 * N as i32 { - r |= BLEND_B; + ret |= BLEND_B; if ix != (ii + N) as i32 { - r |= BLEND_PERMB; + ret |= BLEND_PERMB; } } else { - r = BLEND_OUTOFRANGE; + ret = BLEND_OUTOFRANGE; break; } } @@ -348,20 +406,20 @@ fn blend_flags(a: &[i32; N]) -> u64 { // even indexes must be even, odd indexes must fit the preceding even index + 1 if (ii & 1) == 0 { if ix >= 0 && (ix & 1) > 0 { - r &= !BLEND_LARGEBLOCK; + ret &= !BLEND_LARGEBLOCK; } let iy = a[ii + 1]; if iy >= 0 && (iy & 1) == 0 { - r &= !BLEND_LARGEBLOCK; + ret &= !BLEND_LARGEBLOCK; } if ix >= 0 && iy >= 0 && iy != ix + 1 { - r &= !BLEND_LARGEBLOCK; + ret &= !BLEND_LARGEBLOCK; } if ix == -1 && iy >= 0 { - r |= BLEND_ADDZ; + ret |= BLEND_ADDZ; } if iy == -1 && ix >= 0 { - r |= BLEND_ADDZ; + ret |= BLEND_ADDZ; } } @@ -374,17 +432,17 @@ fn blend_flags(a: &[i32; N]) -> u64 { if ix >= 0 { let lane_i = (ix & !(N as i32)) as usize / lane_size; if lane_i != lane { - r |= BLEND_CROSS_LANE; + ret |= BLEND_CROSS_LANE; } if lane_size == 2 { // check if it fits pshufd if lane_i != lane { - r &= !(BLEND_SHUFAB | BLEND_SHUFBA); + ret &= !(BLEND_SHUFAB | BLEND_SHUFBA); } if (((ix & (N as i32)) != 0) as usize ^ ii) & 1 > 0 { - r &= !BLEND_SHUFAB; + ret &= !BLEND_SHUFAB; } else { - r &= !BLEND_SHUFBA; + ret &= !BLEND_SHUFBA; } } } @@ -394,44 +452,45 @@ fn blend_flags(a: &[i32; N]) -> u64 { let j = ii - (lane * lane_size); let jx = ix - (lane * lane_size) as i32; if jx < 0 || (jx & !(N as i32)) >= lane_size as i32 { - r &= !BLEND_SAME_PATTERN; + ret &= !BLEND_SAME_PATTERN; } if lane_pattern[j] < 0 { lane_pattern[j] = jx; } else if lane_pattern[j] != jx { - r &= !BLEND_SAME_PATTERN; + ret &= !BLEND_SAME_PATTERN; } } } - if r & BLEND_LARGEBLOCK == 0 { - r &= !BLEND_ADDZ; + if ret & BLEND_LARGEBLOCK == 0 { + ret &= !BLEND_ADDZ; } - if r & BLEND_CROSS_LANE > 0 { - r &= !BLEND_SAME_PATTERN; + if ret & BLEND_CROSS_LANE > 0 { + ret &= !BLEND_SAME_PATTERN; } - if r & (BLEND_PERMA | BLEND_PERMB) == 0 { - return r; + if ret & (BLEND_PERMA | BLEND_PERMB) == 0 { + return ret; } - if r & BLEND_SAME_PATTERN > 0 { + if ret & BLEND_SAME_PATTERN > 0 { // same pattern in all lanes. check if it fits unpack patterns - r |= BLEND_PUNPCKHAB | BLEND_PUNPCKHBA | BLEND_PUNPCKLAB | BLEND_PUNPCKLBA; + ret |= + BLEND_PUNPCKHAB | BLEND_PUNPCKHBA | BLEND_PUNPCKLAB | BLEND_PUNPCKLBA; for iu in 0..(lane_size as u32) { let ix = lane_pattern[iu as usize]; if ix >= 0 { let ix = ix as u32; if ix != iu / 2 + (iu & 1) * N as u32 { - r &= !BLEND_PUNPCKLAB; + ret &= !BLEND_PUNPCKLAB; } if ix != iu / 2 + ((iu & 1) ^ 1) * N as u32 { - r &= !BLEND_PUNPCKLBA; + ret &= !BLEND_PUNPCKLBA; } if ix != (iu + lane_size as u32) / 2 + (iu & 1) * N as u32 { - r &= !BLEND_PUNPCKHAB; + ret &= !BLEND_PUNPCKHAB; } if ix != (iu + lane_size as u32) / 2 + ((iu & 1) ^ 1) * N as u32 { - r &= !BLEND_PUNPCKHBA; + ret &= !BLEND_PUNPCKHBA; } } } @@ -456,25 +515,25 @@ fn blend_flags(a: &[i32; N]) -> u64 { if rot < 999 { // fits palignr if rot < lane_size as u32 { - r |= BLEND_ROTATEBA; + ret |= BLEND_ROTATEBA; } else { - r |= BLEND_ROTATEAB; + ret |= BLEND_ROTATEAB; } let elem_size = (V / N) as u32; - r |= (((rot & (lane_size as u32 - 1)) * elem_size) as u64) + ret |= (((rot & (lane_size as u32 - 1)) * elem_size) as u64) << BLEND_ROTPATTERN; } if lane_size == 4 { // check if it fits shufps - r |= BLEND_SHUFAB | BLEND_SHUFBA; + ret |= BLEND_SHUFAB | BLEND_SHUFBA; for ii in 0..2 { let ix = lane_pattern[ii]; if ix >= 0 { if ix & N as i32 > 0 { - r &= !BLEND_SHUFAB; + ret &= !BLEND_SHUFAB; } else { - r &= !BLEND_SHUFBA; + ret &= !BLEND_SHUFBA; } } } @@ -482,19 +541,19 @@ fn blend_flags(a: &[i32; N]) -> u64 { let ix = lane_pattern[ii]; if ix >= 0 { if ix & N as i32 > 0 { - r &= !BLEND_SHUFBA; + ret &= !BLEND_SHUFBA; } else { - r &= !BLEND_SHUFAB; + ret &= !BLEND_SHUFAB; } } } - if r & (BLEND_SHUFAB | BLEND_SHUFBA) > 0 { + if ret & (BLEND_SHUFAB | BLEND_SHUFBA) > 0 { // fits shufps/shufpd let shuf_pattern = 0u8; for iu in 0..lane_size { shuf_pattern |= ((lane_pattern[iu] & 3) as u8) << (iu * 2); } - r |= (shuf_pattern as u64) << BLEND_SHUFPATTERN; + ret |= (shuf_pattern as u64) << BLEND_SHUFPATTERN; } } } else if n_lanes > 1 { @@ -515,14 +574,487 @@ fn blend_flags(a: &[i32; N]) -> u64 { } if rot < 2 * N as u32 { // fits big rotate - r |= BLEND_ROTATE_BIG | (rot as u64) << BLEND_ROTPATTERN; + ret |= BLEND_ROTATE_BIG | (rot as u64) << BLEND_ROTPATTERN; } } - if lane_size == 2 && (r & (BLEND_SHUFAB | BLEND_SHUFBA)) > 0 { + if lane_size == 2 && (ret & (BLEND_SHUFAB | BLEND_SHUFBA)) > 0 { for ii in 0..N { - r |= ((a[ii] & 1) as u64) << (BLEND_SHUFPATTERN + ii as u64); + ret |= ((a[ii] & 1) as u64) << (BLEND_SHUFPATTERN + ii as u64); + } + } + + ret +} + +// largeblock_perm: return indexes for replacing a permute or blend with +// a certain block size by a permute or blend with the double block size. +// Note: it is presupposed that perm_flags() indicates perm_largeblock +// It is required that additional zeroing is added if perm_flags() indicates perm_addz +fn largeblock_perm( + a: &[i32; N], +) -> [i32; N2] { + // GeNERIc pARAMeTerS canNoT BE useD In consT ConTEXTs + assert!(N2 == N / 2); + + // Parameter a is a reference to a constexpr array of permutation indexes + let mut list = [0; N2]; + let mut fit_addz = false; + + // check if additional zeroing is needed at current block size + for i in (0..N).step_by(2) { + let ix = a[i]; + let iy = a[i + 1]; + if (ix == -1 && iy >= 0) || (iy == -1 && ix >= 0) { + fit_addz = true; + break; + } + } + + // loop through indexes + for i in (0..N).step_by(2) { + let ix = a[i]; + let iy = a[i + 1]; + let iz = if ix >= 0 { + ix / 2 + } else if iy >= 0 { + iy / 2 + } else if fit_addz { + V_DC + } else { + ix | iy + }; + list[i / 2] = iz; + } + + list +} + +// perm_flags: returns information about how a permute can be implemented. +// The return value is composed of these flag bits: + +// needs zeroing +const PERM_ZEROING: u64 = 1; +// permutation needed +const PERM_PERM: u64 = 2; +// all is zero or don't care +const PERM_ALLZERO: u64 = 4; +// fits permute with a larger block size (e.g permute Vec2q instead of Vec4i) +const PERM_LARGEBLOCK: u64 = 8; +// additional zeroing needed after permute with larger block size or shift +const PERM_ADDZ: u64 = 0x10; +// additional zeroing needed after perm_zext, perm_compress, or perm_expand +const PERM_ADDZ2: u64 = 0x20; +// permutation crossing 128-bit lanes +const PERM_CROSS_LANE: u64 = 0x40; +// same permute pattern in all 128-bit lanes +const PERM_SAME_PATTERN: u64 = 0x80; +// permutation pattern fits punpckh instruction +const PERM_PUNPCKH: u64 = 0x100; +// permutation pattern fits punpckl instruction +const PERM_PUNPCKL: u64 = 0x200; +// permutation pattern fits rotation within lanes. 4 bit count returned in bit perm_rot_count +const PERM_ROTATE: u64 = 0x400; +// permutation pattern fits shift right within lanes. 4 bit count returned in bit perm_rot_count +const PERM_SHRIGHT: u64 = 0x1000; +// permutation pattern fits shift left within lanes. negative count returned in bit perm_rot_count +const PERM_SHLEFT: u64 = 0x2000; +// permutation pattern fits rotation across lanes. 6 bit count returned in bit perm_rot_count +const PERM_ROTATE_BIG: u64 = 0x4000; +// permutation pattern fits broadcast of a single element. +const PERM_BROADCAST: u64 = 0x8000; +// permutation pattern fits zero extension +const PERM_ZEXT: u64 = 0x10000; +// permutation pattern fits vpcompress instruction +const PERM_COMPRESS: u64 = 0x20000; +// permutation pattern fits vpexpand instruction +const PERM_EXPAND: u64 = 0x40000; +// index out of range +const PERM_OUTOFRANGE: u64 = 0x10000000; +// rotate or shift count is in bits perm_rot_count to perm_rot_count+3 +const PERM_ROT_COUNT: u64 = 32; +// pattern for pshufd is in bit perm_ipattern to perm_ipattern + 7 if perm_same_pattern and elementsize >= 4 +const PERM_IPATTERN: u64 = 40; + +fn permute8(indexes: &[i32; 8], a: f32x8) -> f32x8 { + let mut y = a; + let flags = perm_flags::<32, 8>(indexes); + assert!( + flags & PERM_OUTOFRANGE == 0, + "Index out of range in permute function" + ); + + if flags & PERM_ALLZERO > 0 { + return f32x8::ZERO; + } + + if flags & PERM_PERM > 0 { + if flags & PERM_LARGEBLOCK > 0 { + // constexpr EList L = + // largeblock_perm<8>(indexs); + // y = _mm256_castpd_ps( + // permute4(Vec4d(_mm256_castps_pd(a)))); + if flags & PERM_ADDZ == 0 { + // no remaining zeroing + return y; + } + } else if flags & PERM_SAME_PATTERN > 0 { + if flags & PERM_PUNPCKH != 0 { + // fits punpckhi + y = _mm256_unpackhi_ps(y, y); + } else if flags & PERM_PUNPCKL != 0 { + // fits punpcklo + y = _mm256_unpacklo_ps(y, y); + } else { + // general permute, same pattern in both lanes + y = _mm256_shuffle_ps(a, a, flags >> PERM_IPATTERN as u8); + } + } else if flags & PERM_BROADCAST > 0 && flags >> PERM_ROT_COUNT == 0 { + // broadcast first element + // y = _mm256_broadcastss_ps( + // _mm256_castps256_ps128(y)); + todo!(); + } else if flags & PERM_ZEXT > 0 { + // zero extension + // y = _mm256_castsi256_ps(_mm256_cvtepu32_epi64( + // _mm256_castsi256_si128(_mm256_castps_si256(y)))); + todo!(); + if flags & PERM_ADDZ2 == 0 { + return y; + } + } else if flags & PERM_CROSS_LANE == 0 { + // __m256 m = constant8f(); + // y = _mm256_permutevar_ps(a, _mm256_castps_si256(m)); + todo!(); + } else { + // full permute needed + // __m256i permmask = + // _mm256_castps_si256(constant8f()); + // y = _mm256_permutevar8x32_ps(a, permmask); + todo!(); + } + } + + if flags & PERM_ZEROING > 0 { + // constexpr EList bm = zero_mask_broad(indexs); + // __m256i bm1 = _mm256_loadu_si256((const __m256i *)(bm.a)); + // y = _mm256_and_ps(_mm256_castsi256_ps(bm1), y); + todo!(); + } + + y +} + +fn perm_flags( + a: &[i32; ELEMS], +) -> u64 { + // number of 128-bit lanes + let num_lanes = ELEM_SIZE * ELEMS / 16; + let lane_size = ELEMS / num_lanes; + // current lane + let mut lane = 0usize; + // rotate left count + let mut rot = 999u32; + // index to broadcasted element + let mut broadc = 999i32; + // remember certain patterns that do not fit + let mut patfail = 0u32; + // remember certain patterns need extra zeroing + let mut addz2 = 0u32; + // last index in perm_compress fit + let mut compresslasti = -1i32; + // last position in perm_compress fit + let mut compresslastp = -1i32; + // last index in perm_expand fit + let mut expandlasti = -1i32; + // last position in perm_expand fit + let mut expandlastp = -1i32; + + let mut ret = PERM_LARGEBLOCK | PERM_SAME_PATTERN | PERM_ALLZERO; + let mut lane_pattern = vec![0i32; lane_size].into_boxed_slice(); + for i in 0..ELEMS { + let ix = a[i]; + // meaning of ix: -1 = set to zero, V_DC = don't care, non-negative value = permute. + if ix == -1 { + ret |= PERM_ZEROING; + } else if ix != V_DC && ix as usize >= ELEMS { + ret |= PERM_OUTOFRANGE; + } + + if ix >= 0 { + ret &= !PERM_ALLZERO; + if ix != i as i32 { + ret |= PERM_PERM; + } + if broadc == 999 { + // remember broadcast index + broadc = ix; + } else if broadc != ix { + // does not fit broadcast + broadc = 1000; + } + } + + // check if pattern fits a larger block size: + // even indexes must be even, odd indexes must fit the preceding even index + 1 + if i & 1 == 0 { + if ix > 0 && ix & 1 > 0 { + ret &= !PERM_LARGEBLOCK; + } + let iy = a[i + 1]; + if iy >= 0 && iy & 1 == 0 { + ret &= !PERM_LARGEBLOCK; + } + if ix >= 0 && iy >= 0 && iy != ix + 1 { + ret &= !PERM_LARGEBLOCK; + } + if ix == -1 && iy >= 0 { + ret |= PERM_ADDZ; + } + if iy == -1 && ix >= 0 { + ret |= PERM_ADDZ; + } + } + + lane = i / lane_size; + if lane == 0 { + // first lane, or no pattern yet + lane_pattern[i] = ix; + } + + // check if crossing lanes + if ix >= 0 { + let lanei = ix as usize / lane_size; + if lanei != lane { + ret |= PERM_CROSS_LANE; + } + } + + // check if same pattern in all lanes + if lane != 0 && ix >= 0 { + // not first lane + let j1 = i - lane * lane_size; + let jx = ix - (lane * lane_size) as i32; + if jx < 0 || jx >= lane_size as i32 { + // source is in another lane + ret &= !PERM_SAME_PATTERN; + } + if lane_pattern[j1] < 0 { + // pattern not known from previous lane + lane_pattern[j1] = jx; + } else if lane_pattern[j1] != jx { + // not same pattern + ret &= !PERM_SAME_PATTERN; + } + } + + if ix >= 0 { + // check if pattern fits zero extension (perm_zext) + if (ix * 2) as usize != i { + // does not fit zero extension + patfail |= 1; + } + // check if pattern fits compress (perm_compress) + if ix > compresslasti && ix - compresslasti >= i as i32 - compresslastp { + if i as i32 - compresslastp > 1 { + // perm_compress may need additional zeroing + addz2 |= 2; + } + compresslasti = ix; + compresslastp = i as i32; + } else { + // does not fit perm_compress + patfail |= 2; + } + // check if pattern fits expand (perm_expand) + if ix > expandlasti && ix - expandlasti <= i as i32 - expandlastp { + if ix - expandlasti > 1 { + // perm_expand may need additional zeroing + addz2 |= 4; + } + expandlasti = ix; + expandlastp = i as i32; + } else { + // does not fit perm_compress + patfail |= 4; + } + } else if ix == -1 && i & 1 == 0 { + // zero extension needs additional zeroing + addz2 |= 1; + } + } + + if ret & PERM_PERM == 0 { + return ret; + } + + if ret & PERM_LARGEBLOCK == 0 { + ret &= !PERM_ADDZ; + } + if ret & PERM_CROSS_LANE > 0 { + ret &= !PERM_SAME_PATTERN; + } + if patfail & 1 == 0 { + ret |= PERM_ZEXT; + if addz2 & 1 != 0 { + ret |= PERM_ADDZ2; + } + } else if patfail & 2 == 0 { + ret |= PERM_COMPRESS; + if addz2 & 2 != 0 && compresslastp > 0 { + for j in 0..(compresslastp as usize) { + if a[j] == -1 { + ret |= PERM_ADDZ2; + } + } + } + } else if patfail & 4 == 0 { + ret |= PERM_EXPAND; + if addz2 & 4 != 0 && expandlastp > 0 { + for j in 0..(expandlastp as usize) { + if a[j] == -1 { + ret |= PERM_ADDZ2; + } + } + } + } + + if ret & PERM_SAME_PATTERN > 0 { + // same pattern in all lanes. check if it fits specific patterns + let mut fit = true; + // fit shift or rotate + for i in 0..lane_size { + if lane_pattern[i] >= 0 { + let rot1 = lane_pattern[i] as u32 + lane_size as u32 + - i as u32 % lane_size as u32; + if rot == 999 { + rot = rot1; + } else if rot != rot1 { + fit = false; + } + } + } + rot &= lane_size as u32 - 1; + if fit { + // fits rotate, and possible shift + let rot2 = ((rot & ELEM_SIZE as u32) & 0xF) as u64; + ret |= rot2 << PERM_ROT_COUNT; + ret |= PERM_ROTATE; + // fit shift left + fit = true; + let mut i = 0; + while (i + rot as usize) < lane_size { + // check if first rot elements are zero or don't care + if lane_pattern[i] >= 0 { + fit = false; + } + i += 1; + } + if fit { + ret |= PERM_SHLEFT; + while i < lane_size { + if lane_pattern[i] == -1 { + ret |= PERM_ADDZ; + } + i += 1; + } + } + // fit shift right + fit = true; + i = lane_size - rot as usize; + while i < lane_size { + // check if last (lanesize-rot) elements are zero or don't care + if lane_pattern[i] >= 0 { + fit = false; + } + i += 1; + } + if fit { + ret |= PERM_SHRIGHT; + while i < lane_size - rot as usize { + if lane_pattern[i] == -1 { + ret |= PERM_ADDZ; + } + i += 1; + } + } + } + + // fit punpckhi + fit = true; + let mut j2 = lane_size / 2; + for i in 0..lane_size { + if lane_pattern[i] >= 0 && lane_pattern[i] != j2 as i32 { + fit = false; + } + if (i & 1) != 0 { + j2 += 1; + } + } + if fit { + ret |= PERM_PUNPCKH; + } + // fit punpcklo + fit = true; + j2 = 0; + for i in 0..lane_size { + if lane_pattern[i] >= 0 && lane_pattern[i] != j2 as i32 { + fit = false; + } + if (i & 1) != 0 { + j2 += 1; + } + } + if fit { + ret |= PERM_PUNPCKL; + } + // fit pshufd + if ELEM_SIZE >= 4 { + let mut p = 0u64; + for i in 0..lane_size { + if lane_size == 4 { + p |= ((lane_pattern[i] & 3) as u64) << (2 * i as u64); + } else { + // lanesize = 2 + p |= ((lane_pattern[i] & 1) as u64 * 10 + 4) << (4 * i as u64); + } + } + ret |= p << PERM_IPATTERN; } + } else { + // not same pattern in all lanes + if num_lanes > 1 { + // Try if it fits big rotate + for i in 0..ELEMS { + let ix = a[i]; + if ix >= 0 { + // rotate count + let mut rot2: u32 = + (ix as u32 + ELEMS as u32 - i as u32) % ELEMS as u32; + if rot == 999 { + // save rotate count + rot = rot2; + } else if rot != rot2 { + // does not fit big rotate + rot = 1000; + break; + } + } + } + if rot < ELEMS as u32 { + // fits big rotate + ret |= PERM_ROTATE_BIG | (rot as u64) << PERM_ROT_COUNT; + } + } + } + + if broadc < 999 + && ret & (PERM_ROTATE | PERM_SHRIGHT | PERM_SHLEFT | PERM_ROTATE_BIG) == 0 + { + // fits broadcast + ret |= PERM_BROADCAST | (broadc as u64) << PERM_ROT_COUNT; } - r + ret } From 22346250aea6a89ae4abb688ac21aa00241352c4 Mon Sep 17 00:00:00 2001 From: Josh Holmer Date: Sun, 6 Nov 2022 00:06:10 -0400 Subject: [PATCH 10/11] More code --- Cargo.toml | 2 +- src/denoise/blend.rs | 202 ++++++++++++++++++++++++++++++++++++------- 2 files changed, 170 insertions(+), 34 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 664cd428ec..eee3df5814 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -110,7 +110,7 @@ av1-grain = { version = "0.2.0", features = ["serialize"] } serde-big-array = { version = "0.4.1", optional = true } # Used for parsing film grain table files nom = "7.0.0" -wide = { git = "https://github.com/shssoichiro/wide", branch = "fast-cast-fns" } +wide = { git = "https://github.com/shssoichiro/wide", branch = "additional-functions" } num-complex = "0.4.2" [dependencies.image] diff --git a/src/denoise/blend.rs b/src/denoise/blend.rs index 19411222e6..055abda27b 100644 --- a/src/denoise/blend.rs +++ b/src/denoise/blend.rs @@ -139,40 +139,52 @@ fn blend8_f32(indexes: &[i32; 8], a: f32x8, b: f32x8) -> f32x8 { // nothing from b. just permute a return permute8(indexes, a); } else if flags & BLEND_A == 0 { - let l = blend_perm_indexes::<8, 2>(indexes); + let l = blend_perm_indexes::<8, 16, 2>(indexes); return permute8(cast(&l[8..]), b); } else if flags & (BLEND_PERMA | BLEND_PERMB) == 0 { // no permutation, only blending - let mb = make_bit_mask::<8>(0x303, indexes) as u8; - y = mb.blend(b, a); + let mb = make_bit_mask::<8, 0x303>(indexes) as u8; + y = f32x8::new([ + if mb & 0b10000000 > 0 { -1f32 } else { 0f32 }, + if mb & 0b01000000 > 0 { -1f32 } else { 0f32 }, + if mb & 0b00100000 > 0 { -1f32 } else { 0f32 }, + if mb & 0b00010000 > 0 { -1f32 } else { 0f32 }, + if mb & 0b00001000 > 0 { -1f32 } else { 0f32 }, + if mb & 0b00000100 > 0 { -1f32 } else { 0f32 }, + if mb & 0b00000010 > 0 { -1f32 } else { 0f32 }, + if mb & 0b00000001 > 0 { -1f32 } else { 0f32 }, + ]) + .blend(b, a); } else if flags & BLEND_PUNPCKLAB > 0 { - // y = _mm256_unpacklo_ps(a, b); - todo!(); + y = a.interleave_low(b); } else if flags & BLEND_PUNPCKLBA > 0 { - // y = _mm256_unpacklo_ps(b, a); - todo!(); + y = b.interleave_low(a); } else if flags & BLEND_PUNPCKHAB > 0 { - // y = _mm256_unpackhi_ps(a, b); - todo!(); + y = a.interleave_high(b); } else if flags & BLEND_PUNPCKHBA > 0 { - // y = _mm256_unpackhi_ps(b, a); - todo!(); + y = b.interleave_high(a); } else if flags & BLEND_SHUFAB > 0 { - // use floating point instruction shufpd - // y = _mm256_shuffle_ps(a, b, (flags >> BLEND_SHUFPATTERN) as u8); - todo!(); + y = a.shuffle_nonconst(b, (flags >> BLEND_SHUFPATTERN) as u8); } else if flags & BLEND_SHUFBA > 0 { - // use floating point instruction shufpd - // y = _mm256_shuffle_ps(b, a, (flags >> BLEND_SHUFPATTERN) as u8); - todo!(); + y = b.shuffle_nonconst(a, (flags >> BLEND_SHUFPATTERN) as u8); } else { // No special cases // permute a and b separately, then blend. - let l = blend_perm_indexes::<8, 0>(indexes); + let l = blend_perm_indexes::<8, 16, 0>(indexes); let ya = permute8(cast(&l[..8]), a); let yb = permute8(cast(&l[8..]), b); - let mb = make_bit_mask::<8>(0x303, indexes) as u8; - y = mb.blend(yb, ya); + let mb = make_bit_mask::<8, 0x303>(indexes) as u8; + y = f32x8::new([ + if mb & 0b10000000 > 0 { -1f32 } else { 0f32 }, + if mb & 0b01000000 > 0 { -1f32 } else { 0f32 }, + if mb & 0b00100000 > 0 { -1f32 } else { 0f32 }, + if mb & 0b00010000 > 0 { -1f32 } else { 0f32 }, + if mb & 0b00001000 > 0 { -1f32 } else { 0f32 }, + if mb & 0b00000100 > 0 { -1f32 } else { 0f32 }, + if mb & 0b00000010 > 0 { -1f32 } else { 0f32 }, + if mb & 0b00000001 > 0 { -1f32 } else { 0f32 }, + ]) + .blend(yb, ya); } if flags & BLEND_ZEROING > 0 { // additional zeroing needed @@ -307,7 +319,78 @@ fn select_blend8(action: Option, a: f64x8, b: f64x8) -> f64x4 { #[inline(always)] fn blend4_f64(indexes: &[i32; 4], a: f64x4, b: f64x4) -> f64x4 { - todo!() + let flags = blend_flags::<4, { size_of::() }>(indexes); + assert!( + flags & BLEND_OUTOFRANGE == 0, + "Index out of range in blend function" + ); + + if flags & BLEND_ALLZERO != 0 { + return f64x4::ZERO; + } + + if flags & BLEND_B == 0 { + return permute4(indexes, a); + } + if flags & BLEND_A == 0 { + return permute4( + &[ + if i[0] < 0 { i[0] } else { i[0] & 3 }, + if i[1] < 0 { i[1] } else { i[1] & 3 }, + if i[2] < 0 { i[2] } else { i[2] & 3 }, + if i[3] < 0 { i[3] } else { i[3] & 3 }, + ], + b, + ); + } + + if flags & (BLEND_PERMA | BLEND_PERMB) == 0 { + // no permutation, only blending + + // constexpr uint8_t mb = + // (uint8_t)make_bit_mask<4, 0x302>(indexs); // blend mask + // y = _mm256_blend_pd(a, b, mb); // duplicate each bit + todo!(); + } else if flags & BLEND_LARGEBLOCK != 0 { + // blend and permute 128-bit blocks + // constexpr EList L = + // largeblock_perm<4>(indexs); // get 128-bit blend pattern + // constexpr uint8_t pp = (L.a[0] & 0xF) | uint8_t(L.a[1] & 0xF) << 4; + // y = _mm256_permute2f128_pd(a, b, pp); + todo!(); + } else if flags & BLEND_PUNPCKLAB != 0 { + y = a.interleave_low(b); + } else if flags & BLEND_PUNPCKLBA != 0 { + y = b.interleave_low(a); + } else if flags & BLEND_PUNPCKHAB != 0 { + y = a.interleave_high(b); + } else if flags & BLEND_PUNPCKHBA != 0 { + y = b.interleave_high(a); + } else if flags & BLEND_SHUFAB != 0 { + y = a.shuffle_nonconst(b, ((flags >> BLEND_SHUFPATTERN) as u8) & 0xF); + } else if flags & BLEND_SHUFBA != 0 { + y = b.shuffle_nonconst(a, ((flags >> BLEND_SHUFPATTERN) as u8) & 0xF); + } else { + // No special cases + // constexpr EList L = + // blend_perm_indexes<4, 0>(indexs); // get permutation indexes + // __m256d ya = permute4(a); + // __m256d yb = permute4(b); + // constexpr uint8_t mb = + // (uint8_t)make_bit_mask<4, 0x302>(indexs); // blend mask + // y = _mm256_blend_pd(ya, yb, mb); + todo!(); + } + + if flags & BLEND_ZEROING != 0 { + // additional zeroing needed + // constexpr EList bm = zero_mask_broad(indexs); + // __m256i bm1 = _mm256_loadu_si256((const __m256i *)(bm.a)); + // y = _mm256_and_pd(_mm256_castsi256_pd(bm1), y); + todo!(); + } + + y } // blend_flags: returns information about how a blend function can be implemented @@ -500,7 +583,7 @@ fn blend_flags(a: &[i32; N]) -> u64 { let ix = lane_pattern[iu as usize]; if ix >= 0 { let ix = ix as u32; - let t = ix & !(N as u32); + let mut t = ix & !(N as u32); if (ix & N as u32) > 0 { t += lane_size as u32; } @@ -549,7 +632,7 @@ fn blend_flags(a: &[i32; N]) -> u64 { } if ret & (BLEND_SHUFAB | BLEND_SHUFBA) > 0 { // fits shufps/shufpd - let shuf_pattern = 0u8; + let mut shuf_pattern = 0u8; for iu in 0..lane_size { shuf_pattern |= ((lane_pattern[iu] & 3) as u8) << (iu * 2); } @@ -689,10 +772,9 @@ fn permute8(indexes: &[i32; 8], a: f32x8) -> f32x8 { if flags & PERM_PERM > 0 { if flags & PERM_LARGEBLOCK > 0 { - // constexpr EList L = - // largeblock_perm<8>(indexs); - // y = _mm256_castpd_ps( - // permute4(Vec4d(_mm256_castps_pd(a)))); + let l = largeblock_perm::<8, 4>(indexes); + // SAFETY: Types are of same size + let b4: f32x8 = unsafe { transmute(permute4_f64(&l, transmute(a))) }; if flags & PERM_ADDZ == 0 { // no remaining zeroing return y; @@ -700,19 +782,17 @@ fn permute8(indexes: &[i32; 8], a: f32x8) -> f32x8 { } else if flags & PERM_SAME_PATTERN > 0 { if flags & PERM_PUNPCKH != 0 { // fits punpckhi - y = _mm256_unpackhi_ps(y, y); + y = y.interleave_high(y); } else if flags & PERM_PUNPCKL != 0 { // fits punpcklo - y = _mm256_unpacklo_ps(y, y); + y = y.interleave_low(y); } else { // general permute, same pattern in both lanes - y = _mm256_shuffle_ps(a, a, flags >> PERM_IPATTERN as u8); + y = a.shuffle_nonconst(a, (flags >> PERM_IPATTERN) as u8); } } else if flags & PERM_BROADCAST > 0 && flags >> PERM_ROT_COUNT == 0 { // broadcast first element - // y = _mm256_broadcastss_ps( - // _mm256_castps256_ps128(y)); - todo!(); + y = f32x8::from(y.to_array()[0]); } else if flags & PERM_ZEXT > 0 { // zero extension // y = _mm256_castsi256_ps(_mm256_cvtepu32_epi64( @@ -1058,3 +1138,59 @@ fn perm_flags( ret } + +fn make_bit_mask(a: &[i32; LEN]) -> u64 { + let mut ret = 0u64; + let sel_bit_idx = (BITS & 0xFF) as u8; + let mut ret_bit_idx = 0u64; + let mut flip = 0u64; + for i in 0..LEN { + let ix = a[i]; + if ix < 0 { + ret_bit_idx = ((BITS >> 10) & 1) as u64; + } else { + ret_bit_idx = (((ix as u32) >> sel_bit_idx) & 1) as u64; + if i < LEN / 2 { + flip = ((BITS >> 8) & 1) as u64; + } else { + flip = ((BITS >> 9) & 1) as u64; + } + ret_bit_idx ^= flip ^ 1; + } + ret |= ret_bit_idx << i; + } + ret +} + +// blend_perm_indexes: return an Indexlist for implementing a blend function as +// two permutations. N = vector size. +// dozero = 0: let unused elements be don't care. The two permutation results must be blended +// dozero = 1: zero unused elements in each permuation. The two permutation results can be OR'ed +// dozero = 2: indexes that are -1 or V_DC are preserved +fn blend_perm_indexes( + a: &[i32; N], +) -> [i32; N2] { + assert!(N * 2 == N2); + assert!(DO_ZERO <= 2); + let mut list = [0i32; N2]; + let u = if DO_ZERO > 0 { -1 } else { V_DC }; + for j in 0..N { + let ix = a[j]; + if ix < 0 { + if DO_ZERO == 2 { + a[j] = ix; + a[j + N] = ix; + } else { + a[j] = u; + a[j + N] = u; + } + } else if ix < N as i32 { + a[j] = ix; + a[j + N] = u; + } else { + a[j] = u; + a[j + N] = ix - N as i32; + } + } + list +} From 35643fded15ee1eacc351274a21d41cab747bcc3 Mon Sep 17 00:00:00 2001 From: Josh Holmer Date: Wed, 9 Nov 2022 16:26:59 -0500 Subject: [PATCH 11/11] Even more code --- src/denoise/blend.rs | 208 +++++++++++++++++++++++++++++++++---------- 1 file changed, 160 insertions(+), 48 deletions(-) diff --git a/src/denoise/blend.rs b/src/denoise/blend.rs index 055abda27b..5eaa122387 100644 --- a/src/denoise/blend.rs +++ b/src/denoise/blend.rs @@ -3,7 +3,7 @@ use std::mem::{size_of, transmute}; use arrayvec::ArrayVec; -use wide::{f32x8, f64x4}; +use wide::{f32x8, f64x4, i32x8, i64x4}; use crate::util::cast; @@ -129,8 +129,7 @@ fn blend8_f32(indexes: &[i32; 8], a: f32x8, b: f32x8) -> f32x8 { // blend and permute 32-bit blocks let l = largeblock_perm::<8, 4>(indexes); // SAFETY: Types are of same size - let b4: f32x8 = - unsafe { transmute(blend4_f64(&l, transmute(a), transmute(b))) }; + y = unsafe { transmute(blend4_f64(&l, transmute(a), transmute(b))) }; if flags & BLEND_ADDZ == 0 { // no remaining zeroing return y; @@ -145,14 +144,14 @@ fn blend8_f32(indexes: &[i32; 8], a: f32x8, b: f32x8) -> f32x8 { // no permutation, only blending let mb = make_bit_mask::<8, 0x303>(indexes) as u8; y = f32x8::new([ - if mb & 0b10000000 > 0 { -1f32 } else { 0f32 }, - if mb & 0b01000000 > 0 { -1f32 } else { 0f32 }, - if mb & 0b00100000 > 0 { -1f32 } else { 0f32 }, - if mb & 0b00010000 > 0 { -1f32 } else { 0f32 }, - if mb & 0b00001000 > 0 { -1f32 } else { 0f32 }, - if mb & 0b00000100 > 0 { -1f32 } else { 0f32 }, - if mb & 0b00000010 > 0 { -1f32 } else { 0f32 }, - if mb & 0b00000001 > 0 { -1f32 } else { 0f32 }, + if mb & (1 << 0) > 0 { -1f32 } else { 0f32 }, + if mb & (1 << 1) > 0 { -1f32 } else { 0f32 }, + if mb & (1 << 2) > 0 { -1f32 } else { 0f32 }, + if mb & (1 << 3) > 0 { -1f32 } else { 0f32 }, + if mb & (1 << 4) > 0 { -1f32 } else { 0f32 }, + if mb & (1 << 5) > 0 { -1f32 } else { 0f32 }, + if mb & (1 << 6) > 0 { -1f32 } else { 0f32 }, + if mb & (1 << 7) > 0 { -1f32 } else { 0f32 }, ]) .blend(b, a); } else if flags & BLEND_PUNPCKLAB > 0 { @@ -188,10 +187,10 @@ fn blend8_f32(indexes: &[i32; 8], a: f32x8, b: f32x8) -> f32x8 { } if flags & BLEND_ZEROING > 0 { // additional zeroing needed - // let bm = zero_mask_broad::<8>(indexes); - // let bm1 = _mm256_loadu_si256(bm as __m256i); - // y = _mm256_and_ps(_mm256_castsi256_ps(bm1), y); - todo!(); + let bm = i32x8::new(zero_mask_broad_8x32(indexes)); + // SAFETY: Types have the same size + let bm1: f32x8 = unsafe { transmute(bm) }; + y = bm1 & y; } y @@ -344,18 +343,21 @@ fn blend4_f64(indexes: &[i32; 4], a: f64x4, b: f64x4) -> f64x4 { ); } + let mut y = a; if flags & (BLEND_PERMA | BLEND_PERMB) == 0 { // no permutation, only blending - - // constexpr uint8_t mb = - // (uint8_t)make_bit_mask<4, 0x302>(indexs); // blend mask - // y = _mm256_blend_pd(a, b, mb); // duplicate each bit - todo!(); + let mb = make_bit_mask::<4, 0x302>(indexes); + y = f64x4::new([ + if mb & (1 << 0) > 0 { -1f64 } else { 0f64 }, + if mb & (1 << 1) > 0 { -1f64 } else { 0f64 }, + if mb & (1 << 2) > 0 { -1f64 } else { 0f64 }, + if mb & (1 << 3) > 0 { -1f64 } else { 0f64 }, + ]) + .blend(b, a); } else if flags & BLEND_LARGEBLOCK != 0 { // blend and permute 128-bit blocks - // constexpr EList L = - // largeblock_perm<4>(indexs); // get 128-bit blend pattern - // constexpr uint8_t pp = (L.a[0] & 0xF) | uint8_t(L.a[1] & 0xF) << 4; + let l = largeblock_perm::<4, 2>(indexes); + let pp = (l[0] & 0xF) as u8 | ((l[1] & 0xF) as u8) << 4; // y = _mm256_permute2f128_pd(a, b, pp); todo!(); } else if flags & BLEND_PUNPCKLAB != 0 { @@ -372,22 +374,25 @@ fn blend4_f64(indexes: &[i32; 4], a: f64x4, b: f64x4) -> f64x4 { y = b.shuffle_nonconst(a, ((flags >> BLEND_SHUFPATTERN) as u8) & 0xF); } else { // No special cases - // constexpr EList L = - // blend_perm_indexes<4, 0>(indexs); // get permutation indexes - // __m256d ya = permute4(a); - // __m256d yb = permute4(b); - // constexpr uint8_t mb = - // (uint8_t)make_bit_mask<4, 0x302>(indexs); // blend mask - // y = _mm256_blend_pd(ya, yb, mb); - todo!(); + let l = blend_perm_indexes::<4, 8, 0>(indexes); + let ya = permute4(cast(&l[..4]), a); + let yb = permute4(cast(&l[4..]), b); + let mb = make_bit_mask::<4, 0x302>(indexes); + y = f64x4::new([ + if mb & (1 << 0) > 0 { -1f64 } else { 0f64 }, + if mb & (1 << 1) > 0 { -1f64 } else { 0f64 }, + if mb & (1 << 2) > 0 { -1f64 } else { 0f64 }, + if mb & (1 << 3) > 0 { -1f64 } else { 0f64 }, + ]) + .blend(yb, ya); } if flags & BLEND_ZEROING != 0 { // additional zeroing needed - // constexpr EList bm = zero_mask_broad(indexs); - // __m256i bm1 = _mm256_loadu_si256((const __m256i *)(bm.a)); - // y = _mm256_and_pd(_mm256_castsi256_pd(bm1), y); - todo!(); + let bm = i64x4::new(zero_mask_broad_4x64(indexes)); + // SAFETY: Types have the same size + let bm1: f64x4 = unsafe { transmute(bm) }; + y = bm1 & y; } y @@ -774,7 +779,7 @@ fn permute8(indexes: &[i32; 8], a: f32x8) -> f32x8 { if flags & PERM_LARGEBLOCK > 0 { let l = largeblock_perm::<8, 4>(indexes); // SAFETY: Types are of same size - let b4: f32x8 = unsafe { transmute(permute4_f64(&l, transmute(a))) }; + y = unsafe { transmute(permute4(&l, transmute(a))) }; if flags & PERM_ADDZ == 0 { // no remaining zeroing return y; @@ -817,10 +822,10 @@ fn permute8(indexes: &[i32; 8], a: f32x8) -> f32x8 { } if flags & PERM_ZEROING > 0 { - // constexpr EList bm = zero_mask_broad(indexs); - // __m256i bm1 = _mm256_loadu_si256((const __m256i *)(bm.a)); - // y = _mm256_and_ps(_mm256_castsi256_ps(bm1), y); - todo!(); + let bm = i32x8::new(zero_mask_broad_8x32(indexes)); + // SAFETY: Types have the same size + let bm1: f32x8 = unsafe { transmute(bm) }; + y = bm1 & y; } y @@ -1178,19 +1183,126 @@ fn blend_perm_indexes( let ix = a[j]; if ix < 0 { if DO_ZERO == 2 { - a[j] = ix; - a[j + N] = ix; + list[j] = ix; + list[j + N] = ix; } else { - a[j] = u; - a[j + N] = u; + list[j] = u; + list[j + N] = u; } } else if ix < N as i32 { - a[j] = ix; - a[j + N] = u; + list[j] = ix; + list[j + N] = u; } else { - a[j] = u; - a[j + N] = ix - N as i32; + list[j] = u; + list[j + N] = ix - N as i32; } } list } + +#[inline(always)] +fn zero_mask_broad_8x32(a: &[i32; 8]) -> [i32; 8] { + let mut u = [0i32; 8]; + for i in 0..8 { + u[i] = if a[i] >= 0 { -1 } else { 0 }; + } + u +} + +#[inline(always)] +fn zero_mask_broad_4x64(a: &[i32; 4]) -> [i64; 4] { + let mut u = [0i64; 4]; + for i in 0..4 { + u[i] = if a[i] >= 0 { -1 } else { 0 }; + } + u +} + +fn permute4(indexes: &[i32; 4], a: f64x4) -> f64x4 { + let mut y = a; + let flags = perm_flags::<64, 4>(indexes); + + assert!( + flags & PERM_OUTOFRANGE == 0, + "Index out of range in permute function" + ); + + if flags & PERM_ALLZERO != 0 { + return f64x4::ZERO; + } + + if flags & PERM_LARGEBLOCK != 0 { + // permute 128-bit blocks + let l = largeblock_perm::<4, 2>(indexes); + let j0 = l[0]; + let j1 = l[1]; + + if j0 == 0 && j1 == -1 && flags & PERM_ADDZ == 0 { + // zero extend + // return _mm256_zextpd128_pd256(_mm256_castpd256_pd128(y)); + todo!(); + } + if j0 == 1 && j1 < 0 && flags & PERM_ADDZ == 0 { + // extract upper part, zero extend + // return _mm256_zextpd128_pd256(_mm256_extractf128_pd(y, 1)); + todo!(); + } + if flags & PERM_PERM != 0 && flags & PERM_ZEROING == 0 { + // return _mm256_permute2f128_pd(y, y, (j0 & 1) | (j1 & 1) << 4); + todo!(); + } + } + + if flags & PERM_PERM != 0 { + // permutation needed + if flags & PERM_SAME_PATTERN != 0 { + // same pattern in both lanes + if flags & PERM_PUNPCKH != 0 { + y = y.interleave_high(y); + } else if flags & PERM_PUNPCKL != 0 { + y = y.interleave_low(y); + } else { + // general permute + let mm0 = (indexes[0] & 1) + | (indexes[1] & 1) << 1 + | (indexes[2] & 1) << 2 + | (indexes[3] & 1) << 3; + // select within same lane + // y = _mm256_permute_pd(a, mm0); + todo!(); + } + } else if flags & PERM_BROADCAST != 0 && (flags >> PERM_ROT_COUNT) == 0 { + // broadcast first element + // y = _mm256_broadcastsd_pd( + // _mm256_castpd256_pd128(y)); + todo!(); + } else { + // different patterns in two lanes + if flags & PERM_CROSS_LANE == 0 { + // no lane crossing + // constexpr uint8_t mm0 = + // (i0 & 1) | (i1 & 1) << 1 | (i2 & 1) << 2 | (i3 & 1) << 3; + // y = _mm256_permute_pd(a, mm0); // select within same lane + todo!(); + } else { + // // full permute + // constexpr uint8_t mms = + // (i0 & 3) | (i1 & 3) << 2 | (i2 & 3) << 4 | (i3 & 3) << 6; + // y = _mm256_permute4x64_pd(a, mms); + todo!(); + } + } + } + + if flags & PERM_ZEROING != 0 { + // additional zeroing needed + // use broad mask + // constexpr EList bm = zero_mask_broad(indexs); + // // y = _mm256_and_pd(_mm256_castsi256_pd( Vec4q().load(bm.a) ), y); // does + // // not work with INSTRSET = 7 + // __m256i bm1 = _mm256_loadu_si256((const __m256i *)(bm.a)); + // y = _mm256_and_pd(_mm256_castsi256_pd(bm1), y); + todo!(); + } + y +}