Skip to content

Commit 4d1276c

Browse files
committed
Use const generics for right shift in forward transforms
1 parent 8eb0c4b commit 4d1276c

File tree

4 files changed

+140
-174
lines changed

4 files changed

+140
-174
lines changed

src/asm/aarch64/transform/forward.rs

Lines changed: 18 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -19,46 +19,6 @@ use debug_unreachable::debug_unreachable;
1919

2020
use core::arch::aarch64::*;
2121

22-
#[inline]
23-
unsafe fn vrshrq_n_s32_switch(a: int32x4_t, n: i32) -> int32x4_t {
24-
match n {
25-
0 => a,
26-
1 => vrshrq_n_s32(a, 1),
27-
2 => vrshrq_n_s32(a, 2),
28-
3 => vrshrq_n_s32(a, 3),
29-
4 => vrshrq_n_s32(a, 4),
30-
5 => vrshrq_n_s32(a, 5),
31-
6 => vrshrq_n_s32(a, 6),
32-
7 => vrshrq_n_s32(a, 7),
33-
8 => vrshrq_n_s32(a, 8),
34-
9 => vrshrq_n_s32(a, 9),
35-
10 => vrshrq_n_s32(a, 10),
36-
11 => vrshrq_n_s32(a, 11),
37-
12 => vrshrq_n_s32(a, 12),
38-
13 => vrshrq_n_s32(a, 13),
39-
14 => vrshrq_n_s32(a, 14),
40-
15 => vrshrq_n_s32(a, 15),
41-
16 => vrshrq_n_s32(a, 16),
42-
17 => vrshrq_n_s32(a, 17),
43-
18 => vrshrq_n_s32(a, 18),
44-
19 => vrshrq_n_s32(a, 19),
45-
20 => vrshrq_n_s32(a, 20),
46-
21 => vrshrq_n_s32(a, 21),
47-
22 => vrshrq_n_s32(a, 22),
48-
23 => vrshrq_n_s32(a, 23),
49-
24 => vrshrq_n_s32(a, 24),
50-
25 => vrshrq_n_s32(a, 25),
51-
26 => vrshrq_n_s32(a, 26),
52-
27 => vrshrq_n_s32(a, 27),
53-
28 => vrshrq_n_s32(a, 28),
54-
29 => vrshrq_n_s32(a, 29),
55-
30 => vrshrq_n_s32(a, 30),
56-
31 => vrshrq_n_s32(a, 31),
57-
32 => vrshrq_n_s32(a, 32),
58-
_ => unreachable!(),
59-
}
60-
}
61-
6222
#[derive(Clone, Copy)]
6323
#[repr(transparent)]
6424
struct I32X8(int32x4x2_t);
@@ -92,10 +52,10 @@ impl TxOperations for I32X8 {
9252
}
9353

9454
#[inline]
95-
unsafe fn tx_mul(self, mul: (i32, i32)) -> Self {
55+
unsafe fn tx_mul<const SHIFT: i32>(self, mul: i32) -> Self {
9656
I32X8::new(
97-
vrshrq_n_s32_switch(vmulq_n_s32(self.vec().0, mul.0), mul.1),
98-
vrshrq_n_s32_switch(vmulq_n_s32(self.vec().1, mul.0), mul.1),
57+
vrshrq_n_s32(vmulq_n_s32(self.vec().0, mul), SHIFT),
58+
vrshrq_n_s32(vmulq_n_s32(self.vec().1, mul), SHIFT),
9959
)
10060
}
10161

@@ -268,11 +228,8 @@ unsafe fn shift_left_neon(a: I32X8, shift: u8) -> I32X8 {
268228
}
269229

270230
#[inline]
271-
unsafe fn shift_right_neon(a: I32X8, shift: u8) -> I32X8 {
272-
I32X8::new(
273-
vrshrq_n_s32_switch(a.vec().0, shift.into()),
274-
vrshrq_n_s32_switch(a.vec().1, shift.into()),
275-
)
231+
unsafe fn shift_right_neon<const SHIFT: i32>(a: I32X8) -> I32X8 {
232+
I32X8::new(vrshrq_n_s32(a.vec().0, SHIFT), vrshrq_n_s32(a.vec().1, SHIFT))
276233
}
277234

278235
#[inline]
@@ -285,11 +242,20 @@ unsafe fn round_shift_array_neon(arr: &mut [I32X8], bit: i8) {
285242
return;
286243
}
287244
if bit > 0 {
288-
let shift = bit as u8;
289-
for s in arr.chunks_exact_mut(4) {
290-
for chunk in s {
291-
*chunk = shift_right_neon(*chunk, shift);
245+
if bit == 1 {
246+
for s in arr.chunks_exact_mut(4) {
247+
for chunk in s {
248+
*chunk = shift_right_neon::<1>(*chunk)
249+
}
292250
}
251+
} else if bit == 2 {
252+
for s in arr.chunks_exact_mut(4) {
253+
for chunk in s {
254+
*chunk = shift_right_neon::<2>(*chunk)
255+
}
256+
}
257+
} else {
258+
debug_unreachable!();
293259
}
294260
} else {
295261
let shift = (-bit) as u8;

src/asm/x86/transform/forward.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,13 @@ impl TxOperations for I32X8 {
5555

5656
#[target_feature(enable = "avx2")]
5757
#[inline]
58-
unsafe fn tx_mul(self, mul: (i32, i32)) -> Self {
58+
unsafe fn tx_mul<const SHIFT: i32>(self, mul: i32) -> Self {
5959
I32X8::new(_mm256_srav_epi32(
6060
_mm256_add_epi32(
61-
_mm256_mullo_epi32(self.vec(), _mm256_set1_epi32(mul.0)),
62-
_mm256_set1_epi32(1 << mul.1 >> 1),
61+
_mm256_mullo_epi32(self.vec(), _mm256_set1_epi32(mul)),
62+
_mm256_set1_epi32(1 << SHIFT >> 1),
6363
),
64-
_mm256_set1_epi32(mul.1),
64+
_mm256_set1_epi32(SHIFT),
6565
))
6666
}
6767

src/transform/forward.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ pub mod rust {
3939
0
4040
}
4141

42-
fn tx_mul(self, mul: (i32, i32)) -> Self {
43-
((self * mul.0) + (1 << mul.1 >> 1)) >> mul.1
42+
fn tx_mul<const SHIFT: i32>(self, mul: i32) -> Self {
43+
((self * mul) + (1 << SHIFT >> 1)) >> SHIFT
4444
}
4545

4646
fn rshift1(self) -> Self {

0 commit comments

Comments
 (0)