Skip to content

Commit 308cb64

Browse files
committed
fix: avx512 codegen by multiversion
Signed-off-by: usamoi <[email protected]>
1 parent 8a925e3 commit 308cb64

File tree

8 files changed

+319
-53
lines changed

8 files changed

+319
-53
lines changed

crates/service/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#![feature(core_intrinsics)]
2+
#![feature(avx512_target_feature)]
23

34
pub mod algorithms;
45
pub mod index;

crates/service/src/prelude/global/f16.rs

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@ use crate::prelude::*;
22

33
pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 {
44
#[inline(always)]
5-
#[multiversion::multiversion(targets = "simd")]
5+
#[multiversion::multiversion(targets(
6+
"x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
7+
"x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
8+
"x86_64+ssse3+sse4.1+sse3+sse2+sse+fma",
9+
"aarch64+neon"
10+
))]
611
pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 {
712
assert!(lhs.len() == rhs.len());
813
let n = lhs.len();
@@ -37,7 +42,12 @@ pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 {
3742

3843
pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 {
3944
#[inline(always)]
40-
#[multiversion::multiversion(targets = "simd")]
45+
#[multiversion::multiversion(targets(
46+
"x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
47+
"x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
48+
"x86_64+ssse3+sse4.1+sse3+sse2+sse+fma",
49+
"aarch64+neon"
50+
))]
4151
pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 {
4252
assert!(lhs.len() == rhs.len());
4353
let n = lhs.len();
@@ -68,7 +78,12 @@ pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 {
6878

6979
pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 {
7080
#[inline(always)]
71-
#[multiversion::multiversion(targets = "simd")]
81+
#[multiversion::multiversion(targets(
82+
"x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
83+
"x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
84+
"x86_64+ssse3+sse4.1+sse3+sse2+sse+fma",
85+
"aarch64+neon"
86+
))]
7287
pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 {
7388
assert!(lhs.len() == rhs.len());
7489
let n = lhs.len();

crates/service/src/prelude/global/f16_cos.rs

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@ impl G for F16Cos {
2424
super::f16::dot(lhs, rhs).acos()
2525
}
2626

27-
#[multiversion::multiversion(targets = "simd")]
27+
#[multiversion::multiversion(targets(
28+
"x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
29+
"x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
30+
"x86_64+ssse3+sse4.1+sse3+sse2+sse+fma",
31+
"aarch64+neon"
32+
))]
2833
fn scalar_quantization_distance(
2934
dims: u16,
3035
max: &[F16],
@@ -45,7 +50,12 @@ impl G for F16Cos {
4550
xy / (x2 * y2).sqrt() * (-1.0)
4651
}
4752

48-
#[multiversion::multiversion(targets = "simd")]
53+
#[multiversion::multiversion(targets(
54+
"x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
55+
"x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
56+
"x86_64+ssse3+sse4.1+sse3+sse2+sse+fma",
57+
"aarch64+neon"
58+
))]
4959
fn scalar_quantization_distance2(
5060
dims: u16,
5161
max: &[F16],
@@ -66,7 +76,12 @@ impl G for F16Cos {
6676
xy / (x2 * y2).sqrt() * (-1.0)
6777
}
6878

69-
#[multiversion::multiversion(targets = "simd")]
79+
#[multiversion::multiversion(targets(
80+
"x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
81+
"x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
82+
"x86_64+ssse3+sse4.1+sse3+sse2+sse+fma",
83+
"aarch64+neon"
84+
))]
7085
fn product_quantization_distance(
7186
dims: u16,
7287
ratio: u16,
@@ -91,7 +106,12 @@ impl G for F16Cos {
91106
xy / (x2 * y2).sqrt() * (-1.0)
92107
}
93108

94-
#[multiversion::multiversion(targets = "simd")]
109+
#[multiversion::multiversion(targets(
110+
"x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
111+
"x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
112+
"x86_64+ssse3+sse4.1+sse3+sse2+sse+fma",
113+
"aarch64+neon"
114+
))]
95115
fn product_quantization_distance2(
96116
dims: u16,
97117
ratio: u16,
@@ -117,7 +137,12 @@ impl G for F16Cos {
117137
xy / (x2 * y2).sqrt() * (-1.0)
118138
}
119139

120-
#[multiversion::multiversion(targets = "simd")]
140+
#[multiversion::multiversion(targets(
141+
"x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
142+
"x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
143+
"x86_64+ssse3+sse4.1+sse3+sse2+sse+fma",
144+
"aarch64+neon"
145+
))]
121146
fn product_quantization_distance_with_delta(
122147
dims: u16,
123148
ratio: u16,
@@ -146,7 +171,12 @@ impl G for F16Cos {
146171
}
147172

148173
#[inline(always)]
149-
#[multiversion::multiversion(targets = "simd")]
174+
#[multiversion::multiversion(targets(
175+
"x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
176+
"x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
177+
"x86_64+ssse3+sse4.1+sse3+sse2+sse+fma",
178+
"aarch64+neon"
179+
))]
150180
fn length(vector: &[F16]) -> F16 {
151181
let n = vector.len();
152182
let mut dot = F16::zero();
@@ -157,7 +187,12 @@ fn length(vector: &[F16]) -> F16 {
157187
}
158188

159189
#[inline(always)]
160-
#[multiversion::multiversion(targets = "simd")]
190+
#[multiversion::multiversion(targets(
191+
"x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
192+
"x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
193+
"x86_64+ssse3+sse4.1+sse3+sse2+sse+fma",
194+
"aarch64+neon"
195+
))]
161196
fn l2_normalize(vector: &mut [F16]) {
162197
let n = vector.len();
163198
let l = length(vector);
@@ -167,7 +202,12 @@ fn l2_normalize(vector: &mut [F16]) {
167202
}
168203

169204
#[inline(always)]
170-
#[multiversion::multiversion(targets = "simd")]
205+
#[multiversion::multiversion(targets(
206+
"x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
207+
"x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
208+
"x86_64+ssse3+sse4.1+sse3+sse2+sse+fma",
209+
"aarch64+neon"
210+
))]
171211
fn xy_x2_y2(lhs: &[F16], rhs: &[F16]) -> (F32, F32, F32) {
172212
assert!(lhs.len() == rhs.len());
173213
let n = lhs.len();
@@ -183,7 +223,12 @@ fn xy_x2_y2(lhs: &[F16], rhs: &[F16]) -> (F32, F32, F32) {
183223
}
184224

185225
#[inline(always)]
186-
#[multiversion::multiversion(targets = "simd")]
226+
#[multiversion::multiversion(targets(
227+
"x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
228+
"x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
229+
"x86_64+ssse3+sse4.1+sse3+sse2+sse+fma",
230+
"aarch64+neon"
231+
))]
187232
fn xy_x2_y2_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> (F32, F32, F32) {
188233
assert!(lhs.len() == rhs.len());
189234
let n = lhs.len();

crates/service/src/prelude/global/f16_dot.rs

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@ impl G for F16Dot {
2424
super::f16::dot(lhs, rhs).acos()
2525
}
2626

27-
#[multiversion::multiversion(targets = "simd")]
27+
#[multiversion::multiversion(targets(
28+
"x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
29+
"x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
30+
"x86_64+ssse3+sse4.1+sse3+sse2+sse+fma",
31+
"aarch64+neon"
32+
))]
2833
fn scalar_quantization_distance(
2934
dims: u16,
3035
max: &[F16],
@@ -41,7 +46,12 @@ impl G for F16Dot {
4146
xy * (-1.0)
4247
}
4348

44-
#[multiversion::multiversion(targets = "simd")]
49+
#[multiversion::multiversion(targets(
50+
"x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
51+
"x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
52+
"x86_64+ssse3+sse4.1+sse3+sse2+sse+fma",
53+
"aarch64+neon"
54+
))]
4555
fn scalar_quantization_distance2(
4656
dims: u16,
4757
max: &[F16],
@@ -58,7 +68,12 @@ impl G for F16Dot {
5868
xy * (-1.0)
5969
}
6070

61-
#[multiversion::multiversion(targets = "simd")]
71+
#[multiversion::multiversion(targets(
72+
"x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
73+
"x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
74+
"x86_64+ssse3+sse4.1+sse3+sse2+sse+fma",
75+
"aarch64+neon"
76+
))]
6277
fn product_quantization_distance(
6378
dims: u16,
6479
ratio: u16,
@@ -79,7 +94,12 @@ impl G for F16Dot {
7994
xy * (-1.0)
8095
}
8196

82-
#[multiversion::multiversion(targets = "simd")]
97+
#[multiversion::multiversion(targets(
98+
"x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
99+
"x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
100+
"x86_64+ssse3+sse4.1+sse3+sse2+sse+fma",
101+
"aarch64+neon"
102+
))]
83103
fn product_quantization_distance2(
84104
dims: u16,
85105
ratio: u16,
@@ -101,7 +121,12 @@ impl G for F16Dot {
101121
xy * (-1.0)
102122
}
103123

104-
#[multiversion::multiversion(targets = "simd")]
124+
#[multiversion::multiversion(targets(
125+
"x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
126+
"x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
127+
"x86_64+ssse3+sse4.1+sse3+sse2+sse+fma",
128+
"aarch64+neon"
129+
))]
105130
fn product_quantization_distance_with_delta(
106131
dims: u16,
107132
ratio: u16,
@@ -126,7 +151,12 @@ impl G for F16Dot {
126151
}
127152

128153
#[inline(always)]
129-
#[multiversion::multiversion(targets = "simd")]
154+
#[multiversion::multiversion(targets(
155+
"x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
156+
"x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
157+
"x86_64+ssse3+sse4.1+sse3+sse2+sse+fma",
158+
"aarch64+neon"
159+
))]
130160
fn length(vector: &[F16]) -> F16 {
131161
let n = vector.len();
132162
let mut dot = F16::zero();
@@ -137,7 +167,12 @@ fn length(vector: &[F16]) -> F16 {
137167
}
138168

139169
#[inline(always)]
140-
#[multiversion::multiversion(targets = "simd")]
170+
#[multiversion::multiversion(targets(
171+
"x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
172+
"x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
173+
"x86_64+ssse3+sse4.1+sse3+sse2+sse+fma",
174+
"aarch64+neon"
175+
))]
141176
fn l2_normalize(vector: &mut [F16]) {
142177
let n = vector.len();
143178
let l = length(vector);
@@ -147,7 +182,12 @@ fn l2_normalize(vector: &mut [F16]) {
147182
}
148183

149184
#[inline(always)]
150-
#[multiversion::multiversion(targets = "simd")]
185+
#[multiversion::multiversion(targets(
186+
"x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
187+
"x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
188+
"x86_64+ssse3+sse4.1+sse3+sse2+sse+fma",
189+
"aarch64+neon"
190+
))]
151191
fn dot_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> F32 {
152192
assert!(lhs.len() == rhs.len());
153193
let n: usize = lhs.len();

crates/service/src/prelude/global/f16_l2.rs

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,12 @@ impl G for F16L2 {
2323
super::f16::sl2(lhs, rhs).sqrt()
2424
}
2525

26-
#[multiversion::multiversion(targets = "simd")]
26+
#[multiversion::multiversion(targets(
27+
"x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
28+
"x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
29+
"x86_64+ssse3+sse4.1+sse3+sse2+sse+fma",
30+
"aarch64+neon"
31+
))]
2732
fn scalar_quantization_distance(
2833
dims: u16,
2934
max: &[F16],
@@ -40,7 +45,12 @@ impl G for F16L2 {
4045
result
4146
}
4247

43-
#[multiversion::multiversion(targets = "simd")]
48+
#[multiversion::multiversion(targets(
49+
"x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
50+
"x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
51+
"x86_64+ssse3+sse4.1+sse3+sse2+sse+fma",
52+
"aarch64+neon"
53+
))]
4454
fn scalar_quantization_distance2(
4555
dims: u16,
4656
max: &[F16],
@@ -57,7 +67,12 @@ impl G for F16L2 {
5767
result
5868
}
5969

60-
#[multiversion::multiversion(targets = "simd")]
70+
#[multiversion::multiversion(targets(
71+
"x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
72+
"x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
73+
"x86_64+ssse3+sse4.1+sse3+sse2+sse+fma",
74+
"aarch64+neon"
75+
))]
6176
fn product_quantization_distance(
6277
dims: u16,
6378
ratio: u16,
@@ -77,7 +92,12 @@ impl G for F16L2 {
7792
result
7893
}
7994

80-
#[multiversion::multiversion(targets = "simd")]
95+
#[multiversion::multiversion(targets(
96+
"x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
97+
"x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
98+
"x86_64+ssse3+sse4.1+sse3+sse2+sse+fma",
99+
"aarch64+neon"
100+
))]
81101
fn product_quantization_distance2(
82102
dims: u16,
83103
ratio: u16,
@@ -98,7 +118,12 @@ impl G for F16L2 {
98118
result
99119
}
100120

101-
#[multiversion::multiversion(targets = "simd")]
121+
#[multiversion::multiversion(targets(
122+
"x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
123+
"x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
124+
"x86_64+ssse3+sse4.1+sse3+sse2+sse+fma",
125+
"aarch64+neon"
126+
))]
102127
fn product_quantization_distance_with_delta(
103128
dims: u16,
104129
ratio: u16,
@@ -122,7 +147,12 @@ impl G for F16L2 {
122147
}
123148

124149
#[inline(always)]
125-
#[multiversion::multiversion(targets = "simd")]
150+
#[multiversion::multiversion(targets(
151+
"x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
152+
"x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma",
153+
"x86_64+ssse3+sse4.1+sse3+sse2+sse+fma",
154+
"aarch64+neon"
155+
))]
126156
fn distance_squared_l2_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> F32 {
127157
assert!(lhs.len() == rhs.len());
128158
let n = lhs.len();

0 commit comments

Comments
 (0)