Skip to content

Commit a505dde

Browse files
authored
Port rust-lang/rust#78857 - Improve BinaryHeap performance (#28)
* Remove useless branches from sift_down_range loop * Remove branches from sift_down_to_bottom loop * Remove useless bound checks from into_sorted_vec
1 parent ffee63f commit a505dde

File tree

1 file changed

+25
-17
lines changed

1 file changed

+25
-17
lines changed

src/binary_heap.rs

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -938,7 +938,14 @@ impl<T, C: Compare<T>> BinaryHeap<T, C> {
938938
let mut end = self.len();
939939
while end > 1 {
940940
end -= 1;
941-
self.data.swap(0, end);
941+
// SAFETY: `end` goes from `self.len() - 1` to 1 (both included),
942+
// so it's always a valid index to access.
943+
// It is safe to access index 0 (i.e. `ptr`), because
944+
// 1 <= end < self.len(), which means self.len() >= 2.
945+
unsafe {
946+
let ptr = self.data.as_mut_ptr();
947+
ptr::swap(ptr, ptr.add(end));
948+
}
942949
self.sift_down_range(0, end);
943950
}
944951
self.into_vec()
@@ -975,23 +982,24 @@ impl<T, C: Compare<T>> BinaryHeap<T, C> {
975982
unsafe {
976983
let mut hole = Hole::new(&mut self.data, pos);
977984
let mut child = 2 * pos + 1;
978-
while child < end {
979-
let right = child + 1;
985+
while child < end - 1 {
980986
// compare with the greater of the two children
981-
// if right < end && !(hole.get(child) > hole.get(right)) {
982-
if right < end
983-
&& self.cmp.compare(hole.get(child), hole.get(right)) != Ordering::Greater
984-
{
985-
child = right;
986-
}
987+
// if !(hole.get(child) > hole.get(child + 1)) { child += 1 }
988+
child += (self.cmp.compare(hole.get(child), hole.get(child + 1))
989+
!= Ordering::Greater) as usize;
987990
// if we are already in order, stop.
988991
// if hole.element() >= hole.get(child) {
989992
if self.cmp.compare(hole.element(), hole.get(child)) != Ordering::Less {
990-
break;
993+
return;
991994
}
992995
hole.move_to(child);
993996
child = 2 * hole.pos() + 1;
994997
}
998+
if child == end - 1
999+
&& self.cmp.compare(hole.element(), hole.get(child)) == Ordering::Less
1000+
{
1001+
hole.move_to(child);
1002+
}
9951003
}
9961004
}
9971005

@@ -1011,18 +1019,18 @@ impl<T, C: Compare<T>> BinaryHeap<T, C> {
10111019
unsafe {
10121020
let mut hole = Hole::new(&mut self.data, pos);
10131021
let mut child = 2 * pos + 1;
1014-
while child < end {
1022+
while child < end - 1 {
10151023
let right = child + 1;
10161024
// compare with the greater of the two children
1017-
// if right < end && !(hole.get(child) > hole.get(right)) {
1018-
if right < end
1019-
&& self.cmp.compare(hole.get(child), hole.get(right)) != Ordering::Greater
1020-
{
1021-
child = right;
1022-
}
1025+
// if !(hole.get(child) > hole.get(right)) { child += 1 }
1026+
child += (self.cmp.compare(hole.get(child), hole.get(right)) != Ordering::Greater)
1027+
as usize;
10231028
hole.move_to(child);
10241029
child = 2 * hole.pos() + 1;
10251030
}
1031+
if child == end - 1 {
1032+
hole.move_to(child);
1033+
}
10261034
pos = hole.pos;
10271035
}
10281036
self.sift_up(start, pos);

0 commit comments

Comments
 (0)