Skip to content

Commit

Permalink
Use const generics for right shift in forward transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
barrbrain committed Oct 25, 2023
1 parent 8eb0c4b commit 4d1276c
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 174 deletions.
70 changes: 18 additions & 52 deletions src/asm/aarch64/transform/forward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,46 +19,6 @@ use debug_unreachable::debug_unreachable;

use core::arch::aarch64::*;

#[inline]
unsafe fn vrshrq_n_s32_switch(a: int32x4_t, n: i32) -> int32x4_t {
match n {
0 => a,
1 => vrshrq_n_s32(a, 1),
2 => vrshrq_n_s32(a, 2),
3 => vrshrq_n_s32(a, 3),
4 => vrshrq_n_s32(a, 4),
5 => vrshrq_n_s32(a, 5),
6 => vrshrq_n_s32(a, 6),
7 => vrshrq_n_s32(a, 7),
8 => vrshrq_n_s32(a, 8),
9 => vrshrq_n_s32(a, 9),
10 => vrshrq_n_s32(a, 10),
11 => vrshrq_n_s32(a, 11),
12 => vrshrq_n_s32(a, 12),
13 => vrshrq_n_s32(a, 13),
14 => vrshrq_n_s32(a, 14),
15 => vrshrq_n_s32(a, 15),
16 => vrshrq_n_s32(a, 16),
17 => vrshrq_n_s32(a, 17),
18 => vrshrq_n_s32(a, 18),
19 => vrshrq_n_s32(a, 19),
20 => vrshrq_n_s32(a, 20),
21 => vrshrq_n_s32(a, 21),
22 => vrshrq_n_s32(a, 22),
23 => vrshrq_n_s32(a, 23),
24 => vrshrq_n_s32(a, 24),
25 => vrshrq_n_s32(a, 25),
26 => vrshrq_n_s32(a, 26),
27 => vrshrq_n_s32(a, 27),
28 => vrshrq_n_s32(a, 28),
29 => vrshrq_n_s32(a, 29),
30 => vrshrq_n_s32(a, 30),
31 => vrshrq_n_s32(a, 31),
32 => vrshrq_n_s32(a, 32),
_ => unreachable!(),
}
}

#[derive(Clone, Copy)]
#[repr(transparent)]
struct I32X8(int32x4x2_t);
Expand Down Expand Up @@ -92,10 +52,10 @@ impl TxOperations for I32X8 {
}

#[inline]
unsafe fn tx_mul(self, mul: (i32, i32)) -> Self {
unsafe fn tx_mul<const SHIFT: i32>(self, mul: i32) -> Self {
I32X8::new(
vrshrq_n_s32_switch(vmulq_n_s32(self.vec().0, mul.0), mul.1),
vrshrq_n_s32_switch(vmulq_n_s32(self.vec().1, mul.0), mul.1),
vrshrq_n_s32(vmulq_n_s32(self.vec().0, mul), SHIFT),
vrshrq_n_s32(vmulq_n_s32(self.vec().1, mul), SHIFT),
)
}

Expand Down Expand Up @@ -268,11 +228,8 @@ unsafe fn shift_left_neon(a: I32X8, shift: u8) -> I32X8 {
}

#[inline]
unsafe fn shift_right_neon(a: I32X8, shift: u8) -> I32X8 {
I32X8::new(
vrshrq_n_s32_switch(a.vec().0, shift.into()),
vrshrq_n_s32_switch(a.vec().1, shift.into()),
)
unsafe fn shift_right_neon<const SHIFT: i32>(a: I32X8) -> I32X8 {
I32X8::new(vrshrq_n_s32(a.vec().0, SHIFT), vrshrq_n_s32(a.vec().1, SHIFT))
}

#[inline]
Expand All @@ -285,11 +242,20 @@ unsafe fn round_shift_array_neon(arr: &mut [I32X8], bit: i8) {
return;
}
if bit > 0 {
let shift = bit as u8;
for s in arr.chunks_exact_mut(4) {
for chunk in s {
*chunk = shift_right_neon(*chunk, shift);
if bit == 1 {
for s in arr.chunks_exact_mut(4) {
for chunk in s {
*chunk = shift_right_neon::<1>(*chunk)
}
}
} else if bit == 2 {
for s in arr.chunks_exact_mut(4) {
for chunk in s {
*chunk = shift_right_neon::<2>(*chunk)
}
}
} else {
debug_unreachable!();
}
} else {
let shift = (-bit) as u8;
Expand Down
8 changes: 4 additions & 4 deletions src/asm/x86/transform/forward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ impl TxOperations for I32X8 {

#[target_feature(enable = "avx2")]
#[inline]
unsafe fn tx_mul(self, mul: (i32, i32)) -> Self {
unsafe fn tx_mul<const SHIFT: i32>(self, mul: i32) -> Self {
I32X8::new(_mm256_srav_epi32(
_mm256_add_epi32(
_mm256_mullo_epi32(self.vec(), _mm256_set1_epi32(mul.0)),
_mm256_set1_epi32(1 << mul.1 >> 1),
_mm256_mullo_epi32(self.vec(), _mm256_set1_epi32(mul)),
_mm256_set1_epi32(1 << SHIFT >> 1),
),
_mm256_set1_epi32(mul.1),
_mm256_set1_epi32(SHIFT),
))
}

Expand Down
4 changes: 2 additions & 2 deletions src/transform/forward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ pub mod rust {
0
}

fn tx_mul(self, mul: (i32, i32)) -> Self {
((self * mul.0) + (1 << mul.1 >> 1)) >> mul.1
fn tx_mul<const SHIFT: i32>(self, mul: i32) -> Self {
((self * mul) + (1 << SHIFT >> 1)) >> SHIFT
}

fn rshift1(self) -> Self {
Expand Down
Loading

0 comments on commit 4d1276c

Please sign in to comment.