Skip to content

Commit

Permalink
refactor(rust): avoid nightly rust for case conversion (#11610)
Browse files Browse the repository at this point in the history
  • Loading branch information
orlp authored Oct 9, 2023
1 parent 5a234a7 commit a16ea64
Showing 1 changed file with 41 additions and 139 deletions.
180 changes: 41 additions & 139 deletions crates/polars-ops/src/chunked_array/strings/case.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
#[cfg(feature = "nightly")]
use core::unicode::conversions;

use polars_core::prelude::Utf8Chunked;

// inlined from std
// Inlined from std.
fn convert_while_ascii(b: &[u8], convert: fn(&u8) -> u8, out: &mut Vec<u8>) {
unsafe {
out.set_len(0);
out.reserve(b.len());
}
out.clear();
out.reserve(b.len());

const USIZE_SIZE: usize = std::mem::size_of::<usize>();
const MAGIC_UNROLL: usize = 2;
Expand All @@ -18,69 +13,43 @@ fn convert_while_ascii(b: &[u8], convert: fn(&u8) -> u8, out: &mut Vec<u8>) {
let mut i = 0;
unsafe {
while i + N <= b.len() {
// Safety: we have checks the sizes `b` and `out` to know that our
// SAFETY: we have checks the sizes `b` and `out`.
let in_chunk = b.get_unchecked(i..i + N);
let out_chunk = out.spare_capacity_mut().get_unchecked_mut(i..i + N);

let mut bits = 0;
for j in 0..MAGIC_UNROLL {
// read the bytes 1 usize at a time (unaligned since we haven't checked the alignment)
// safety: in_chunk is valid bytes in the range
// Read the bytes 1 usize at a time (unaligned since we haven't checked the alignment).
// SAFETY: in_chunk is valid bytes in the range.
bits |= in_chunk.as_ptr().cast::<usize>().add(j).read_unaligned();
}
// if our chunks aren't ascii, then return only the prior bytes as init
// If our chunks aren't ascii, then return only the prior bytes as init.
if bits & NONASCII_MASK != 0 {
break;
}

// perform the case conversions on N bytes (gets heavily autovec'd)
// Perform the case conversions on N bytes (gets heavily autovec'd).
for j in 0..N {
// safety: in_chunk and out_chunk is valid bytes in the range
// SAFETY: in_chunk and out_chunk are valid bytes in the range.
let out = out_chunk.get_unchecked_mut(j);
out.write(convert(in_chunk.get_unchecked(j)));
}

// mark these bytes as initialised
// Mark these bytes as initialised.
i += N;
}
out.set_len(i);
}
}

#[cfg(not(feature = "nightly"))]
pub(super) fn to_lowercase<'a>(ca: &'a Utf8Chunked) -> Utf8Chunked {
// this amortizes allocations and will not be freed
// so will have size of max(len)
let mut buf = Vec::new();

// this is one that will be set if we cannot convert ascii
// this length will change every iteration we must use this
let mut buf2 = Vec::new();
let f = |s: &'a str| {
convert_while_ascii(s.as_bytes(), u8::to_ascii_lowercase, &mut buf);
let slice = if buf.len() < s.len() {
buf2 = s.to_lowercase().into_bytes();
buf2.as_ref()
} else {
buf.as_ref()
};
// extend lifetime
// lifetime is bound to 'a
let slice = unsafe { std::str::from_utf8_unchecked(slice) };
unsafe { std::mem::transmute::<&str, &'a str>(slice) }
};
ca.apply_mut(f)
}

#[cfg(feature = "nightly")]
fn to_lowercase_helper(s: &str, buf: &mut Vec<u8>) {
convert_while_ascii(s.as_bytes(), u8::to_ascii_lowercase, buf);

// Safety: we know this is a valid char boundary since
// out.len() is only progressed if ascii bytes are found
// SAFETY: we know this is a valid char boundary since
// out.len() is only progressed if ASCII bytes are found.
let rest = unsafe { s.get_unchecked(buf.len()..) };

// Safety: We have written only valid ASCII to our vec
// SAFETY: We have written only valid ASCII to our vec.
let mut s = unsafe { String::from_utf8_unchecked(std::mem::take(buf)) };

for (i, c) in rest[..].char_indices() {
Expand All @@ -92,18 +61,7 @@ fn to_lowercase_helper(s: &str, buf: &mut Vec<u8>) {
// See https://github.com/rust-lang/rust/issues/26035
map_uppercase_sigma(rest, i, &mut s)
} else {
match conversions::to_lower(c) {
[a, '\0', _] => s.push(a),
[a, b, '\0'] => {
s.push(a);
s.push(b);
},
[a, b, c] => {
s.push(a);
s.push(b);
s.push(c);
},
}
s.extend(c.to_lowercase());
}
}

Expand All @@ -124,136 +82,80 @@ fn to_lowercase_helper(s: &str, buf: &mut Vec<u8>) {
None => false,
}
}
// put buf back for next iteration

// Put buf back for next iteration.
*buf = s.into_bytes();
}

