Improve autovectorization of to_lowercase / to_uppercase functions
Refactor the code in the `convert_while_ascii` helper function to make it more suitable for auto-vectorization and also process the full ascii prefix of the string. The generic case conversion logic will only be invoked starting from the first non-ascii character. The runtime on microbenchmarks with ascii-only inputs improves between 1.5x for short and 4x for long inputs on x86_64 and aarch64. The new implementation also encapsulates all unsafe inside the `convert_while_ascii` function. Fixes #123712
This commit is contained in:
parent
702987f75b
commit
e393f56d37
@ -347,3 +347,5 @@ mod $name {
|
|||||||
|
|
||||||
make_test!(split_space_str, s, s.split(" ").count());
|
make_test!(split_space_str, s, s.split(" ").count());
|
||||||
make_test!(split_ad_str, s, s.split("ad").count());
|
make_test!(split_ad_str, s, s.split("ad").count());
|
||||||
|
|
||||||
|
make_test!(to_lowercase, s, s.to_lowercase());
|
||||||
|
@ -9,6 +9,7 @@
|
|||||||
|
|
||||||
use core::borrow::{Borrow, BorrowMut};
|
use core::borrow::{Borrow, BorrowMut};
|
||||||
use core::iter::FusedIterator;
|
use core::iter::FusedIterator;
|
||||||
|
use core::mem::MaybeUninit;
|
||||||
#[stable(feature = "encode_utf16", since = "1.8.0")]
|
#[stable(feature = "encode_utf16", since = "1.8.0")]
|
||||||
pub use core::str::EncodeUtf16;
|
pub use core::str::EncodeUtf16;
|
||||||
#[stable(feature = "split_ascii_whitespace", since = "1.34.0")]
|
#[stable(feature = "split_ascii_whitespace", since = "1.34.0")]
|
||||||
@ -365,14 +366,9 @@ pub fn replacen<P: Pattern>(&self, pat: P, to: &str, count: usize) -> String {
|
|||||||
without modifying the original"]
|
without modifying the original"]
|
||||||
#[stable(feature = "unicode_case_mapping", since = "1.2.0")]
|
#[stable(feature = "unicode_case_mapping", since = "1.2.0")]
|
||||||
pub fn to_lowercase(&self) -> String {
|
pub fn to_lowercase(&self) -> String {
|
||||||
let out = convert_while_ascii(self.as_bytes(), u8::to_ascii_lowercase);
|
let (mut s, rest) = convert_while_ascii(self, u8::to_ascii_lowercase);
|
||||||
|
|
||||||
// Safety: we know this is a valid char boundary since
|
let prefix_len = s.len();
|
||||||
// out.len() is only progressed if ascii bytes are found
|
|
||||||
let rest = unsafe { self.get_unchecked(out.len()..) };
|
|
||||||
|
|
||||||
// Safety: We have written only valid ASCII to our vec
|
|
||||||
let mut s = unsafe { String::from_utf8_unchecked(out) };
|
|
||||||
|
|
||||||
for (i, c) in rest.char_indices() {
|
for (i, c) in rest.char_indices() {
|
||||||
if c == 'Σ' {
|
if c == 'Σ' {
|
||||||
@ -381,8 +377,7 @@ pub fn to_lowercase(&self) -> String {
|
|||||||
// in `SpecialCasing.txt`,
|
// in `SpecialCasing.txt`,
|
||||||
// so hard-code it rather than have a generic "condition" mechanism.
|
// so hard-code it rather than have a generic "condition" mechanism.
|
||||||
// See https://github.com/rust-lang/rust/issues/26035
|
// See https://github.com/rust-lang/rust/issues/26035
|
||||||
let out_len = self.len() - rest.len();
|
let sigma_lowercase = map_uppercase_sigma(self, prefix_len + i);
|
||||||
let sigma_lowercase = map_uppercase_sigma(&self, i + out_len);
|
|
||||||
s.push(sigma_lowercase);
|
s.push(sigma_lowercase);
|
||||||
} else {
|
} else {
|
||||||
match conversions::to_lower(c) {
|
match conversions::to_lower(c) {
|
||||||
@ -458,14 +453,7 @@ fn case_ignorable_then_cased<I: Iterator<Item = char>>(iter: I) -> bool {
|
|||||||
without modifying the original"]
|
without modifying the original"]
|
||||||
#[stable(feature = "unicode_case_mapping", since = "1.2.0")]
|
#[stable(feature = "unicode_case_mapping", since = "1.2.0")]
|
||||||
pub fn to_uppercase(&self) -> String {
|
pub fn to_uppercase(&self) -> String {
|
||||||
let out = convert_while_ascii(self.as_bytes(), u8::to_ascii_uppercase);
|
let (mut s, rest) = convert_while_ascii(self, u8::to_ascii_uppercase);
|
||||||
|
|
||||||
// Safety: we know this is a valid char boundary since
|
|
||||||
// out.len() is only progressed if ascii bytes are found
|
|
||||||
let rest = unsafe { self.get_unchecked(out.len()..) };
|
|
||||||
|
|
||||||
// Safety: We have written only valid ASCII to our vec
|
|
||||||
let mut s = unsafe { String::from_utf8_unchecked(out) };
|
|
||||||
|
|
||||||
for c in rest.chars() {
|
for c in rest.chars() {
|
||||||
match conversions::to_upper(c) {
|
match conversions::to_upper(c) {
|
||||||
@ -614,50 +602,87 @@ pub unsafe fn from_boxed_utf8_unchecked(v: Box<[u8]>) -> Box<str> {
|
|||||||
unsafe { Box::from_raw(Box::into_raw(v) as *mut str) }
|
unsafe { Box::from_raw(Box::into_raw(v) as *mut str) }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Converts the bytes while the bytes are still ascii.
|
/// Converts leading ascii bytes in `s` by calling the `convert` function.
|
||||||
|
///
|
||||||
/// For better average performance, this happens in chunks of `2*size_of::<usize>()`.
|
/// For better average performance, this happens in chunks of `2*size_of::<usize>()`.
|
||||||
/// Returns a vec with the converted bytes.
|
///
|
||||||
|
/// Returns a tuple of the converted prefix and the remainder starting from
|
||||||
|
/// the first non-ascii character.
|
||||||
|
///
|
||||||
|
/// This function is only public so that it can be verified in a codegen test,
|
||||||
|
/// see `issue-123712-str-to-lower-autovectorization.rs`.
|
||||||
|
#[unstable(feature = "str_internals", issue = "none")]
|
||||||
|
#[doc(hidden)]
|
||||||
#[inline]
|
#[inline]
|
||||||
#[cfg(not(test))]
|
#[cfg(not(test))]
|
||||||
#[cfg(not(no_global_oom_handling))]
|
#[cfg(not(no_global_oom_handling))]
|
||||||
fn convert_while_ascii(b: &[u8], convert: fn(&u8) -> u8) -> Vec<u8> {
|
pub fn convert_while_ascii(s: &str, convert: fn(&u8) -> u8) -> (String, &str) {
|
||||||
let mut out = Vec::with_capacity(b.len());
|
// Process the input in chunks of 16 bytes to enable auto-vectorization.
|
||||||
|
// Previously the chunk size depended on the size of `usize`,
|
||||||
|
// but on 32-bit platforms with sse or neon is also the better choice.
|
||||||
|
// The only downside on other platforms would be a bit more loop-unrolling.
|
||||||
|
const N: usize = 16;
|
||||||
|
|
||||||
const USIZE_SIZE: usize = mem::size_of::<usize>();
|
let mut slice = s.as_bytes();
|
||||||
const MAGIC_UNROLL: usize = 2;
|
let mut out = Vec::with_capacity(slice.len());
|
||||||
const N: usize = USIZE_SIZE * MAGIC_UNROLL;
|
let mut out_slice = out.spare_capacity_mut();
|
||||||
const NONASCII_MASK: usize = usize::from_ne_bytes([0x80; USIZE_SIZE]);
|
|
||||||
|
|
||||||
let mut i = 0;
|
let mut ascii_prefix_len = 0_usize;
|
||||||
unsafe {
|
let mut is_ascii = [false; N];
|
||||||
while i + N <= b.len() {
|
|
||||||
// Safety: we have checks the sizes `b` and `out` to know that our
|
|
||||||
let in_chunk = b.get_unchecked(i..i + N);
|
|
||||||
let out_chunk = out.spare_capacity_mut().get_unchecked_mut(i..i + N);
|
|
||||||
|
|
||||||
let mut bits = 0;
|
while slice.len() >= N {
|
||||||
for j in 0..MAGIC_UNROLL {
|
// SAFETY: checked in loop condition
|
||||||
// read the bytes 1 usize at a time (unaligned since we haven't checked the alignment)
|
let chunk = unsafe { slice.get_unchecked(..N) };
|
||||||
// safety: in_chunk is valid bytes in the range
|
// SAFETY: out_slice has at least same length as input slice and gets sliced with the same offsets
|
||||||
bits |= in_chunk.as_ptr().cast::<usize>().add(j).read_unaligned();
|
let out_chunk = unsafe { out_slice.get_unchecked_mut(..N) };
|
||||||
|
|
||||||
|
for j in 0..N {
|
||||||
|
is_ascii[j] = chunk[j] <= 127;
|
||||||
}
|
}
|
||||||
// if our chunks aren't ascii, then return only the prior bytes as init
|
|
||||||
if bits & NONASCII_MASK != 0 {
|
// Auto-vectorization for this check is a bit fragile, sum and comparing against the chunk
|
||||||
|
// size gives the best result, specifically a pmovmsk instruction on x86.
|
||||||
|
// See https://github.com/llvm/llvm-project/issues/96395 for why llvm currently does not
|
||||||
|
// currently recognize other similar idioms.
|
||||||
|
if is_ascii.iter().map(|x| *x as u8).sum::<u8>() as usize != N {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// perform the case conversions on N bytes (gets heavily autovec'd)
|
|
||||||
for j in 0..N {
|
for j in 0..N {
|
||||||
// safety: in_chunk and out_chunk is valid bytes in the range
|
out_chunk[j] = MaybeUninit::new(convert(&chunk[j]));
|
||||||
let out = out_chunk.get_unchecked_mut(j);
|
|
||||||
out.write(convert(in_chunk.get_unchecked(j)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// mark these bytes as initialised
|
ascii_prefix_len += N;
|
||||||
i += N;
|
slice = unsafe { slice.get_unchecked(N..) };
|
||||||
}
|
out_slice = unsafe { out_slice.get_unchecked_mut(N..) };
|
||||||
out.set_len(i);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
out
|
// handle the remainder as individual bytes
|
||||||
|
while slice.len() > 0 {
|
||||||
|
let byte = slice[0];
|
||||||
|
if byte > 127 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
// SAFETY: out_slice has at least same length as input slice
|
||||||
|
unsafe {
|
||||||
|
*out_slice.get_unchecked_mut(0) = MaybeUninit::new(convert(&byte));
|
||||||
|
}
|
||||||
|
ascii_prefix_len += 1;
|
||||||
|
slice = unsafe { slice.get_unchecked(1..) };
|
||||||
|
out_slice = unsafe { out_slice.get_unchecked_mut(1..) };
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
// SAFETY: ascii_prefix_len bytes have been initialized above
|
||||||
|
out.set_len(ascii_prefix_len);
|
||||||
|
|
||||||
|
// SAFETY: We have written only valid ascii to the output vec
|
||||||
|
let ascii_string = String::from_utf8_unchecked(out);
|
||||||
|
|
||||||
|
// SAFETY: we know this is a valid char boundary
|
||||||
|
// since we only skipped over leading ascii bytes
|
||||||
|
let rest = core::str::from_utf8_unchecked(slice);
|
||||||
|
|
||||||
|
(ascii_string, rest)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1854,7 +1854,10 @@ fn to_lowercase() {
|
|||||||
assert_eq!("ΑΣ''Α".to_lowercase(), "ασ''α");
|
assert_eq!("ΑΣ''Α".to_lowercase(), "ασ''α");
|
||||||
|
|
||||||
// https://github.com/rust-lang/rust/issues/124714
|
// https://github.com/rust-lang/rust/issues/124714
|
||||||
|
// input lengths around the boundary of the chunk size used by the ascii prefix optimization
|
||||||
|
assert_eq!("abcdefghijklmnoΣ".to_lowercase(), "abcdefghijklmnoς");
|
||||||
assert_eq!("abcdefghijklmnopΣ".to_lowercase(), "abcdefghijklmnopς");
|
assert_eq!("abcdefghijklmnopΣ".to_lowercase(), "abcdefghijklmnopς");
|
||||||
|
assert_eq!("abcdefghijklmnopqΣ".to_lowercase(), "abcdefghijklmnopqς");
|
||||||
|
|
||||||
// a really long string that has it's lowercase form
|
// a really long string that has it's lowercase form
|
||||||
// even longer. this tests that implementations don't assume
|
// even longer. this tests that implementations don't assume
|
||||||
|
@ -0,0 +1,23 @@
|
|||||||
|
//@ only-x86_64
|
||||||
|
//@ compile-flags: -C opt-level=3
|
||||||
|
#![crate_type = "lib"]
|
||||||
|
#![no_std]
|
||||||
|
#![feature(str_internals)]
|
||||||
|
|
||||||
|
extern crate alloc;
|
||||||
|
|
||||||
|
/// Ensure that the ascii-prefix loop for `str::to_lowercase` and `str::to_uppercase` uses vector
|
||||||
|
/// instructions.
|
||||||
|
///
|
||||||
|
/// The llvm ir should be the same for all targets that support some form of simd. Only targets
|
||||||
|
/// without any simd instructions would see scalarized ir.
|
||||||
|
/// Unfortunately, there is no `only-simd` directive to only run this test on only such platforms,
|
||||||
|
/// and using test revisions would still require the core libraries for all platforms.
|
||||||
|
// CHECK-LABEL: @lower_while_ascii
|
||||||
|
// CHECK: [[A:%[0-9]]] = load <16 x i8>
|
||||||
|
// CHECK-NEXT: [[B:%[0-9]]] = icmp slt <16 x i8> [[A]], zeroinitializer
|
||||||
|
// CHECK-NEXT: [[C:%[0-9]]] = bitcast <16 x i1> [[B]] to i16
|
||||||
|
#[no_mangle]
|
||||||
|
pub fn lower_while_ascii(s: &str) -> (alloc::string::String, &str) {
|
||||||
|
alloc::str::convert_while_ascii(s, u8::to_ascii_lowercase)
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user