-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfind-slice-element-position-in-rust.html
187 lines (187 loc) · 20.6 KB
/
find-slice-element-position-in-rust.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8"/>
<meta name="author" content="Nikita Sivukhin"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Find slice element position in Rust, fast!</title>
<link rel="stylesheet" href="/static/style.css">
</head>
<body>
<div class="page">
<div class="meta">
<header>
<h2><a href="/">naming is hard</a></h2>
</header>
</div>
<div class="article">
<section id="Find-slice-element-position-in-Rust-fast">
<a class="heading-anchor" href="#Find-slice-element-position-in-Rust-fast"><h1 date="2024/01/13">Find slice element position in Rust, fast!</h1>
</a><p>I started to learn <code>Rust</code> only recently and while exploring <a href="https://doc.rust-lang.org/std/primitive.slice.html">slice methods</a> I was a bit surprised that I didn’t find any method for finding the position of element in the slice:</p>
<pre class="noline language-rust"><code><span class="keyword">fn</span> <span class="function">find</span>(<span class="identifier">haystack</span>: &[<span class="identifier">u8</span>], <span class="identifier">needle</span>: <span class="identifier">u8</span>) -> <span class="identifier">Option</span><<span class="identifier">usize</span>> { ... }</code>
</pre>
<p>I had some experience with <code>Zig</code> which has a very useful <a href="https://ziglang.org/documentation/master/std/#A;std:mem"><code>std.mem</code></a> module with many generic functions including <code>indexOf</code>, which internally implements <a href="https://en.wikipedia.org/wiki/Boyer%E2%80%93Moore%E2%80%93Horspool_algorithm">Boyer-Moore-Horspool</a> pattern matching algorithm against generic element type <code>T</code>:</p>
<pre class="noline language-zig"><code><span class="keyword">fn</span> <span class="function">indexOf</span>(<span class="keyword">comptime</span> <span class="identifier">T</span>: <span class="identifier">type</span>, <span class="identifier">haystack</span>: []<span class="keyword">const</span> <span class="identifier">T</span>, <span class="identifier">needle</span>: []<span class="keyword">const</span> <span class="identifier">T</span>) ?<span class="identifier">usize</span> { ... }</code>
</pre>
<p>After discussing with <code>Rust</code> experts I quickly got the response that I can just use methods of <code>Iterator</code> traits:</p>
<pre class="language-rust"><code><span class="keyword">fn</span> <span class="function">find</span>(<span class="identifier">haystack</span>: &[<span class="identifier">u8</span>], <span class="identifier">needle</span>: <span class="identifier">u8</span>) -> <span class="identifier">Option</span><<span class="identifier">usize</span>> {</code>
<code> <span class="identifier">haystack</span>.<span class="function">iter</span>().<span class="function">position</span>(|&<span class="identifier">x</span>| <span class="identifier">x</span> == <span class="identifier">needle</span>)</code>
<code>}</code>
</pre>
<p>Nice! But what about performance of this method? At first, I was afraid that using lambda function with closure will lead to poor performance (coming from <code>Go</code> with non-<code>LLVM</code> based compiler which has pretty limited power of inlining optimization). But, unsurprisingly for most of the developers, <code>LLVM</code> (and <code>Rust</code>) can optimize this method very nicely and <code>rustc</code> produce <a href="https://godbolt.org/z/YrvjKfx1v">very clean</a> binary with <code>-C opt-level=3 -C target-cpu=native</code> release profile flags:</p>
<pre class="language-asm"><code><span class="comment"># input : rdi=haystack.ptr, rsi=haystack.size, rdx=needle</span></code>
<code><span class="comment"># output: rax=None/Some, rdx=Some(v)</span></code>
<code>example::find:</code>
<code> <span class="keyword">test</span> rsi, rsi <span class="comment"># if haystack.len() == 0</span></code>
<code> <span class="keyword">je</span> .LBB0_1 <span class="comment"># return None<usize></span></code>
<code> <span class="keyword">mov</span> ecx, edx <span class="comment"># b = needle</span></code>
<code> <span class="keyword">xor</span> eax, eax <span class="comment"># result = None<usize></span></code>
<code> <span class="keyword">xor</span> edx, edx <span class="comment"># i = 0</span></code>
<code>.LBB0_3: <span class="comment"># loop {</span></code>
<code> <span class="keyword">cmp</span> byte ptr [rdi + rdx], cl <span class="comment"># if haystack[i] == b</span></code>
<code> <span class="keyword">je</span> .LBB0_4 <span class="comment"># return Option::Some(i)</span></code>
<code> <span class="keyword">inc</span> rdx <span class="comment"># i++</span></code>
<code> <span class="keyword">cmp</span> rsi, rdx <span class="comment"># if haystack.len() != i</span></code>
<code> <span class="keyword">jne</span> .LBB0_3 <span class="comment"># continue</span></code>
<code> <span class="keyword">mov</span> rdx, rsi <span class="comment"># i = haystack.len()</span></code>
<code> <span class="keyword">ret</span> <span class="comment"># return None()</span></code>
<code>.LBB0_1: <span class="comment"># }</span></code>
<code> <span class="keyword">xor</span> edx, edx</code>
<code> <span class="keyword">xor</span> eax, eax</code>
<code> <span class="keyword">ret</span></code>
<code>.LBB0_4:</code>
<code> <span class="keyword">mov</span> eax, <span class="number">1</span></code>
<code> <span class="keyword">ret</span></code>
</pre>
<p>Can we improve the method’s performance?</p>
</section>
<section id="Implementing-find-without-early-returns">
<a class="heading-anchor" href="#Implementing-find-without-early-returns"><h3>Implementing <code>find</code> without early returns</h3>
</a><p>We can notice that for other iterator methods like <code>filter</code> compiler will use <code>SSE / AVX</code> instructions (if target CPU supports them). Then, what is preventing compiler from using <code>SIMD</code> instructions for <code>position</code> method? Internally within a team we came to the conclusion that <code>position</code> method implementation returns early which makes it harder for <code>LLVM</code> to apply <code>SIMD</code> (although I have no proofs for that).</p>
<p>We can assume that compiler will be able to vectorize function if it will have predictable amount of operations (static or with simple relation of input properties like slice length). How can we achieve that for the <code>position</code> function? Actually, there is a nice way to implement it without <code>break</code> in the middle of the loop: we just need to process slice in reverse order! Then, in this case, we can just reassign result variable if we found matching element:</p>
<pre class="language-rust"><code><span class="keyword">pub</span> <span class="keyword">fn</span> <span class="function">find_branchless</span>(<span class="identifier">haystack</span>: &[<span class="identifier">u8</span>], <span class="identifier">needle</span>: <span class="identifier">u8</span>) -> <span class="identifier">Option</span><<span class="identifier">usize</span>> {</code>
<code> <span class="keyword">let</span> <span class="keyword">mut</span> <span class="identifier">position</span> = <span class="identifier">None</span>::<<span class="identifier">usize</span>>;</code>
<code> <span class="keyword">for</span> (<span class="identifier">i</span>, &<span class="identifier">b</span>) <span class="keyword">in</span> <span class="identifier">haystack</span>.<span class="function">iter</span>().<span class="function">enumerate</span>().<span class="function">rev</span>() {</code>
<code> <span class="keyword">if</span> <span class="identifier">b</span> == <span class="identifier">needle</span> {</code>
<code> <span class="identifier">position</span> = <span class="function">Some</span>(<span class="identifier">i</span>);</code>
<code> }</code>
<code> }</code>
<code> <span class="identifier">position</span></code>
<code>}</code>
</pre>
<p>Unfortunately, this doesn’t help – there are still to <code>SIMD</code> instructions in the output assembler. But wait, we can notice drastic changes in the <a href="https://godbolt.org/z/5Eh5rfaW3">output binary</a> – now it seems like compiler unrolled our main loop and compare bytes in chunks of size 8:</p>
<pre class="language-asm"><code><span class="comment"># there is just a part of the assembler, you can find full output by the godbolt link</span></code>
<code>.LBB0_11:</code>
<code> <span class="keyword">cmp</span> byte ptr [r8 + r11 - <span class="number">1</span>], dl</code>
<code> <span class="keyword">lea</span> r14, [rsi + r11 - <span class="number">1</span>]</code>
<code> <span class="keyword">lea</span> r15, [rsi + r11 - <span class="number">7</span>]</code>
<code> <span class="keyword">cmovne</span> r14, rcx</code>
<code> <span class="keyword">cmove</span> rax, rbx</code>
<code> <span class="keyword">cmp</span> byte ptr [r8 + r11 - <span class="number">2</span>], dl</code>
<code> <span class="keyword">lea</span> rcx, [rsi + r11 - <span class="number">2</span>]</code>
<code> <span class="keyword">cmovne</span> rcx, r14</code>
<code> <span class="keyword">cmove</span> rax, rbx</code>
<code> ...</code>
<code> <span class="keyword">cmp</span> byte ptr [r8 + r11 - <span class="number">8</span>], dl</code>
<code> <span class="keyword">lea</span> rcx, [rsi + r11 - <span class="number">8</span>]</code>
<code> <span class="keyword">cmovne</span> rcx, r15</code>
<code> <span class="keyword">cmove</span> rax, rbx</code>
</pre>
<p>That’s looks promising! Unrolling will help in performance by itself, but we can be on the right path to the successful vectorization guidance for the compiler!</p>
</section>
<section id="Vectorization-by-any-means">
<a class="heading-anchor" href="#Vectorization-by-any-means"><h3>Vectorization by any means!</h3>
</a><p>At this point, I had no clue of how I can simplify life of the compiler except only one last thing – we can make slice length constant and hope that this will finally activate vectorization engine in the <code>LLVM</code>. Turns out that this was enough! If we will use <code>[u8; 16]</code> or <code>[u8; 32]</code> types for input arguments – then <code>LLVM</code> <a href="https://godbolt.org/z/csYjj769s">will use</a> <code>128</code>-bit or <code>256</code>-bit <code>SSE</code> / <code>AVX</code> registers and corresponding instructions!</p>
<pre class="language-asm"><code>example::find_branchless:</code>
<code> <span class="keyword">vmovd</span> xmm0, esi</code>
<code> <span class="keyword">vpbroadcastb</span> xmm0, xmm0</code>
<code> <span class="keyword">vpcmpeqb</span> xmm0, xmm0, xmmword ptr [rdi]</code>
<code> <span class="keyword">vpextrb</span> eax, xmm0, <span class="number">14</span></code>
<code> <span class="keyword">vpextrb</span> ecx, xmm0, <span class="number">13</span></code>
<code> <span class="keyword">vpextrb</span> edx, xmm0, <span class="number">10</span></code>
<code> <span class="keyword">and</span> eax, <span class="number">1</span></code>
<code> <span class="keyword">xor</span> rax, <span class="number">15</span></code>
<code> <span class="keyword">test</span> cl, <span class="number">1</span></code>
<code> <span class="keyword">mov</span> ecx, <span class="number">13</span></code>
<code> <span class="keyword">cmove</span> rcx, rax</code>
<code> <span class="keyword">vpextrb</span> eax, xmm0, <span class="number">12</span></code>
<code> <span class="comment"># ... there are a lot of instructions for determining actual position of matched byte ...</span></code>
<code> <span class="keyword">vpmovmskb</span> esi, xmm0</code>
<code> <span class="keyword">cmove</span> rdx, rcx</code>
<code> <span class="keyword">xor</span> eax, eax</code>
<code> <span class="keyword">test</span> esi, esi</code>
<code> <span class="keyword">setne</span> al</code>
<code> <span class="keyword">ret</span></code>
</pre>
<p>You can notice that compiler generates truly branch-less code (literally, no jump instructions!). This can be surprising at the first sight, but actually compiler make use of <code>cmove</code> (“conditional move”) instruction which move value between operands only if the flags register are in the specific state. This instruction has way better performance then ordinary <code>CMP</code> / <code>JEQ</code> pair and allow to implement simple conditional scenarios like we have in the branch-less implementation of <code>find</code> function.</p>
</section>
<section id="Vectorized-version-of-find">
<a class="heading-anchor" href="#Vectorized-version-of-find"><h3>Vectorized version of <code>find</code></h3>
</a><p>Ok, that’s great that we finally forced <code>rustc</code> to use vectorization. But current <code>find</code> implementation is barely usable because it works only for byte arrays of fixed size! What can we do about that?</p>
<p>Here the actions are straightforward – we can split our input slice in chunks of bounded size and try to apply our branch-less implementation of <code>find</code> method for them. <code>Rust</code> has nice <a href="https://doc.rust-lang.org/std/primitive.slice.html#method.chunks"><code>chunks</code></a> function which do exactly what we want, let’s try to use it:</p>
<pre class="language-rust"><code><span class="keyword">fn</span> <span class="function">find_branchless</span>(<span class="identifier">haystack</span>: &[<span class="identifier">u8</span>], <span class="identifier">needle</span>: <span class="identifier">u8</span>) -> <span class="identifier">Option</span><<span class="identifier">usize</span>> { ... }</code>
<code><span class="keyword">pub</span> <span class="keyword">fn</span> <span class="function">find</span>(<span class="identifier">haystack</span>: &[<span class="identifier">u8</span>], <span class="identifier">needle</span>: <span class="identifier">u8</span>) -> <span class="identifier">Option</span><<span class="identifier">usize</span>> {</code>
<code> <span class="identifier">haystack</span></code>
<code> .<span class="function">chunks</span>(<span class="number">32</span>)</code>
<code> .<span class="function">enumerate</span>()</code>
<code> .<span class="function">find_map</span>(|(<span class="identifier">i</span>, <span class="identifier">chunk</span>)| <span class="function">find_branchless</span>(<span class="identifier">chunk</span>, <span class="identifier">needle</span>).<span class="function">map</span>(|<span class="identifier">x</span>| <span class="number">32</span> * <span class="identifier">i</span> + <span class="identifier">x</span>) )</code>
<code>}</code>
</pre>
<p>Unfortunately, this doesn’t work – the compiler again produces boring assembly with only unrolling optimization on. But, if we stop and think about it, this is actually expected! Chunking logic make every chunk unpredictable in size – because there is no guarantees about exact size of the last chunk (and every chunk can be the last one!).</p>
<p>Luckily, <code>Rust</code> developer team thought about this and added method <a href="https://doc.rust-lang.org/std/primitive.slice.html#method.chunks_exact"><code>chunks_exact</code></a> specifically for such cases! This method split slice in equally sized chunks and provides access to the tail of potentially smaller size through additional method: <code>remainder</code>.</p>
<p>This final step allow us to make our dream come true: <a href="https://godbolt.org/z/Kja5WGjMf">vectorized <code>find</code> function</a> with only safe <code>Rust</code>!</p>
<pre class="language-rust"><code><span class="comment">// bonus: refactoring of find_branchless function to make it more elegant!</span></code>
<code><span class="keyword">fn</span> <span class="function">find_branchless</span>(<span class="identifier">haystack</span>: &[<span class="identifier">u8</span>], <span class="identifier">needle</span>: <span class="identifier">u8</span>) -> <span class="identifier">Option</span><<span class="identifier">usize</span>> {</code>
<code> <span class="keyword">return</span> <span class="identifier">haystack</span>.<span class="function">iter</span>().<span class="function">enumerate</span>()</code>
<code> .<span class="function">filter</span>(|(_, &<span class="identifier">b</span>)| <span class="identifier">b</span> == <span class="identifier">needle</span>)</code>
<code> .<span class="function">rfold</span>(<span class="identifier">None</span>, |_, (<span class="identifier">i</span>, _)| <span class="function">Some</span>(<span class="identifier">i</span>))</code>
<code>}</code>
<code></code>
<code><span class="keyword">fn</span> <span class="function">find</span>(<span class="identifier">haystack</span>: &[<span class="identifier">u8</span>], <span class="identifier">needle</span>: <span class="identifier">u8</span>) -> <span class="identifier">Option</span><<span class="identifier">usize</span>> {</code>
<code> <span class="keyword">let</span> <span class="identifier">chunks</span> = <span class="identifier">haystack</span>.<span class="function">chunks_exact</span>(<span class="number">32</span>);</code>
<code> <span class="keyword">let</span> <span class="identifier">remainder</span> = <span class="identifier">chunks</span>.<span class="function">remainder</span>();</code>
<code> <span class="identifier">chunks</span>.<span class="function">enumerate</span>()</code>
<code> .<span class="function">find_map</span>(|(<span class="identifier">i</span>, <span class="identifier">chunk</span>)| <span class="function">find_branchless</span>(<span class="identifier">chunk</span>, <span class="identifier">needle</span>).<span class="function">map</span>(|<span class="identifier">x</span>| <span class="number">32</span> * <span class="identifier">i</span> + <span class="identifier">x</span>) )</code>
<code> .<span class="function">or</span>(<span class="function">find_branchless</span>(<span class="identifier">remainder</span>, <span class="identifier">needle</span>).<span class="function">map</span>(|<span class="identifier">x</span>| (<span class="identifier">haystack</span>.<span class="function">len</span>() & !<span class="number">0</span><span class="identifier">x1f</span>) + <span class="identifier">x</span>))</code>
<code>}</code>
</pre>
</section>
<section id="Benchmarks">
<a class="heading-anchor" href="#Benchmarks"><h3>Benchmarks</h3>
</a><p>The full benchmark source code is available here: <a href="https://github.com/sivukhin/sivukhin.github.io/tree/master/rust-find-bench">./rust-find-bench</a></p>
<table>
<tr>
<th style="text-align: left;">method</th>
<th style="text-align: right;">time</th>
<th style="text-align: right;">speedup</th>
</tr>
<tr>
<td style="text-align: left;"><code>find_naive</code></td>
<td style="text-align: right;"><code>366.06 ns</code></td>
<td style="text-align: right;"><code>x1.0</code></td>
</tr>
<tr>
<td style="text-align: left;"><code>find_chunks</code></td>
<td style="text-align: right;"><code>414.06 ns</code></td>
<td style="text-align: right;"><code>x0.9</code></td>
</tr>
<tr>
<td style="text-align: left;"><code>find_chunks_exact</code></td>
<td style="text-align: right;"><code>133.53 ns</code></td>
<td style="text-align: right;"><code>x2.7</code></td>
</tr>
<tr>
<td style="text-align: left;"><code>find_chunks_exact_branchless</code></td>
<td style="text-align: right;"><code>40.48 ns</code></td>
<td style="text-align: right;"><strong><code>x9.0</code></strong></td>
</tr>
</table>
</section>
</div>
<hr/>
<div class="footer">
<p><a href="https://github.com/sivukhin/sivukhin.github.io">github link</a> (made with <a href="https://github.com/sivukhin/godjot">godjot</a> and <a href="https://github.com/sivukhin/gopeg">gopeg</a>)</p>
</div>
</div>
</body>
</html>