diff --git a/Cargo.toml b/Cargo.toml index 320e7aafa4..eee3df5814 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,7 +36,6 @@ binaries = [ "fern", "console", "av-metrics", - "nom", ] default = ["binaries", "asm", "threading", "signal_support"] asm = ["nasm-rs", "cc"] @@ -103,11 +102,16 @@ simd_helpers = "0.1" wasm-bindgen = { version = "0.2.63", optional = true } rust_hawktracer = "0.7.0" const_fn_assert = "0.1.2" -nom = { version = "7.0.0", optional = true } +# `unreachable!` macro which panics in debug mode +# and optimizes away in release mode new_debug_unreachable = "1.0.4" once_cell = "1.13.0" av1-grain = { version = "0.2.0", features = ["serialize"] } serde-big-array = { version = "0.4.1", optional = true } +# Used for parsing film grain table files +nom = "7.0.0" +wide = { git = "https://github.com/shssoichiro/wide", branch = "additional-functions" } +num-complex = "0.4.2" [dependencies.image] version = "0.24.3" diff --git a/benches/bench.rs b/benches/bench.rs index 4929d83fdc..e5395b99de 100644 --- a/benches/bench.rs +++ b/benches/bench.rs @@ -7,6 +7,7 @@ // Media Patent License 1.0 was not distributed with this source code in the // PATENTS file, you can obtain it at www.aomedia.org/license/patent. +mod denoise; mod dist; mod mc; mod plane; @@ -23,12 +24,15 @@ use rav1e::bench::partition::*; use rav1e::bench::predict::*; use rav1e::bench::rdo::*; use rav1e::bench::transform::*; +use rav1e::prelude::*; use crate::plane::plane; use crate::rdo::rdo; use crate::transform::{forward_transforms, inverse_transforms}; use criterion::*; +use rand::Rng; +use rand_chacha::ChaChaRng; use std::sync::Arc; use std::time::Duration; @@ -193,6 +197,26 @@ fn update_cdf_4(b: &mut Bencher) { }); } +fn fill_plane(ra: &mut ChaChaRng, plane: &mut Plane) { + let stride = plane.cfg.stride; + for row in plane.data_origin_mut().chunks_mut(stride) { + for pixel in row { + let v: u8 = ra.gen(); + *pixel = T::cast_from(v); + } + } +} + +fn new_plane( + ra: &mut ChaChaRng, width: usize, height: usize, +) -> Plane { + let mut p = Plane::new(width, height, 0, 0, 128 + 8, 128 + 8); + + fill_plane(ra, &mut p); + + p +} + criterion_group!(intra_prediction, predict::pred_bench,); criterion_group!(cfl, cfl_rdo); @@ -217,5 +241,6 @@ criterion_main!( ec, rdo, plane, - mc::mc + mc::mc, + denoise::denoise ); diff --git a/benches/denoise.rs b/benches/denoise.rs new file mode 100644 index 0000000000..5daad7577b --- /dev/null +++ b/benches/denoise.rs @@ -0,0 +1,76 @@ +use super::new_plane; +use criterion::*; +use rand::SeedableRng; +use rand_chacha::ChaChaRng; +use rav1e::bench::denoise::*; +use rav1e::prelude::*; +use std::collections::BTreeMap; +use std::sync::Arc; + +fn bench_dft_denoiser_8b(c: &mut Criterion) { + let mut ra = ChaChaRng::from_seed([0; 32]); + let w = 640; + let h = 480; + let mut frame_queue = BTreeMap::new(); + for i in 0..3 { + frame_queue.insert( + i, + Some(Arc::new(Frame { + planes: [ + new_plane::(&mut ra, w, h), + new_plane::(&mut ra, w / 2, h / 2), + new_plane::(&mut ra, w / 2, h / 2), + ], + })), + ); + } + frame_queue.insert(3, None); + + c.bench_function("dft_denoiser_8b", |b| { + b.iter_with_setup( + || DftDenoiser::new(2.0, w, h, 8, ChromaSampling::Cs420), + |mut denoiser| { + for _ in 0..3 { + let _ = black_box(denoiser.filter_frame(&frame_queue)); + } + }, + ) + }); +} + +fn bench_dft_denoiser_10b(c: &mut Criterion) { + let mut ra = ChaChaRng::from_seed([0; 32]); + let w = 640; + let h = 480; + let mut frame_queue = BTreeMap::new(); + for i in 0..3 { + let mut frame = Frame { + planes: [ + new_plane::(&mut ra, w, h), + new_plane::(&mut ra, w / 2, h / 2), + new_plane::(&mut ra, w / 2, h / 2), + ], + }; + for p in 0..3 { + // Shift from 16-bit to 10-bit + frame.planes[p].data.iter_mut().for_each(|pix| { + *pix = *pix >> 6; + }); + } + frame_queue.insert(i, Some(Arc::new(frame))); + } + frame_queue.insert(3, None); + + c.bench_function("dft_denoiser_10b", |b| { + b.iter_with_setup( + || DftDenoiser::new(2.0, w, h, 10, ChromaSampling::Cs420), + |mut denoiser| { + for _ in 0..3 { + let _ = black_box(denoiser.filter_frame(&frame_queue)); + } + }, + ) + }); +} + +criterion_group!(denoise, bench_dft_denoiser_8b, bench_dft_denoiser_10b); diff --git a/benches/dist.rs b/benches/dist.rs index d86a7743e0..0be7e95361 100644 --- a/benches/dist.rs +++ b/benches/dist.rs @@ -9,6 +9,7 @@ #![allow(clippy::trivially_copy_pass_by_ref)] +use super::new_plane; use criterion::*; use rand::{Rng, SeedableRng}; use rand_chacha::ChaChaRng; @@ -69,26 +70,6 @@ const DIST_BENCH_SET: &[(BlockSize, usize)] = &[ (BLOCK_64X16, 10), ]; -fn fill_plane(ra: &mut ChaChaRng, plane: &mut Plane) { - let stride = plane.cfg.stride; - for row in plane.data_origin_mut().chunks_mut(stride) { - for pixel in row { - let v: u8 = ra.gen(); - *pixel = T::cast_from(v); - } - } -} - -fn new_plane( - ra: &mut ChaChaRng, width: usize, height: usize, -) -> Plane { - let mut p = Plane::new(width, height, 0, 0, 128 + 8, 128 + 8); - - fill_plane(ra, &mut p); - - p -} - type DistFn = fn( plane_org: &PlaneRegion<'_, T>, plane_ref: &PlaneRegion<'_, T>, diff --git a/benches/mc.rs b/benches/mc.rs index 4d3f81add7..8b6156f89b 100644 --- a/benches/mc.rs +++ b/benches/mc.rs @@ -1,7 +1,8 @@ #![allow(clippy::unit_arg)] +use super::new_plane; use criterion::*; -use rand::{Rng, SeedableRng}; +use rand::SeedableRng; use rand_chacha::ChaChaRng; use rav1e::bench::cpu_features::*; use rav1e::bench::frame::{AsRegion, PlaneOffset, PlaneSlice}; @@ -525,26 +526,6 @@ criterion_group!( bench_prep_8tap_center_hbd ); -fn fill_plane(ra: &mut ChaChaRng, plane: &mut Plane) { - let stride = plane.cfg.stride; - for row in plane.data_origin_mut().chunks_mut(stride) { - for pixel in row { - let v: u8 = ra.gen(); - *pixel = T::cast_from(v); - } - } -} - -fn new_plane( - ra: &mut ChaChaRng, width: usize, height: usize, -) -> Plane { - let mut p = Plane::new(width, height, 0, 0, 128 + 8, 128 + 8); - - fill_plane(ra, &mut p); - - p -} - fn get_params( rec_plane: &Plane, po: PlaneOffset, mv: MotionVector, ) -> (i32, i32, PlaneSlice) { diff --git a/clippy.toml b/clippy.toml index 2ef5862dea..bb141ab8e0 100644 --- a/clippy.toml +++ b/clippy.toml @@ -1,4 +1,5 @@ too-many-arguments-threshold = 16 cognitive-complexity-threshold = 40 -trivial-copy-size-limit = 16 # 128-bits = 2 64-bit registers +trivial-copy-size-limit = 16 # 128-bits = 2 64-bit registers +doc-valid-idents = ["DFTTest", "DFTTest2"] # 128-bits = 2 64-bit registers msrv = "1.60" diff --git a/src/api/config/encoder.rs b/src/api/config/encoder.rs index 7f84d5a081..c36bbda0f7 100644 --- a/src/api/config/encoder.rs +++ b/src/api/config/encoder.rs @@ -85,6 +85,8 @@ pub struct EncoderConfig { pub tune: Tune, /// Parameters for grain synthesis. pub film_grain_params: Option>, + /// Strength of denoising, 0 = disabled + pub denoise_strength: u8, /// Number of tiles horizontally. Must be a power of two. /// /// Overridden by [`tiles`], if present. @@ -159,6 +161,7 @@ impl EncoderConfig { bitrate: 0, tune: Tune::default(), film_grain_params: None, + denoise_strength: 0, tile_cols: 0, tile_rows: 0, tiles: 0, diff --git a/src/api/internal.rs b/src/api/internal.rs index 7ba9da0748..bac4910280 100644 --- a/src/api/internal.rs +++ b/src/api/internal.rs @@ -15,6 +15,7 @@ use crate::api::{ }; use crate::color::ChromaSampling::Cs400; use crate::cpu_features::CpuFeatureLevel; +use crate::denoise::{DftDenoiser, TEMPORAL_RADIUS}; use crate::dist::get_satd; use crate::encoder::*; use crate::frame::*; @@ -220,7 +221,7 @@ impl FrameData { } } -type FrameQueue = BTreeMap>>>; +pub(crate) type FrameQueue = BTreeMap>>>; type FrameDataQueue = BTreeMap>>; // the fields pub(super) are accessed only by the tests @@ -248,6 +249,7 @@ pub(crate) struct ContextInner { /// Maps `output_frameno` to `gop_input_frameno_start`. pub(crate) gop_input_frameno_start: BTreeMap, keyframe_detector: SceneChangeDetector, + denoiser: Option>, pub(crate) config: Arc, seq: Arc, pub(crate) rc_state: RCState, @@ -295,6 +297,16 @@ impl ContextInner { lookahead_distance, seq.clone(), ), + denoiser: if enc.denoise_strength > 0 { + Some(DftDenoiser::::new( + enc.denoise_strength as f32 / 10.0, + enc.width, + enc.height, + enc.bit_depth, + )) + } else { + None + }, config: Arc::new(enc.clone()), seq, rc_state: RCState::new( @@ -359,6 +371,25 @@ impl ContextInner { self.t35_q.insert(input_frameno, params.t35_metadata); } + // If denoising is enabled, run it now because we want the entire + // encoding process, including lookahead, to see the denoised frame. + if let Some(ref mut denoiser) = self.denoiser { + loop { + let denoiser_frame = denoiser.cur_frameno; + if (!is_flushing + && input_frameno >= denoiser_frame + TEMPORAL_RADIUS as u64) + || (is_flushing && Some(denoiser_frame) < self.limit) + { + self.frame_q.insert( + denoiser_frame, + Some(Arc::new(denoiser.filter_frame(&self.frame_q).unwrap())), + ); + } else { + break; + } + } + } + if !self.needs_more_frame_q_lookahead(self.next_lookahead_frame) { let lookahead_frames = self .frame_q diff --git a/src/api/test.rs b/src/api/test.rs index 072562631f..3d167fdfc5 100644 --- a/src/api/test.rs +++ b/src/api/test.rs @@ -2131,6 +2131,7 @@ fn log_q_exp_overflow() { tile_cols: 0, tile_rows: 0, tiles: 0, + denoise_strength: 0, speed_settings: SpeedSettings { multiref: false, fast_deblock: true, @@ -2207,6 +2208,7 @@ fn guess_frame_subtypes_assert() { tile_cols: 0, tile_rows: 0, tiles: 0, + denoise_strength: 0, speed_settings: SpeedSettings { multiref: false, fast_deblock: true, diff --git a/src/asm/x86/transform/forward.rs b/src/asm/x86/transform/forward.rs index 18b1171517..7a46c310bb 100644 --- a/src/asm/x86/transform/forward.rs +++ b/src/asm/x86/transform/forward.rs @@ -316,20 +316,6 @@ impl SizeClass1D { } } -fn cast(x: &[T]) -> &[T; N] { - // SAFETY: we perform a bounds check with [..N], - // so casting to *const [T; N] is valid because the bounds - // check guarantees that x has N elements - unsafe { &*(&x[..N] as *const [T] as *const [T; N]) } -} - -fn cast_mut(x: &mut [T]) -> &mut [T; N] { - // SAFETY: we perform a bounds check with [..N], - // so casting to *mut [T; N] is valid because the bounds - // check guarantees that x has N elements - unsafe { &mut *(&mut x[..N] as *mut [T] as *mut [T; N]) } -} - #[allow(clippy::identity_op, clippy::erasing_op)] #[target_feature(enable = "avx2")] unsafe fn forward_transform_avx2( diff --git a/src/bin/common.rs b/src/bin/common.rs index 3e0d5aa8c4..7f8edf6134 100644 --- a/src/bin/common.rs +++ b/src/bin/common.rs @@ -183,6 +183,17 @@ pub struct CliOptions { help_heading = "ENCODE SETTINGS" )] pub photon_noise: u8, + /// Enable spatio-temporal denoising, intended to be used with grain synthesis. + /// Takes a strength value 0-50. + /// + /// Default strength is 1/2 of photon noise strength, + /// or 4 if a photon noise table is specified. + #[clap( + long, + value_parser = clap::value_parser!(u8).range(0..=50), + help_heading = "ENCODE SETTINGS" + )] + pub denoise: Option, /// Uses a film grain table file to apply grain synthesis to the encode. /// Uses the same table file format as aomenc and svt-av1. #[clap( @@ -661,7 +672,19 @@ fn parse_config(matches: &CliOptions) -> Result { .expect("Failed to parse film grain table"); if !table.is_empty() { cfg.film_grain_params = Some(table); + cfg.denoise_strength = 4; + } + } else if matches.photon_noise > 0 { + cfg.denoise_strength = matches.photon_noise / 2; + // We have to know the video resolution before we can generate a table, + // so we must handle that elsewhere. + } + // A user set denoise strength overrides the defaults above + if let Some(denoise_str) = matches.denoise { + if denoise_str > 50 { + panic!("Denoising strength must be between 0-50"); } + cfg.denoise_strength = denoise_str; } if let Some(frame_rate) = matches.frame_rate { diff --git a/src/denoise/blend.rs b/src/denoise/blend.rs new file mode 100644 index 0000000000..5eaa122387 --- /dev/null +++ b/src/denoise/blend.rs @@ -0,0 +1,1308 @@ +//! A pox on the house of whoever decided this was a good way to code anything + +use std::mem::{size_of, transmute}; + +use arrayvec::ArrayVec; +use wide::{f32x8, f64x4, i32x8, i64x4}; + +use crate::util::cast; + +use super::{f32x16, f64x8}; + +const V_DC: i32 = -256; + +pub fn blend16< + const I0: usize, + const I1: usize, + const I2: usize, + const I3: usize, + const I4: usize, + const I5: usize, + const I6: usize, + const I7: usize, + const I8: usize, + const I9: usize, + const I10: usize, + const I11: usize, + const I12: usize, + const I13: usize, + const I14: usize, + const I15: usize, +>( + a: f32x16, b: f32x16, +) -> f32x16 { + let x0 = blend_half16::(a, b); + let x1 = blend_half16::(a, b); + [x0, x1] +} + +fn blend_half16< + const I0: usize, + const I1: usize, + const I2: usize, + const I3: usize, + const I4: usize, + const I5: usize, + const I6: usize, + const I7: usize, +>( + a: f32x16, b: f32x16, +) -> f32x8 { + const N: usize = 8; + let ind: [usize; N] = [I0, I1, I2, I3, I4, I5, I6, I7]; + + // lambda to find which of the four possible sources are used + fn list_sources(ind: &[usize; N]) -> ArrayVec { + let mut source_used = [false; 4]; + for i in 0..N { + let ix = ind[i]; + let src = ix / N; + source_used[src & 3] = true; + } + // return a list of sources used. + let mut sources = ArrayVec::new(); + for i in 0..4 { + if source_used[i] { + sources.push(i); + } + } + sources + } + + let sources = list_sources(&ind); + if sources.is_empty() { + return f32x8::ZERO; + } + + // get indexes for the first one or two sources + let uindex = if sources.len() > 2 { 1 } else { 2 }; + let l = blend_half_indexes::<8>( + uindex, + sources.get(0).copied(), + sources.get(1).copied(), + &ind, + ); + let src0 = select_blend16(sources.get(0).copied(), a, b); + let src1 = select_blend16(sources.get(1).copied(), a, b); + let mut x0 = blend8_f32(cast(&l[..8]), src0, src1); + + // get last one or two sources + if sources.len() > 2 { + let m = blend_half_indexes::<8>( + 1, + sources.get(2).copied(), + sources.get(3).copied(), + &ind, + ); + let src2 = select_blend16(sources.get(2).copied(), a, b); + let src3 = select_blend16(sources.get(3).copied(), a, b); + let x1 = blend8_f32(cast(&m[..8]), src2, src3); + + // combine result of two blends. Unused elements are zero + x0 |= x1; + } + + x0 +} + +fn select_blend16(action: Option, a: f32x16, b: f32x16) -> f32x8 { + match action { + Some(0) => a[0], + Some(1) => a[1], + Some(2) => b[0], + _ => b[1], + } +} + +#[inline(always)] +fn blend8_f32(indexes: &[i32; 8], a: f32x8, b: f32x8) -> f32x8 { + let mut y = a; + let flags = blend_flags::<8, { size_of::() }>(&indexes); + + assert!(flags & BLEND_OUTOFRANGE == 0); + + if flags & BLEND_ALLZERO > 0 { + return f32x8::ZERO; + } + + if flags & BLEND_LARGEBLOCK > 0 { + // blend and permute 32-bit blocks + let l = largeblock_perm::<8, 4>(indexes); + // SAFETY: Types are of same size + y = unsafe { transmute(blend4_f64(&l, transmute(a), transmute(b))) }; + if flags & BLEND_ADDZ == 0 { + // no remaining zeroing + return y; + } + } else if flags & BLEND_B == 0 { + // nothing from b. just permute a + return permute8(indexes, a); + } else if flags & BLEND_A == 0 { + let l = blend_perm_indexes::<8, 16, 2>(indexes); + return permute8(cast(&l[8..]), b); + } else if flags & (BLEND_PERMA | BLEND_PERMB) == 0 { + // no permutation, only blending + let mb = make_bit_mask::<8, 0x303>(indexes) as u8; + y = f32x8::new([ + if mb & (1 << 0) > 0 { -1f32 } else { 0f32 }, + if mb & (1 << 1) > 0 { -1f32 } else { 0f32 }, + if mb & (1 << 2) > 0 { -1f32 } else { 0f32 }, + if mb & (1 << 3) > 0 { -1f32 } else { 0f32 }, + if mb & (1 << 4) > 0 { -1f32 } else { 0f32 }, + if mb & (1 << 5) > 0 { -1f32 } else { 0f32 }, + if mb & (1 << 6) > 0 { -1f32 } else { 0f32 }, + if mb & (1 << 7) > 0 { -1f32 } else { 0f32 }, + ]) + .blend(b, a); + } else if flags & BLEND_PUNPCKLAB > 0 { + y = a.interleave_low(b); + } else if flags & BLEND_PUNPCKLBA > 0 { + y = b.interleave_low(a); + } else if flags & BLEND_PUNPCKHAB > 0 { + y = a.interleave_high(b); + } else if flags & BLEND_PUNPCKHBA > 0 { + y = b.interleave_high(a); + } else if flags & BLEND_SHUFAB > 0 { + y = a.shuffle_nonconst(b, (flags >> BLEND_SHUFPATTERN) as u8); + } else if flags & BLEND_SHUFBA > 0 { + y = b.shuffle_nonconst(a, (flags >> BLEND_SHUFPATTERN) as u8); + } else { + // No special cases + // permute a and b separately, then blend. + let l = blend_perm_indexes::<8, 16, 0>(indexes); + let ya = permute8(cast(&l[..8]), a); + let yb = permute8(cast(&l[8..]), b); + let mb = make_bit_mask::<8, 0x303>(indexes) as u8; + y = f32x8::new([ + if mb & 0b10000000 > 0 { -1f32 } else { 0f32 }, + if mb & 0b01000000 > 0 { -1f32 } else { 0f32 }, + if mb & 0b00100000 > 0 { -1f32 } else { 0f32 }, + if mb & 0b00010000 > 0 { -1f32 } else { 0f32 }, + if mb & 0b00001000 > 0 { -1f32 } else { 0f32 }, + if mb & 0b00000100 > 0 { -1f32 } else { 0f32 }, + if mb & 0b00000010 > 0 { -1f32 } else { 0f32 }, + if mb & 0b00000001 > 0 { -1f32 } else { 0f32 }, + ]) + .blend(yb, ya); + } + if flags & BLEND_ZEROING > 0 { + // additional zeroing needed + let bm = i32x8::new(zero_mask_broad_8x32(indexes)); + // SAFETY: Types have the same size + let bm1: f32x8 = unsafe { transmute(bm) }; + y = bm1 & y; + } + + y +} + +pub fn blend8< + const I0: usize, + const I1: usize, + const I2: usize, + const I3: usize, + const I4: usize, + const I5: usize, + const I6: usize, + const I7: usize, +>( + a: f64x8, b: f64x8, +) -> f64x8 { + let x0 = blend_half8::(a, b); + let x1 = blend_half8::(a, b); + [x0, x1] +} + +fn blend_half8< + const I0: usize, + const I1: usize, + const I2: usize, + const I3: usize, +>( + a: f64x8, b: f64x8, +) -> f64x4 { + const N: usize = 4; + let ind: [usize; N] = [I0, I1, I2, I3]; + + // lambda to find which of the four possible sources are used + fn list_sources(ind: &[usize; N]) -> ArrayVec { + let mut source_used = [false; 4]; + for i in 0..N { + let ix = ind[i]; + let src = ix / N; + source_used[src & 3] = true; + } + // return a list of sources used. + let mut sources = ArrayVec::new(); + for i in 0..4 { + if source_used[i] { + sources.push(i); + } + } + sources + } + + let sources = list_sources(&ind); + if sources.is_empty() { + return f64x4::ZERO; + } + + // get indexes for the first one or two sources + let uindex = if sources.len() > 2 { 1 } else { 2 }; + let l = blend_half_indexes::<4>( + uindex, + sources.get(0).copied(), + sources.get(1).copied(), + &ind, + ); + let src0 = select_blend8(sources.get(0).copied(), a, b); + let src1 = select_blend8(sources.get(1).copied(), a, b); + let mut x0 = blend4_f64(cast(&l[..4]), src0, src1); + + // get last one or two sources + if sources.len() > 2 { + let m = blend_half_indexes::<4>( + 1, + sources.get(2).copied(), + sources.get(3).copied(), + &ind, + ); + let src2 = select_blend8(sources.get(2).copied(), a, b); + let src3 = select_blend8(sources.get(3).copied(), a, b); + let x1 = blend4_f64(cast(&m[..4]), src2, src3); + + // combine result of two blends. Unused elements are zero + x0 |= x1; + } + + x0 +} + +// blend_half_indexes: return an Indexlist for emulating a blend function as +// blends or permutations from multiple sources +// dozero = 0: let unused elements be don't care. Multiple permutation results must be blended +// dozero = 1: zero unused elements in each permuation. Multiple permutation results can be OR'ed +// dozero = 2: indexes that are -1 or V_DC are preserved +// src1, src2: sources to blend in a partial implementation +fn blend_half_indexes( + dozero: u8, src1: Option, src2: Option, ind: &[usize; N], +) -> ArrayVec { + // a is a reference to a constexpr array of permutation indexes + let mut list = ArrayVec::new(); + // value to use for unused entries + let u = if dozero > 0 { -1 } else { V_DC }; + + for j in 0..N { + let idx = ind[j]; + let src = idx / N; + list.push(if src1 == Some(src) { + (idx & (N - 1)) as i32 + } else if src2 == Some(src) { + ((idx & (N - 1)) + N) as i32 + } else { + u + }); + } + + list +} + +fn select_blend8(action: Option, a: f64x8, b: f64x8) -> f64x4 { + match action { + Some(0) => a[0], + Some(1) => a[1], + Some(2) => b[0], + _ => b[1], + } +} + +#[inline(always)] +fn blend4_f64(indexes: &[i32; 4], a: f64x4, b: f64x4) -> f64x4 { + let flags = blend_flags::<4, { size_of::() }>(indexes); + assert!( + flags & BLEND_OUTOFRANGE == 0, + "Index out of range in blend function" + ); + + if flags & BLEND_ALLZERO != 0 { + return f64x4::ZERO; + } + + if flags & BLEND_B == 0 { + return permute4(indexes, a); + } + if flags & BLEND_A == 0 { + return permute4( + &[ + if i[0] < 0 { i[0] } else { i[0] & 3 }, + if i[1] < 0 { i[1] } else { i[1] & 3 }, + if i[2] < 0 { i[2] } else { i[2] & 3 }, + if i[3] < 0 { i[3] } else { i[3] & 3 }, + ], + b, + ); + } + + let mut y = a; + if flags & (BLEND_PERMA | BLEND_PERMB) == 0 { + // no permutation, only blending + let mb = make_bit_mask::<4, 0x302>(indexes); + y = f64x4::new([ + if mb & (1 << 0) > 0 { -1f64 } else { 0f64 }, + if mb & (1 << 1) > 0 { -1f64 } else { 0f64 }, + if mb & (1 << 2) > 0 { -1f64 } else { 0f64 }, + if mb & (1 << 3) > 0 { -1f64 } else { 0f64 }, + ]) + .blend(b, a); + } else if flags & BLEND_LARGEBLOCK != 0 { + // blend and permute 128-bit blocks + let l = largeblock_perm::<4, 2>(indexes); + let pp = (l[0] & 0xF) as u8 | ((l[1] & 0xF) as u8) << 4; + // y = _mm256_permute2f128_pd(a, b, pp); + todo!(); + } else if flags & BLEND_PUNPCKLAB != 0 { + y = a.interleave_low(b); + } else if flags & BLEND_PUNPCKLBA != 0 { + y = b.interleave_low(a); + } else if flags & BLEND_PUNPCKHAB != 0 { + y = a.interleave_high(b); + } else if flags & BLEND_PUNPCKHBA != 0 { + y = b.interleave_high(a); + } else if flags & BLEND_SHUFAB != 0 { + y = a.shuffle_nonconst(b, ((flags >> BLEND_SHUFPATTERN) as u8) & 0xF); + } else if flags & BLEND_SHUFBA != 0 { + y = b.shuffle_nonconst(a, ((flags >> BLEND_SHUFPATTERN) as u8) & 0xF); + } else { + // No special cases + let l = blend_perm_indexes::<4, 8, 0>(indexes); + let ya = permute4(cast(&l[..4]), a); + let yb = permute4(cast(&l[4..]), b); + let mb = make_bit_mask::<4, 0x302>(indexes); + y = f64x4::new([ + if mb & (1 << 0) > 0 { -1f64 } else { 0f64 }, + if mb & (1 << 1) > 0 { -1f64 } else { 0f64 }, + if mb & (1 << 2) > 0 { -1f64 } else { 0f64 }, + if mb & (1 << 3) > 0 { -1f64 } else { 0f64 }, + ]) + .blend(yb, ya); + } + + if flags & BLEND_ZEROING != 0 { + // additional zeroing needed + let bm = i64x4::new(zero_mask_broad_4x64(indexes)); + // SAFETY: Types have the same size + let bm1: f64x4 = unsafe { transmute(bm) }; + y = bm1 & y; + } + + y +} + +// blend_flags: returns information about how a blend function can be implemented +// The return value is composed of these flag bits: + +// needs zeroing +const BLEND_ZEROING: u64 = 1; +// all is zero or don't care +const BLEND_ALLZERO: u64 = 2; +// fits blend with a larger block size (e.g permute Vec2q instead of Vec4i) +const BLEND_LARGEBLOCK: u64 = 4; +// additional zeroing needed after blend with larger block size or shift +const BLEND_ADDZ: u64 = 8; +// has data from a +const BLEND_A: u64 = 0x10; +// has data from b +const BLEND_B: u64 = 0x20; +// permutation of a needed +const BLEND_PERMA: u64 = 0x40; +// permutation of b needed +const BLEND_PERMB: u64 = 0x80; +// permutation crossing 128-bit lanes +const BLEND_CROSS_LANE: u64 = 0x100; +// same permute/blend pattern in all 128-bit lanes +const BLEND_SAME_PATTERN: u64 = 0x200; +// pattern fits punpckh(a,b) +const BLEND_PUNPCKHAB: u64 = 0x1000; +// pattern fits punpckh(b,a) +const BLEND_PUNPCKHBA: u64 = 0x2000; +// pattern fits punpckl(a,b) +const BLEND_PUNPCKLAB: u64 = 0x4000; +// pattern fits punpckl(b,a) +const BLEND_PUNPCKLBA: u64 = 0x8000; +// pattern fits palignr(a,b) +const BLEND_ROTATEAB: u64 = 0x10000; +// pattern fits palignr(b,a) +const BLEND_ROTATEBA: u64 = 0x20000; +// pattern fits shufps/shufpd(a,b) +const BLEND_SHUFAB: u64 = 0x40000; +// pattern fits shufps/shufpd(b,a) +const BLEND_SHUFBA: u64 = 0x80000; +// pattern fits rotation across lanes. count returned in bits blend_rotpattern +const BLEND_ROTATE_BIG: u64 = 0x100000; +// index out of range +const BLEND_OUTOFRANGE: u64 = 0x10000000; +// pattern for shufps/shufpd is in bit blend_shufpattern to blend_shufpattern + 7 +const BLEND_SHUFPATTERN: u64 = 32; +// pattern for palignr is in bit blend_rotpattern to blend_rotpattern + 7 +const BLEND_ROTPATTERN: u64 = 40; + +// INSTRSET = 8 +fn blend_flags(a: &[i32; N]) -> u64 { + let mut ret = BLEND_LARGEBLOCK | BLEND_SAME_PATTERN | BLEND_ALLZERO; + // number of 128-bit lanes + let n_lanes = V / 16; + // elements per lane + let lane_size = N / n_lanes; + // current lane + let mut lane = 0; + // rotate left count + let mut rot: u32 = 999; + // GeNERIc pARAMeTerS canNoT BE useD In consT ConTEXTs + let mut lane_pattern = vec![0; lane_size].into_boxed_slice(); + if lane_size == 2 && N <= 8 { + ret |= BLEND_SHUFAB | BLEND_SHUFBA; + } + + for ii in 0..N { + let ix = a[ii]; + if ix < 0 { + if ix == -1 { + ret |= BLEND_ZEROING; + } else if ix != V_DC { + ret = BLEND_OUTOFRANGE; + break; + } + } else { + ret &= !BLEND_ALLZERO; + if ix < N as i32 { + ret |= BLEND_A; + if ix != ii as i32 { + ret |= BLEND_PERMA; + } + } else if ix < 2 * N as i32 { + ret |= BLEND_B; + if ix != (ii + N) as i32 { + ret |= BLEND_PERMB; + } + } else { + ret = BLEND_OUTOFRANGE; + break; + } + } + + // check if pattern fits a larger block size: + // even indexes must be even, odd indexes must fit the preceding even index + 1 + if (ii & 1) == 0 { + if ix >= 0 && (ix & 1) > 0 { + ret &= !BLEND_LARGEBLOCK; + } + let iy = a[ii + 1]; + if iy >= 0 && (iy & 1) == 0 { + ret &= !BLEND_LARGEBLOCK; + } + if ix >= 0 && iy >= 0 && iy != ix + 1 { + ret &= !BLEND_LARGEBLOCK; + } + if ix == -1 && iy >= 0 { + ret |= BLEND_ADDZ; + } + if iy == -1 && ix >= 0 { + ret |= BLEND_ADDZ; + } + } + + lane = ii / lane_size; + if lane == 0 { + lane_pattern[ii] = ix; + } + + // check if crossing lanes + if ix >= 0 { + let lane_i = (ix & !(N as i32)) as usize / lane_size; + if lane_i != lane { + ret |= BLEND_CROSS_LANE; + } + if lane_size == 2 { + // check if it fits pshufd + if lane_i != lane { + ret &= !(BLEND_SHUFAB | BLEND_SHUFBA); + } + if (((ix & (N as i32)) != 0) as usize ^ ii) & 1 > 0 { + ret &= !BLEND_SHUFAB; + } else { + ret &= !BLEND_SHUFBA; + } + } + } + + // check if same pattern in all lanes + if lane != 0 && ix >= 0 { + let j = ii - (lane * lane_size); + let jx = ix - (lane * lane_size) as i32; + if jx < 0 || (jx & !(N as i32)) >= lane_size as i32 { + ret &= !BLEND_SAME_PATTERN; + } + if lane_pattern[j] < 0 { + lane_pattern[j] = jx; + } else if lane_pattern[j] != jx { + ret &= !BLEND_SAME_PATTERN; + } + } + } + + if ret & BLEND_LARGEBLOCK == 0 { + ret &= !BLEND_ADDZ; + } + if ret & BLEND_CROSS_LANE > 0 { + ret &= !BLEND_SAME_PATTERN; + } + if ret & (BLEND_PERMA | BLEND_PERMB) == 0 { + return ret; + } + + if ret & BLEND_SAME_PATTERN > 0 { + // same pattern in all lanes. check if it fits unpack patterns + ret |= + BLEND_PUNPCKHAB | BLEND_PUNPCKHBA | BLEND_PUNPCKLAB | BLEND_PUNPCKLBA; + for iu in 0..(lane_size as u32) { + let ix = lane_pattern[iu as usize]; + if ix >= 0 { + let ix = ix as u32; + if ix != iu / 2 + (iu & 1) * N as u32 { + ret &= !BLEND_PUNPCKLAB; + } + if ix != iu / 2 + ((iu & 1) ^ 1) * N as u32 { + ret &= !BLEND_PUNPCKLBA; + } + if ix != (iu + lane_size as u32) / 2 + (iu & 1) * N as u32 { + ret &= !BLEND_PUNPCKHAB; + } + if ix != (iu + lane_size as u32) / 2 + ((iu & 1) ^ 1) * N as u32 { + ret &= !BLEND_PUNPCKHBA; + } + } + } + + for iu in 0..(lane_size as u32) { + // check if it fits palignr + let ix = lane_pattern[iu as usize]; + if ix >= 0 { + let ix = ix as u32; + let mut t = ix & !(N as u32); + if (ix & N as u32) > 0 { + t += lane_size as u32; + } + let tb = (t + 2 * lane_size as u32 - iu) % (lane_size as u32 * 2); + if rot == 999 { + rot = tb; + } else if rot != tb { + rot = 1000; + } + } + } + if rot < 999 { + // fits palignr + if rot < lane_size as u32 { + ret |= BLEND_ROTATEBA; + } else { + ret |= BLEND_ROTATEAB; + } + let elem_size = (V / N) as u32; + ret |= (((rot & (lane_size as u32 - 1)) * elem_size) as u64) + << BLEND_ROTPATTERN; + } + + if lane_size == 4 { + // check if it fits shufps + ret |= BLEND_SHUFAB | BLEND_SHUFBA; + for ii in 0..2 { + let ix = lane_pattern[ii]; + if ix >= 0 { + if ix & N as i32 > 0 { + ret &= !BLEND_SHUFAB; + } else { + ret &= !BLEND_SHUFBA; + } + } + } + for ii in 2..4 { + let ix = lane_pattern[ii]; + if ix >= 0 { + if ix & N as i32 > 0 { + ret &= !BLEND_SHUFBA; + } else { + ret &= !BLEND_SHUFAB; + } + } + } + if ret & (BLEND_SHUFAB | BLEND_SHUFBA) > 0 { + // fits shufps/shufpd + let mut shuf_pattern = 0u8; + for iu in 0..lane_size { + shuf_pattern |= ((lane_pattern[iu] & 3) as u8) << (iu * 2); + } + ret |= (shuf_pattern as u64) << BLEND_SHUFPATTERN; + } + } + } else if n_lanes > 1 { + // not same pattern in all lanes + let mut rot = 999; + for ii in 0..N { + let ix = a[ii]; + if ix >= 0 { + let rot2: u32 = + (ix + 2 * N as i32 - ii as i32) as u32 % (2 * N) as u32; + if rot == 999 { + rot = rot2; + } else if rot != rot2 { + rot = 1000; + break; + } + } + } + if rot < 2 * N as u32 { + // fits big rotate + ret |= BLEND_ROTATE_BIG | (rot as u64) << BLEND_ROTPATTERN; + } + } + if lane_size == 2 && (ret & (BLEND_SHUFAB | BLEND_SHUFBA)) > 0 { + for ii in 0..N { + ret |= ((a[ii] & 1) as u64) << (BLEND_SHUFPATTERN + ii as u64); + } + } + + ret +} + +// largeblock_perm: return indexes for replacing a permute or blend with +// a certain block size by a permute or blend with the double block size. +// Note: it is presupposed that perm_flags() indicates perm_largeblock +// It is required that additional zeroing is added if perm_flags() indicates perm_addz +fn largeblock_perm( + a: &[i32; N], +) -> [i32; N2] { + // GeNERIc pARAMeTerS canNoT BE useD In consT ConTEXTs + assert!(N2 == N / 2); + + // Parameter a is a reference to a constexpr array of permutation indexes + let mut list = [0; N2]; + let mut fit_addz = false; + + // check if additional zeroing is needed at current block size + for i in (0..N).step_by(2) { + let ix = a[i]; + let iy = a[i + 1]; + if (ix == -1 && iy >= 0) || (iy == -1 && ix >= 0) { + fit_addz = true; + break; + } + } + + // loop through indexes + for i in (0..N).step_by(2) { + let ix = a[i]; + let iy = a[i + 1]; + let iz = if ix >= 0 { + ix / 2 + } else if iy >= 0 { + iy / 2 + } else if fit_addz { + V_DC + } else { + ix | iy + }; + list[i / 2] = iz; + } + + list +} + +// perm_flags: returns information about how a permute can be implemented. +// The return value is composed of these flag bits: + +// needs zeroing +const PERM_ZEROING: u64 = 1; +// permutation needed +const PERM_PERM: u64 = 2; +// all is zero or don't care +const PERM_ALLZERO: u64 = 4; +// fits permute with a larger block size (e.g permute Vec2q instead of Vec4i) +const PERM_LARGEBLOCK: u64 = 8; +// additional zeroing needed after permute with larger block size or shift +const PERM_ADDZ: u64 = 0x10; +// additional zeroing needed after perm_zext, perm_compress, or perm_expand +const PERM_ADDZ2: u64 = 0x20; +// permutation crossing 128-bit lanes +const PERM_CROSS_LANE: u64 = 0x40; +// same permute pattern in all 128-bit lanes +const PERM_SAME_PATTERN: u64 = 0x80; +// permutation pattern fits punpckh instruction +const PERM_PUNPCKH: u64 = 0x100; +// permutation pattern fits punpckl instruction +const PERM_PUNPCKL: u64 = 0x200; +// permutation pattern fits rotation within lanes. 4 bit count returned in bit perm_rot_count +const PERM_ROTATE: u64 = 0x400; +// permutation pattern fits shift right within lanes. 4 bit count returned in bit perm_rot_count +const PERM_SHRIGHT: u64 = 0x1000; +// permutation pattern fits shift left within lanes. negative count returned in bit perm_rot_count +const PERM_SHLEFT: u64 = 0x2000; +// permutation pattern fits rotation across lanes. 6 bit count returned in bit perm_rot_count +const PERM_ROTATE_BIG: u64 = 0x4000; +// permutation pattern fits broadcast of a single element. +const PERM_BROADCAST: u64 = 0x8000; +// permutation pattern fits zero extension +const PERM_ZEXT: u64 = 0x10000; +// permutation pattern fits vpcompress instruction +const PERM_COMPRESS: u64 = 0x20000; +// permutation pattern fits vpexpand instruction +const PERM_EXPAND: u64 = 0x40000; +// index out of range +const PERM_OUTOFRANGE: u64 = 0x10000000; +// rotate or shift count is in bits perm_rot_count to perm_rot_count+3 +const PERM_ROT_COUNT: u64 = 32; +// pattern for pshufd is in bit perm_ipattern to perm_ipattern + 7 if perm_same_pattern and elementsize >= 4 +const PERM_IPATTERN: u64 = 40; + +fn permute8(indexes: &[i32; 8], a: f32x8) -> f32x8 { + let mut y = a; + let flags = perm_flags::<32, 8>(indexes); + assert!( + flags & PERM_OUTOFRANGE == 0, + "Index out of range in permute function" + ); + + if flags & PERM_ALLZERO > 0 { + return f32x8::ZERO; + } + + if flags & PERM_PERM > 0 { + if flags & PERM_LARGEBLOCK > 0 { + let l = largeblock_perm::<8, 4>(indexes); + // SAFETY: Types are of same size + y = unsafe { transmute(permute4(&l, transmute(a))) }; + if flags & PERM_ADDZ == 0 { + // no remaining zeroing + return y; + } + } else if flags & PERM_SAME_PATTERN > 0 { + if flags & PERM_PUNPCKH != 0 { + // fits punpckhi + y = y.interleave_high(y); + } else if flags & PERM_PUNPCKL != 0 { + // fits punpcklo + y = y.interleave_low(y); + } else { + // general permute, same pattern in both lanes + y = a.shuffle_nonconst(a, (flags >> PERM_IPATTERN) as u8); + } + } else if flags & PERM_BROADCAST > 0 && flags >> PERM_ROT_COUNT == 0 { + // broadcast first element + y = f32x8::from(y.to_array()[0]); + } else if flags & PERM_ZEXT > 0 { + // zero extension + // y = _mm256_castsi256_ps(_mm256_cvtepu32_epi64( + // _mm256_castsi256_si128(_mm256_castps_si256(y)))); + todo!(); + if flags & PERM_ADDZ2 == 0 { + return y; + } + } else if flags & PERM_CROSS_LANE == 0 { + // __m256 m = constant8f(); + // y = _mm256_permutevar_ps(a, _mm256_castps_si256(m)); + todo!(); + } else { + // full permute needed + // __m256i permmask = + // _mm256_castps_si256(constant8f()); + // y = _mm256_permutevar8x32_ps(a, permmask); + todo!(); + } + } + + if flags & PERM_ZEROING > 0 { + let bm = i32x8::new(zero_mask_broad_8x32(indexes)); + // SAFETY: Types have the same size + let bm1: f32x8 = unsafe { transmute(bm) }; + y = bm1 & y; + } + + y +} + +fn perm_flags( + a: &[i32; ELEMS], +) -> u64 { + // number of 128-bit lanes + let num_lanes = ELEM_SIZE * ELEMS / 16; + let lane_size = ELEMS / num_lanes; + // current lane + let mut lane = 0usize; + // rotate left count + let mut rot = 999u32; + // index to broadcasted element + let mut broadc = 999i32; + // remember certain patterns that do not fit + let mut patfail = 0u32; + // remember certain patterns need extra zeroing + let mut addz2 = 0u32; + // last index in perm_compress fit + let mut compresslasti = -1i32; + // last position in perm_compress fit + let mut compresslastp = -1i32; + // last index in perm_expand fit + let mut expandlasti = -1i32; + // last position in perm_expand fit + let mut expandlastp = -1i32; + + let mut ret = PERM_LARGEBLOCK | PERM_SAME_PATTERN | PERM_ALLZERO; + let mut lane_pattern = vec![0i32; lane_size].into_boxed_slice(); + for i in 0..ELEMS { + let ix = a[i]; + // meaning of ix: -1 = set to zero, V_DC = don't care, non-negative value = permute. + if ix == -1 { + ret |= PERM_ZEROING; + } else if ix != V_DC && ix as usize >= ELEMS { + ret |= PERM_OUTOFRANGE; + } + + if ix >= 0 { + ret &= !PERM_ALLZERO; + if ix != i as i32 { + ret |= PERM_PERM; + } + if broadc == 999 { + // remember broadcast index + broadc = ix; + } else if broadc != ix { + // does not fit broadcast + broadc = 1000; + } + } + + // check if pattern fits a larger block size: + // even indexes must be even, odd indexes must fit the preceding even index + 1 + if i & 1 == 0 { + if ix > 0 && ix & 1 > 0 { + ret &= !PERM_LARGEBLOCK; + } + let iy = a[i + 1]; + if iy >= 0 && iy & 1 == 0 { + ret &= !PERM_LARGEBLOCK; + } + if ix >= 0 && iy >= 0 && iy != ix + 1 { + ret &= !PERM_LARGEBLOCK; + } + if ix == -1 && iy >= 0 { + ret |= PERM_ADDZ; + } + if iy == -1 && ix >= 0 { + ret |= PERM_ADDZ; + } + } + + lane = i / lane_size; + if lane == 0 { + // first lane, or no pattern yet + lane_pattern[i] = ix; + } + + // check if crossing lanes + if ix >= 0 { + let lanei = ix as usize / lane_size; + if lanei != lane { + ret |= PERM_CROSS_LANE; + } + } + + // check if same pattern in all lanes + if lane != 0 && ix >= 0 { + // not first lane + let j1 = i - lane * lane_size; + let jx = ix - (lane * lane_size) as i32; + if jx < 0 || jx >= lane_size as i32 { + // source is in another lane + ret &= !PERM_SAME_PATTERN; + } + if lane_pattern[j1] < 0 { + // pattern not known from previous lane + lane_pattern[j1] = jx; + } else if lane_pattern[j1] != jx { + // not same pattern + ret &= !PERM_SAME_PATTERN; + } + } + + if ix >= 0 { + // check if pattern fits zero extension (perm_zext) + if (ix * 2) as usize != i { + // does not fit zero extension + patfail |= 1; + } + // check if pattern fits compress (perm_compress) + if ix > compresslasti && ix - compresslasti >= i as i32 - compresslastp { + if i as i32 - compresslastp > 1 { + // perm_compress may need additional zeroing + addz2 |= 2; + } + compresslasti = ix; + compresslastp = i as i32; + } else { + // does not fit perm_compress + patfail |= 2; + } + // check if pattern fits expand (perm_expand) + if ix > expandlasti && ix - expandlasti <= i as i32 - expandlastp { + if ix - expandlasti > 1 { + // perm_expand may need additional zeroing + addz2 |= 4; + } + expandlasti = ix; + expandlastp = i as i32; + } else { + // does not fit perm_compress + patfail |= 4; + } + } else if ix == -1 && i & 1 == 0 { + // zero extension needs additional zeroing + addz2 |= 1; + } + } + + if ret & PERM_PERM == 0 { + return ret; + } + + if ret & PERM_LARGEBLOCK == 0 { + ret &= !PERM_ADDZ; + } + if ret & PERM_CROSS_LANE > 0 { + ret &= !PERM_SAME_PATTERN; + } + if patfail & 1 == 0 { + ret |= PERM_ZEXT; + if addz2 & 1 != 0 { + ret |= PERM_ADDZ2; + } + } else if patfail & 2 == 0 { + ret |= PERM_COMPRESS; + if addz2 & 2 != 0 && compresslastp > 0 { + for j in 0..(compresslastp as usize) { + if a[j] == -1 { + ret |= PERM_ADDZ2; + } + } + } + } else if patfail & 4 == 0 { + ret |= PERM_EXPAND; + if addz2 & 4 != 0 && expandlastp > 0 { + for j in 0..(expandlastp as usize) { + if a[j] == -1 { + ret |= PERM_ADDZ2; + } + } + } + } + + if ret & PERM_SAME_PATTERN > 0 { + // same pattern in all lanes. check if it fits specific patterns + let mut fit = true; + // fit shift or rotate + for i in 0..lane_size { + if lane_pattern[i] >= 0 { + let rot1 = lane_pattern[i] as u32 + lane_size as u32 + - i as u32 % lane_size as u32; + if rot == 999 { + rot = rot1; + } else if rot != rot1 { + fit = false; + } + } + } + rot &= lane_size as u32 - 1; + if fit { + // fits rotate, and possible shift + let rot2 = ((rot & ELEM_SIZE as u32) & 0xF) as u64; + ret |= rot2 << PERM_ROT_COUNT; + ret |= PERM_ROTATE; + // fit shift left + fit = true; + let mut i = 0; + while (i + rot as usize) < lane_size { + // check if first rot elements are zero or don't care + if lane_pattern[i] >= 0 { + fit = false; + } + i += 1; + } + if fit { + ret |= PERM_SHLEFT; + while i < lane_size { + if lane_pattern[i] == -1 { + ret |= PERM_ADDZ; + } + i += 1; + } + } + // fit shift right + fit = true; + i = lane_size - rot as usize; + while i < lane_size { + // check if last (lanesize-rot) elements are zero or don't care + if lane_pattern[i] >= 0 { + fit = false; + } + i += 1; + } + if fit { + ret |= PERM_SHRIGHT; + while i < lane_size - rot as usize { + if lane_pattern[i] == -1 { + ret |= PERM_ADDZ; + } + i += 1; + } + } + } + + // fit punpckhi + fit = true; + let mut j2 = lane_size / 2; + for i in 0..lane_size { + if lane_pattern[i] >= 0 && lane_pattern[i] != j2 as i32 { + fit = false; + } + if (i & 1) != 0 { + j2 += 1; + } + } + if fit { + ret |= PERM_PUNPCKH; + } + // fit punpcklo + fit = true; + j2 = 0; + for i in 0..lane_size { + if lane_pattern[i] >= 0 && lane_pattern[i] != j2 as i32 { + fit = false; + } + if (i & 1) != 0 { + j2 += 1; + } + } + if fit { + ret |= PERM_PUNPCKL; + } + // fit pshufd + if ELEM_SIZE >= 4 { + let mut p = 0u64; + for i in 0..lane_size { + if lane_size == 4 { + p |= ((lane_pattern[i] & 3) as u64) << (2 * i as u64); + } else { + // lanesize = 2 + p |= ((lane_pattern[i] & 1) as u64 * 10 + 4) << (4 * i as u64); + } + } + ret |= p << PERM_IPATTERN; + } + } else { + // not same pattern in all lanes + if num_lanes > 1 { + // Try if it fits big rotate + for i in 0..ELEMS { + let ix = a[i]; + if ix >= 0 { + // rotate count + let mut rot2: u32 = + (ix as u32 + ELEMS as u32 - i as u32) % ELEMS as u32; + if rot == 999 { + // save rotate count + rot = rot2; + } else if rot != rot2 { + // does not fit big rotate + rot = 1000; + break; + } + } + } + if rot < ELEMS as u32 { + // fits big rotate + ret |= PERM_ROTATE_BIG | (rot as u64) << PERM_ROT_COUNT; + } + } + } + + if broadc < 999 + && ret & (PERM_ROTATE | PERM_SHRIGHT | PERM_SHLEFT | PERM_ROTATE_BIG) == 0 + { + // fits broadcast + ret |= PERM_BROADCAST | (broadc as u64) << PERM_ROT_COUNT; + } + + ret +} + +fn make_bit_mask(a: &[i32; LEN]) -> u64 { + let mut ret = 0u64; + let sel_bit_idx = (BITS & 0xFF) as u8; + let mut ret_bit_idx = 0u64; + let mut flip = 0u64; + for i in 0..LEN { + let ix = a[i]; + if ix < 0 { + ret_bit_idx = ((BITS >> 10) & 1) as u64; + } else { + ret_bit_idx = (((ix as u32) >> sel_bit_idx) & 1) as u64; + if i < LEN / 2 { + flip = ((BITS >> 8) & 1) as u64; + } else { + flip = ((BITS >> 9) & 1) as u64; + } + ret_bit_idx ^= flip ^ 1; + } + ret |= ret_bit_idx << i; + } + ret +} + +// blend_perm_indexes: return an Indexlist for implementing a blend function as +// two permutations. N = vector size. +// dozero = 0: let unused elements be don't care. The two permutation results must be blended +// dozero = 1: zero unused elements in each permuation. The two permutation results can be OR'ed +// dozero = 2: indexes that are -1 or V_DC are preserved +fn blend_perm_indexes( + a: &[i32; N], +) -> [i32; N2] { + assert!(N * 2 == N2); + assert!(DO_ZERO <= 2); + let mut list = [0i32; N2]; + let u = if DO_ZERO > 0 { -1 } else { V_DC }; + for j in 0..N { + let ix = a[j]; + if ix < 0 { + if DO_ZERO == 2 { + list[j] = ix; + list[j + N] = ix; + } else { + list[j] = u; + list[j + N] = u; + } + } else if ix < N as i32 { + list[j] = ix; + list[j + N] = u; + } else { + list[j] = u; + list[j + N] = ix - N as i32; + } + } + list +} + +#[inline(always)] +fn zero_mask_broad_8x32(a: &[i32; 8]) -> [i32; 8] { + let mut u = [0i32; 8]; + for i in 0..8 { + u[i] = if a[i] >= 0 { -1 } else { 0 }; + } + u +} + +#[inline(always)] +fn zero_mask_broad_4x64(a: &[i32; 4]) -> [i64; 4] { + let mut u = [0i64; 4]; + for i in 0..4 { + u[i] = if a[i] >= 0 { -1 } else { 0 }; + } + u +} + +fn permute4(indexes: &[i32; 4], a: f64x4) -> f64x4 { + let mut y = a; + let flags = perm_flags::<64, 4>(indexes); + + assert!( + flags & PERM_OUTOFRANGE == 0, + "Index out of range in permute function" + ); + + if flags & PERM_ALLZERO != 0 { + return f64x4::ZERO; + } + + if flags & PERM_LARGEBLOCK != 0 { + // permute 128-bit blocks + let l = largeblock_perm::<4, 2>(indexes); + let j0 = l[0]; + let j1 = l[1]; + + if j0 == 0 && j1 == -1 && flags & PERM_ADDZ == 0 { + // zero extend + // return _mm256_zextpd128_pd256(_mm256_castpd256_pd128(y)); + todo!(); + } + if j0 == 1 && j1 < 0 && flags & PERM_ADDZ == 0 { + // extract upper part, zero extend + // return _mm256_zextpd128_pd256(_mm256_extractf128_pd(y, 1)); + todo!(); + } + if flags & PERM_PERM != 0 && flags & PERM_ZEROING == 0 { + // return _mm256_permute2f128_pd(y, y, (j0 & 1) | (j1 & 1) << 4); + todo!(); + } + } + + if flags & PERM_PERM != 0 { + // permutation needed + if flags & PERM_SAME_PATTERN != 0 { + // same pattern in both lanes + if flags & PERM_PUNPCKH != 0 { + y = y.interleave_high(y); + } else if flags & PERM_PUNPCKL != 0 { + y = y.interleave_low(y); + } else { + // general permute + let mm0 = (indexes[0] & 1) + | (indexes[1] & 1) << 1 + | (indexes[2] & 1) << 2 + | (indexes[3] & 1) << 3; + // select within same lane + // y = _mm256_permute_pd(a, mm0); + todo!(); + } + } else if flags & PERM_BROADCAST != 0 && (flags >> PERM_ROT_COUNT) == 0 { + // broadcast first element + // y = _mm256_broadcastsd_pd( + // _mm256_castpd256_pd128(y)); + todo!(); + } else { + // different patterns in two lanes + if flags & PERM_CROSS_LANE == 0 { + // no lane crossing + // constexpr uint8_t mm0 = + // (i0 & 1) | (i1 & 1) << 1 | (i2 & 1) << 2 | (i3 & 1) << 3; + // y = _mm256_permute_pd(a, mm0); // select within same lane + todo!(); + } else { + // // full permute + // constexpr uint8_t mms = + // (i0 & 3) | (i1 & 3) << 2 | (i2 & 3) << 4 | (i3 & 3) << 6; + // y = _mm256_permute4x64_pd(a, mms); + todo!(); + } + } + } + + if flags & PERM_ZEROING != 0 { + // additional zeroing needed + // use broad mask + // constexpr EList bm = zero_mask_broad(indexs); + // // y = _mm256_and_pd(_mm256_castsi256_pd( Vec4q().load(bm.a) ), y); // does + // // not work with INSTRSET = 7 + // __m256i bm1 = _mm256_loadu_si256((const __m256i *)(bm.a)); + // y = _mm256_and_pd(_mm256_castsi256_pd(bm1), y); + todo!(); + } + y +} diff --git a/src/denoise/kernel.rs b/src/denoise/kernel.rs new file mode 100644 index 0000000000..443439bea1 --- /dev/null +++ b/src/denoise/kernel.rs @@ -0,0 +1,356 @@ +use std::{ + mem::{size_of, transmute}, + ptr::copy_nonoverlapping, +}; + +use super::blend::{blend16, blend8}; +use super::{ + bitblt, calc_pad_size, f32x16, BLOCK_SIZE, TEMPORAL_RADIUS, TEMPORAL_SIZE, +}; +use crate::util::cast; +use av_metrics::video::Pixel; +use num_traits::clamp; +use wide::{f32x8, u16x8, u8x16}; + +pub fn reflection_padding( + output: &mut [T], input: &[T], width: usize, height: usize, stride: usize, +) { + let pad_width = calc_pad_size(width); + let pad_height = calc_pad_size(height); + + let offset_y = (pad_height - height) / 2; + let offset_x = (pad_width - width) / 2; + + bitblt( + &mut output[(offset_y * pad_width + offset_x)..], + pad_width, + input, + stride, + width, + height, + ); + + // copy left and right regions + for y in offset_y..(offset_y + height) { + let dst_line = &mut output[(y * pad_width)..]; + + for x in 0..offset_x { + dst_line[x] = dst_line[offset_x * 2 - x]; + } + + for x in (offset_x + width)..pad_width { + dst_line[x] = dst_line[2 * (offset_x + width) - 2 - x]; + } + } + + // copy top region + for y in 0..offset_y { + let dst = output[(y * pad_width)..][..pad_width].as_mut_ptr(); + let src = output[((offset_y * 2 - y) * pad_width)..][..pad_width].as_ptr(); + // SAFETY: We check the start and end bounds above. + // We have to use `copy_nonoverlapping` because the borrow checker is not happy with + // `copy_from_slice`. + unsafe { + copy_nonoverlapping(src, dst, pad_width); + } + } + + // copy bottom region + for y in (offset_y + height)..pad_height { + let dst = output[(y * pad_width)..][..pad_width].as_mut_ptr(); + let src = output[((2 * (offset_y + height) - 2 - y) * pad_width)..] + [..pad_width] + .as_ptr(); + // SAFETY: We check the start and end bounds above. + // We have to use `copy_nonoverlapping` because the borrow checker is not happy with + // `copy_from_slice`. + unsafe { + copy_nonoverlapping(src, dst, pad_width); + } + } +} + +pub fn load_block( + block: &mut [f32x16], shifted_src: &[T], width: usize, height: usize, + bit_depth: usize, window: &[f32x16], +) { + let scale = 1.0f32 / (1 << (bit_depth - 8)) as f32; + let offset_x = calc_pad_size(width); + let offset_y = calc_pad_size(height); + for i in 0..TEMPORAL_SIZE { + for j in 0..BLOCK_SIZE { + // The compiler will optimize away these branches + let vec_input = if size_of::() == 1 { + // SAFETY: We know that T is u8 + let u8s: [u8; 16] = unsafe { + *transmute::<_, &[u8; 16]>(cast::<16, _>( + &shifted_src[((i * offset_y + j) * offset_x)..][..16], + )) + }; + let f32_upper = u8x16::new(u8s).to_u32x8().to_f32x8(); + let f32_lower = u8x16::new([ + u8s[8], u8s[9], u8s[10], u8s[11], u8s[12], u8s[13], u8s[14], + u8s[15], 0, 0, 0, 0, 0, 0, 0, 0, + ]) + .to_u32x8() + .to_f32x8(); + [f32_upper, f32_lower] + } else { + // SAFETY: We know that T is u8 + let u16s: [u16; 16] = unsafe { + *transmute::<_, &[u16; 16]>(cast::<16, _>( + &shifted_src[((i * offset_y + j) * offset_x)..][..16], + )) + }; + let f32_upper = u16x8::new(*cast(&u16s[..8])).to_u32x8().to_f32x8(); + let f32_lower = u16x8::new(*cast(&u16s[8..])).to_u32x8().to_f32x8(); + [f32_upper, f32_lower] + }; + let window = &window[i * BLOCK_SIZE + j]; + let result_upper = scale * window[0] * vec_input[0]; + let result_lower = scale * window[1] * vec_input[1]; + block[i * BLOCK_SIZE * 2 + j] = [result_upper, result_lower]; + } + } +} + +pub fn fused( + block: &mut [f32x16], sigma: f32, pmin: f32, pmax: f32, + window_freq: &[f32x16], +) { + for i in 0..TEMPORAL_SIZE { + transpose_16x16::<1>(&mut block[(i * 32)..]); + rdft::<16>(&mut block[(i * 32)..]); + transpose_32x16(&mut block[(i * 32)..]); + dft::<16>(&mut block[(i * 32)..]); + } + for i in 0..16 { + dft::<3>(&mut block[(i * 2)..], 16); + } + + let gf = block[0].extract(0) / window_freq[0].extract(0); + remove_mean(block, gf, window_freq); + + frequency_filtering(block, sigma, pmin, pmax); + + add_mean(block, gf, window_freq); + + for i in 0..16 { + idft::<3>(&mut block[(i * 2)..], 16); + } + idft::<16>(&mut block[(TEMPORAL_RADIUS * 32)..]); + transpose_32x16(&mut block[(TEMPORAL_RADIUS * 32)..]); + irdft::<16>(&mut block[(TEMPORAL_RADIUS * 32)..]); + post_irdft::<16>(&mut block[(TEMPORAL_RADIUS * 32)..]); + transpose_16x16::<1>(&mut block[(TEMPORAL_RADIUS * 32)..]); +} + +pub fn store_block( + shifted_dst: &mut [f32], shifted_block: &[f32x16], width: usize, + height: usize, shifted_window: &[f32x16], +) { + let pad_size = calc_pad_size(width); + for i in 0..BLOCK_SIZE { + let acc = &mut shifted_dst[(i * pad_size)..]; + let mut acc_simd = + [f32x8::new(*cast(&acc[..8])), f32x8::new(*cast(&acc[8..16]))]; + acc_simd[0] = + shifted_block[i][0].mul_add(shifted_window[i][0], acc_simd[0]); + acc_simd[1] = + shifted_block[i][1].mul_add(shifted_window[i][1], acc_simd[1]); + acc[..8].copy_from_slice(acc_simd[0].as_array_ref()); + acc[8..16].copy_from_slice(acc_simd[1].as_array_ref()); + } +} + +pub fn store_frame( + output: &mut [T], shifted_src: &[f32], width: usize, height: usize, + bit_depth: usize, dst_stride: usize, src_stride: usize, +) { + let scale = 1.0f32 / (1 << (bit_depth - 8)) as f32; + let peak = (1u32 << bit_depth) - 1; + for y in 0..height { + for x in 0..width { + // SAFETY: We know the bounds of the planes for src and dest + unsafe { + let clamped = clamp( + (*shifted_src.get_unchecked(y * src_stride + x) / scale + 0.5f32) + as u32, + 0u32, + peak, + ); + *output.get_unchecked_mut(y * dst_stride + x) = T::cast_from(clamped); + } + } + } +} + +fn transpose_16x16(block: &mut [f32x16]) { + for i in 0..2 { + for j in 0..2 { + for k in 0..2 { + let id1 = ((i * 2 + j) * 2 + k) * 2 * STRIDE; + let id2 = (((i * 2 + j) * 2 + k) * 2 + 1) * STRIDE; + let temp1 = blend16::< + 0, + 2, + 16, + 18, + 4, + 6, + 20, + 22, + 8, + 10, + 24, + 26, + 12, + 14, + 28, + 30, + >(block[id1], block[id2]); + let temp2 = blend16::< + 1, + 3, + 17, + 19, + 5, + 7, + 21, + 23, + 9, + 11, + 25, + 27, + 13, + 15, + 29, + 31, + >(block[id1], block[id2]); + block[id1] = temp1; + block[id2] = temp2; + } + + for k in 0..2 { + let id1 = (((i * 2 + j) * 2) * 2 + k) * STRIDE; + let id2 = (((i * 2 + j) * 2 + 1) * 2 + k) * STRIDE; + let temp1 = blend16::< + 0, + 2, + 16, + 18, + 4, + 6, + 20, + 22, + 8, + 10, + 24, + 26, + 12, + 14, + 28, + 30, + >(block[id1], block[id2]); + let temp2 = blend16::< + 1, + 3, + 17, + 19, + 5, + 7, + 21, + 23, + 9, + 11, + 25, + 27, + 13, + 15, + 29, + 31, + >(block[id1], block[id2]); + block[id1] = temp1; + block[id2] = temp2; + } + } + + for j in 0..4 { + let id1 = (i * 8 + j) * STRIDE; + let id2 = (i * 8 + 4 + j) * STRIDE; + // SAFETY: Types are the same size + let temp1 = unsafe { + transmute(blend8::<0, 1, 8, 9, 4, 5, 12, 13>( + transmute(block[id1]), + transmute(block[id2]), + )) + }; + // SAFETY: Types are the same size + let temp2 = unsafe { + transmute(blend8::<2, 3, 10, 11, 6, 7, 14, 15>( + transmute(block[id1]), + transmute(block[id2]), + )) + }; + block[id1] = temp1; + block[id2] = temp2; + } + } + + for i in 0..8 { + // SAFETY: Types are the same size + let temp1 = unsafe { + transmute(blend8::<0, 1, 2, 3, 8, 9, 10, 11>( + transmute(block[i * STRIDE]), + transmute(block[(i + 8) * STRIDE]), + )) + }; + // SAFETY: Types are the same size + let temp2 = unsafe { + transmute(blend8::<4, 5, 6, 7, 12, 13, 14, 15>( + transmute(block[i * STRIDE]), + transmute(block[(i + 8) * STRIDE]), + )) + }; + block[i * STRIDE] = temp1; + block[(i + 8) * STRIDE] = temp2; + } +} + +fn transpose_32x16(block: &mut [f32x16]) { + const STRIDE: usize = 1; + transpose_16x16::<2>(block); + transpose_16x16::<2>(&mut block[STRIDE..]); +} + +fn remove_mean() { + todo!() +} + +fn frequency_filtering() { + todo!() +} + +fn add_mean() { + todo!() +} + +fn rdft(block: &mut [f32x16]) { + todo!() +} + +fn dft(block: &mut [f32x16]) { + todo!() +} + +fn idft() { + todo!() +} + +fn irdft() { + todo!() +} + +fn post_irdft() { + todo!() +} diff --git a/src/denoise/mod.rs b/src/denoise/mod.rs new file mode 100644 index 0000000000..429c55f6ec --- /dev/null +++ b/src/denoise/mod.rs @@ -0,0 +1,358 @@ +mod blend; +mod kernel; + +use crate::api::FrameQueue; +use crate::util::{cast, Aligned}; +use crate::EncoderStatus; +use arrayvec::ArrayVec; +use kernel::*; +use num_complex::Complex32; +use num_traits::Zero; +use std::f32::consts::PI; +use std::iter::once; +use std::mem::{size_of, transmute}; +use std::ptr::copy_nonoverlapping; +use std::sync::Arc; +use v_frame::frame::Frame; +use v_frame::pixel::Pixel; +use wide::{f32x8, f64x4}; + +pub const TEMPORAL_RADIUS: usize = 1; +const TEMPORAL_SIZE: usize = TEMPORAL_RADIUS * 2 + 1; +const BLOCK_SIZE: usize = 16; +const BLOCK_STEP: usize = 12; +const REAL_SIZE: usize = TEMPORAL_SIZE * BLOCK_SIZE * BLOCK_SIZE; +const COMPLEX_SIZE: usize = TEMPORAL_SIZE * BLOCK_SIZE * (BLOCK_SIZE / 2 + 1); + +// The C implementation uses this f32x16 type, which is implemented the same way +// for non-avx512 systems. `wide` doesn't have a f32x16 type, so we mimic it. +#[allow(non_camel_case_types)] +type f32x16 = [f32x8; 2]; +#[allow(non_camel_case_types)] +type f64x8 = [f64x4; 2]; + +/// This denoiser is based on the DFTTest2 plugin from Vapoursynth. +/// This type of denoising was chosen because it provides +/// high quality while not being too slow. The DFTTest2 implementation +/// is much faster than the original due to its custom FFT kernel. +pub struct DftDenoiser +where + T: Pixel, +{ + // External values + prev_frame: Option>>, + pub(crate) cur_frameno: u64, + bit_depth: usize, + // Local values + sigma: f32, + window: Aligned<[f32; REAL_SIZE]>, + window_freq: Aligned<[Complex32; COMPLEX_SIZE]>, + pmin: f32, + pmax: f32, + padded: Vec, + padded2: Vec, +} + +impl DftDenoiser +where + T: Pixel, +{ + // This should only need to run once per video. + pub fn new( + sigma: f32, width: usize, height: usize, bit_depth: usize, + ) -> Self { + if size_of::() == 1 { + assert!(bit_depth <= 8); + } else { + assert!(bit_depth > 8); + } + + let window = build_window(); + let wscale = window.iter().copied().map(|w| w * w).sum::(); + let sigma = sigma as f32 * wscale; + let pmin = 0.0f32; + let pmax = 500.0f32 * wscale; + let mut window_freq_real = Aligned::new([0f32; REAL_SIZE]); + window_freq_real.iter_mut().zip(window.iter()).for_each(|(freq, w)| { + *freq = *w as f32 * 255.0; + }); + let window_freq = rdft(&window_freq_real); + let w_pad_size = calc_pad_size(width); + let h_pad_size = calc_pad_size(height); + let pad_size = w_pad_size * h_pad_size; + let padded = vec![T::zero(); pad_size * TEMPORAL_SIZE]; + let padded2 = vec![0f32; pad_size]; + + Self { + prev_frame: None, + cur_frameno: 0, + bit_depth, + sigma, + window, + window_freq, + pmin, + pmax, + padded, + padded2, + } + } + + pub fn filter_frame( + &mut self, frame_q: &FrameQueue, + ) -> Result, EncoderStatus> { + let next_frame = frame_q.get(&(self.cur_frameno + 1)); + if next_frame.is_none() { + // We also need to have the next unfiltered frame, + // unless we are at the end of the video. + return Err(EncoderStatus::NeedMoreData); + } + + let next_frame = next_frame.cloned().flatten(); + let orig_frame = frame_q.get(&self.cur_frameno).unwrap().as_ref().unwrap(); + let frames = + once(self.prev_frame.clone().unwrap_or_else(|| Arc::clone(orig_frame))) + .chain(once(Arc::clone(orig_frame))) + .chain(once(next_frame.unwrap_or_else(|| Arc::clone(orig_frame)))) + .collect::>(); + + let mut dest = (**orig_frame).clone(); + for p in 0..3 { + let width = frames[0].planes[p].cfg.width; + let height = frames[0].planes[p].cfg.height; + let stride = frames[0].planes[p].cfg.stride; + let w_pad_size = calc_pad_size(width); + let h_pad_size = calc_pad_size(height); + let pad_size_spatial = w_pad_size * h_pad_size; + for i in 0..TEMPORAL_SIZE { + let src = &frames[i].planes[p]; + reflection_padding( + &mut self.padded[(i * pad_size_spatial)..], + src.data_origin(), + width, + height, + stride, + ) + } + + for i in 0..h_pad_size { + for j in 0..w_pad_size { + let mut block = [f32x16::default(); 7 * BLOCK_SIZE * 2]; + let offset_x = w_pad_size; + load_block( + &mut block, + &self.padded[((i * offset_x + j) * BLOCK_STEP)..], + width, + height, + self.bit_depth, + // SAFETY: We know that the window size is a multiple of 16 + unsafe { transmute(&self.window[..]) }, + ); + fused( + &mut block, + self.sigma, + self.pmin, + self.pmax, + // SAFETY: We know that the window size is a multiple of 16 + unsafe { transmute(&self.window_freq[..]) }, + ); + store_block( + &mut self.padded2[((i * offset_x + j) * BLOCK_STEP)..], + &block[(TEMPORAL_RADIUS * BLOCK_SIZE * 2)..], + width, + height, + // SAFETY: We know that the window size is a multiple of 16 + unsafe { + transmute( + &self.window[(TEMPORAL_RADIUS * BLOCK_SIZE * 2 * 16)..], + ) + }, + ); + todo!() + } + } + + let offset_y = (h_pad_size - height) / 2; + let offset_x = (w_pad_size - width) / 2; + let dest_plane = &mut dest.planes[p]; + store_frame( + dest_plane.data_origin_mut(), + &self.padded2[(offset_y * w_pad_size + offset_x)..], + width, + height, + self.bit_depth, + stride, + w_pad_size, + ); + } + + self.prev_frame = Some(Arc::clone(orig_frame)); + self.cur_frameno += 1; + + Ok(dest) + } +} + +#[inline(always)] +// Hanning windowing +fn spatial_window_value(n: f32) -> f32 { + let temp = PI * n / BLOCK_SIZE as f32; + 0.5 * (1.0 - (2.0 * temp).cos()) +} + +#[inline(always)] +// Simple rectangular windowing +const fn temporal_window_value() -> f32 { + 1.0 +} + +fn build_window() -> Aligned<[f32; REAL_SIZE]> { + let temporal_window = [temporal_window_value(); TEMPORAL_SIZE]; + + let mut spatial_window = [0f32; BLOCK_SIZE]; + spatial_window.iter_mut().enumerate().for_each(|(i, val)| { + *val = spatial_window_value(i as f32 + 0.5); + }); + let spatial_window = normalize(&spatial_window); + + let mut window = Aligned::new([0f32; REAL_SIZE]); + let mut i = 0; + for t_val in temporal_window { + for s_val1 in spatial_window { + for s_vals2 in spatial_window.chunks_exact(8) { + let s_val2 = f32x8::new(*cast::<8, _>(s_vals2)); + let mut value = t_val * s_val1 * s_val2; + // normalize for unnormalized FFT implementation + value /= + f32x8::from((TEMPORAL_SIZE as f32).sqrt() * BLOCK_SIZE as f32); + // SAFETY: We know the slices are valid sizes + unsafe { + copy_nonoverlapping( + value.as_array_ref().as_ptr(), + window.as_mut_ptr().add(i), + 8usize, + ) + }; + i += 8; + } + } + } + window +} + +fn normalize(window: &[f32; BLOCK_SIZE]) -> [f32; BLOCK_SIZE] { + let mut new_window = [0f32; BLOCK_SIZE]; + // SAFETY: We know all of the sizes, so bound checks are not needed. + unsafe { + for q in 0..BLOCK_SIZE { + let nw = new_window.get_unchecked_mut(q); + for h in (0..=q).rev().step_by(BLOCK_STEP) { + *nw += window.get_unchecked(h).powi(2); + } + for h in ((q + BLOCK_STEP)..BLOCK_SIZE).step_by(BLOCK_STEP) { + *nw += window.get_unchecked(h).powi(2); + } + } + } + for (w, nw) in window.iter().zip(new_window.iter_mut()) { + *nw = *w / nw.sqrt(); + } + new_window +} + +// Identical to Vapoursynth's implementation `vs_bitblt` +// which basically copies the pixels in a plane. +pub fn bitblt( + mut dest: &mut [T], dest_stride: usize, mut src: &[T], src_stride: usize, + width: usize, height: usize, +) { + if src_stride == dest_stride && src_stride == width { + dest[..(width * height)].copy_from_slice(&src[..(width * height)]); + } else { + for _ in 0..height { + dest[..width].copy_from_slice(&src[..width]); + src = &src[src_stride..]; + dest = &mut dest[dest_stride..]; + } + } +} + +fn rdft( + input: &Aligned<[f32; REAL_SIZE]>, +) -> Aligned<[Complex32; COMPLEX_SIZE]> { + const SHAPE: [usize; 3] = [TEMPORAL_SIZE, BLOCK_SIZE, BLOCK_SIZE]; + + let mut output = Aligned::new([Complex32::zero(); COMPLEX_SIZE]); + + for i in 0..(SHAPE[0] * SHAPE[1]) { + dft( + &mut output[(i * (SHAPE[1] / 2 + 1))..], + DftInput::Real(&input[(i * SHAPE[1])..]), + SHAPE[2], + 1, + ); + } + + let mut output2 = Aligned::new([Complex32::zero(); COMPLEX_SIZE]); + + let stride = SHAPE[2] / 2 + 1; + for i in 0..SHAPE[0] { + for j in 0..stride { + dft( + &mut output2[(i * SHAPE[1] * stride + j)..], + DftInput::Complex(&output[(i * SHAPE[1] * stride + j)..]), + SHAPE[1], + stride, + ); + } + } + + let stride = SHAPE[1] * stride; + for i in 0..stride { + dft(&mut output[i..], DftInput::Complex(&output2[i..]), SHAPE[0], stride); + } + + output +} + +enum DftInput<'a> { + Real(&'a [f32]), + Complex(&'a [Complex32]), +} + +#[inline(always)] +fn dft(output: &mut [Complex32], input: DftInput, n: usize, stride: usize) { + match input { + DftInput::Real(input) => { + let out_num = n / 2 + 1; + for i in 0..out_num { + let mut sum = Complex32::zero(); + for j in 0..n { + let imag = -2f32 * i as f32 * j as f32 * PI / n as f32; + let weight = Complex32::new(imag.cos(), imag.sin()); + sum += input[j * stride] * weight; + } + output[i * stride] = sum; + } + } + DftInput::Complex(input) => { + let out_num = n; + for i in 0..out_num { + let mut sum = Complex32::zero(); + for j in 0..n { + let imag = -2f32 * i as f32 * j as f32 * PI / n as f32; + let weight = Complex32::new(imag.cos(), imag.sin()); + sum += input[j * stride] * weight; + } + output[i * stride] = sum; + } + } + } +} + +#[inline(always)] +fn calc_pad_size(size: usize) -> usize { + size + + if size % BLOCK_SIZE > 0 { BLOCK_SIZE - size % BLOCK_SIZE } else { 0 } + + BLOCK_SIZE + - BLOCK_STEP.max(BLOCK_STEP) * 2 +} diff --git a/src/fuzzing.rs b/src/fuzzing.rs index aab9abe059..d440f3f88f 100644 --- a/src/fuzzing.rs +++ b/src/fuzzing.rs @@ -257,6 +257,7 @@ impl Arbitrary for ArbitraryEncoder { switch_frame_interval: u.int_in_range(0..=3)?, tune: *u.choose(&[Tune::Psnr, Tune::Psychovisual])?, film_grain_params: None, + denoise_strength: u.int_in_range(0..=50)?, }; let frame_count = diff --git a/src/lib.rs b/src/lib.rs index 3425588db4..0c31e91123 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -257,6 +257,8 @@ mod cdef; #[doc(hidden)] pub mod context; mod deblock; +#[doc(hidden)] +pub mod denoise; mod encoder; mod entropymode; mod lrf; @@ -446,6 +448,9 @@ pub mod bench { pub mod context { pub use crate::context::*; } + pub mod denoise { + pub use crate::denoise::*; + } pub mod dist { pub use crate::dist::*; } diff --git a/src/test_encode_decode/mod.rs b/src/test_encode_decode/mod.rs index 9e6082ed55..8bdecca838 100644 --- a/src/test_encode_decode/mod.rs +++ b/src/test_encode_decode/mod.rs @@ -10,9 +10,8 @@ // Fuzzing only uses a subset of these. #![cfg_attr(fuzzing, allow(unused))] -use crate::color::ChromaSampling; - use crate::api::config::GrainTableSegment; +use crate::color::ChromaSampling; use crate::util::Pixel; use crate::*; diff --git a/src/util/align.rs b/src/util/align.rs index c86424e8b2..02928698cd 100644 --- a/src/util/align.rs +++ b/src/util/align.rs @@ -42,6 +42,20 @@ impl Aligned { } } +impl std::ops::Deref for Aligned { + type Target = T; + + fn deref(&self) -> &T { + &self.data + } +} + +impl std::ops::DerefMut for Aligned { + fn deref_mut(&mut self) -> &mut T { + &mut self.data + } +} + /// An analog to a Box<[T]> where the underlying slice is aligned. /// Alignment is according to the architecture-specific SIMD constraints. pub struct AlignedBoxedSlice { diff --git a/src/util/array.rs b/src/util/array.rs new file mode 100644 index 0000000000..afeba3f244 --- /dev/null +++ b/src/util/array.rs @@ -0,0 +1,13 @@ +pub fn cast(x: &[T]) -> &[T; N] { + // SAFETY: we perform a bounds check with [..N], + // so casting to *const [T; N] is valid because the bounds + // check guarantees that x has N elements + unsafe { &*(&x[..N] as *const [T] as *const [T; N]) } +} + +pub fn cast_mut(x: &mut [T]) -> &mut [T; N] { + // SAFETY: we perform a bounds check with [..N], + // so casting to *mut [T; N] is valid because the bounds + // check guarantees that x has N elements + unsafe { &mut *(&mut x[..N] as *mut [T] as *mut [T; N]) } +} diff --git a/src/util/mod.rs b/src/util/mod.rs index e5af5c2461..c7c30aef45 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -8,6 +8,7 @@ // PATENTS file, you can obtain it at www.aomedia.org/license/patent. mod align; +mod array; #[macro_use] mod cdf; mod kmeans; @@ -18,6 +19,7 @@ pub use v_frame::math::*; pub use v_frame::pixel::*; pub use align::*; +pub(crate) use array::*; pub use uninit::*; pub use kmeans::*;