// inlined from std
#[cfg(feature = "nightly")]
pub(super) fn to_lowercase<'a>(ca: &'a Utf8Chunked) -> Utf8Chunked {
// amortize allocation
// Amortize allocation.
let mut buf = Vec::new();
let f = |s: &'a str| {
let f = |s: &'a str| -> &'a str {
to_lowercase_helper(s, &mut buf);

// extend lifetime
// lifetime is bound to 'a
// SAFETY: apply_mut will copy value from buf before next iteration.
let slice = unsafe { std::str::from_utf8_unchecked(&buf) };
unsafe { std::mem::transmute::<&str, &'a str>(slice) }
};
ca.apply_mut(f)
}

#[cfg(not(feature = "nightly"))]
// Inlined from std.
pub(super) fn to_uppercase<'a>(ca: &'a Utf8Chunked) -> Utf8Chunked {
// this amortizes allocations and will not be freed
// so will have size of max(len)
// Amortize allocation.
let mut buf = Vec::new();

// this is one that will be set if we cannot convert ascii
// this length will change every iteration we must use this
let mut buf2 = Vec::new();
let f = |s: &'a str| {
let f = |s: &'a str| -> &'a str {
convert_while_ascii(s.as_bytes(), u8::to_ascii_uppercase, &mut buf);
let slice = if buf.len() < s.len() {
buf2 = s.to_uppercase().into_bytes();
buf2.as_ref()
} else {
buf.as_ref()
};
// extend lifetime
// lifetime is bound to 'a
let slice = unsafe { std::str::from_utf8_unchecked(slice) };
unsafe { std::mem::transmute::<&str, &'a str>(slice) }
};
ca.apply_mut(f)
}

#[inline]
#[cfg(feature = "nightly")]
fn push_char_to_upper(c: char, s: &mut String) {
match conversions::to_upper(c) {
[a, '\0', _] => s.push(a),
[a, b, '\0'] => {
s.push(a);
s.push(b);
},
[a, b, c] => {
s.push(a);
s.push(b);
s.push(c);
},
}
}

// inlined from std
#[cfg(feature = "nightly")]
pub(super) fn to_uppercase<'a>(ca: &'a Utf8Chunked) -> Utf8Chunked {
// amortize allocation
let mut buf = Vec::new();
let f = |s: &'a str| {
convert_while_ascii(s.as_bytes(), u8::to_ascii_uppercase, &mut buf);

// Safety: we know this is a valid char boundary since
// out.len() is only progressed if ascii bytes are found
// SAFETY: we know this is a valid char boundary since
// out.len() is only progressed if ascii bytes are found.
let rest = unsafe { s.get_unchecked(buf.len()..) };

// Safety: We have written only valid ASCII to our vec
// SAFETY: We have written only valid ASCII to our vec.
let mut s = unsafe { String::from_utf8_unchecked(std::mem::take(&mut buf)) };

for c in rest.chars() {
push_char_to_upper(c, &mut s);
s.extend(c.to_uppercase());
}

// put buf back for next iteration
// Put buf back for next iteration.
buf = s.into_bytes();

// extend lifetime
// lifetime is bound to 'a
// SAFETY: apply_mut will copy value from buf before next iteration.
let slice = unsafe { std::str::from_utf8_unchecked(&buf) };
unsafe { std::mem::transmute::<&str, &'a str>(slice) }
};
ca.apply_mut(f)
}

#[cfg(feature = "nightly")]
pub(super) fn to_titlecase<'a>(ca: &'a Utf8Chunked) -> Utf8Chunked {
// amortize allocation
// Amortize allocation.
let mut buf = Vec::new();

// temporary scratch
// we have a double copy as we first convert to lowercase
// and then copy to `buf`
// Temporary scratch space.
// We have a double copy as we first convert to lowercase and then copy to `buf`.
let mut scratch = Vec::new();
let f = |s: &'a str| {
unsafe {
buf.set_len(0);
}
// this helper sets scratch len to 0
let f = |s: &'a str| -> &'a str {
to_lowercase_helper(s, &mut scratch);

let mut next_is_upper = true;

let lowercased = unsafe { std::str::from_utf8_unchecked(&scratch) };

// SAFETY: the buffer is clear, empty string is valid UTF-8.
buf.clear();
let mut s = unsafe { String::from_utf8_unchecked(std::mem::take(&mut buf)) };

let mut next_is_upper = true;
for c in lowercased.chars() {
let is_whitespace = c.is_whitespace();
if is_whitespace || !next_is_upper {
next_is_upper = is_whitespace;
s.push(c);
if next_is_upper {
s.extend(c.to_uppercase());
} else {
next_is_upper = false;
push_char_to_upper(c, &mut s);
s.push(c);
}
next_is_upper = c.is_whitespace();
}

// put buf back for next iteration
// Put buf back for next iteration.
buf = s.into_bytes();

// extend lifetime
// lifetime is bound to 'a
// SAFETY: apply_mut will copy value from buf before next iteration.
let slice = unsafe { std::str::from_utf8_unchecked(&buf) };
unsafe { std::mem::transmute::<&str, &'a str>(slice) }
};
Expand Down

0 comments on commit a16ea64

Please sign in to comment.