From 23637e20cdf3f7b8e01b42dbaf25357e5d3c31ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eduardo=20S=C3=A1nchez=20Mu=C3=B1oz?= Date: Wed, 6 Oct 2021 22:26:36 +0200 Subject: [PATCH] libcore: assume the input of `next_code_point` and `next_code_point_reverse` is UTF-8-like The functions are now `unsafe` and they use `Option::unwrap_unchecked` instead of `unwrap_or_0` `unwrap_or_0` was added in 42357d772b8a3a1ce4395deeac0a5cf1f66e951d. I guess `unwrap_unchecked` was not available back then. Given this example: ```rust pub fn first_char(s: &str) -> Option { s.chars().next() } ``` Previously, the following assembly was produced: ```asm _ZN7example10first_char17ha056ddea6bafad1cE: .cfi_startproc test rsi, rsi je .LBB0_1 movzx edx, byte ptr [rdi] test dl, dl js .LBB0_3 mov eax, edx ret .LBB0_1: mov eax, 1114112 ret .LBB0_3: lea r8, [rdi + rsi] xor eax, eax mov r9, r8 cmp rsi, 1 je .LBB0_5 movzx eax, byte ptr [rdi + 1] add rdi, 2 and eax, 63 mov r9, rdi .LBB0_5: mov ecx, edx and ecx, 31 cmp dl, -33 jbe .LBB0_6 cmp r9, r8 je .LBB0_9 movzx esi, byte ptr [r9] add r9, 1 and esi, 63 shl eax, 6 or eax, esi cmp dl, -16 jb .LBB0_12 .LBB0_13: cmp r9, r8 je .LBB0_14 movzx edx, byte ptr [r9] and edx, 63 jmp .LBB0_16 .LBB0_6: shl ecx, 6 or eax, ecx ret .LBB0_9: xor esi, esi mov r9, r8 shl eax, 6 or eax, esi cmp dl, -16 jae .LBB0_13 .LBB0_12: shl ecx, 12 or eax, ecx ret .LBB0_14: xor edx, edx .LBB0_16: and ecx, 7 shl ecx, 18 shl eax, 6 or eax, ecx or eax, edx ret ``` After this change, the assembly is reduced to: ```asm _ZN7example10first_char17h4318683472f884ccE: .cfi_startproc test rsi, rsi je .LBB0_1 movzx ecx, byte ptr [rdi] test cl, cl js .LBB0_3 mov eax, ecx ret .LBB0_1: mov eax, 1114112 ret .LBB0_3: mov eax, ecx and eax, 31 movzx esi, byte ptr [rdi + 1] and esi, 63 cmp cl, -33 jbe .LBB0_4 movzx edx, byte ptr [rdi + 2] shl esi, 6 and edx, 63 or edx, esi cmp cl, -16 jb .LBB0_7 movzx ecx, byte ptr [rdi + 3] and eax, 7 shl eax, 18 shl edx, 6 and ecx, 63 or ecx, edx or eax, ecx ret .LBB0_4: shl eax, 6 or eax, esi ret .LBB0_7: shl eax, 12 or eax, edx ret ``` --- library/core/src/str/iter.rs | 14 ++++----- library/core/src/str/validations.rs | 44 ++++++++++++++++++----------- library/std/src/sys_common/wtf8.rs | 3 +- 3 files changed, 36 insertions(+), 25 deletions(-) diff --git a/library/core/src/str/iter.rs b/library/core/src/str/iter.rs index 94a534c6e79..48410446716 100644 --- a/library/core/src/str/iter.rs +++ b/library/core/src/str/iter.rs @@ -39,10 +39,9 @@ impl<'a> Iterator for Chars<'a> { #[inline] fn next(&mut self) -> Option { - next_code_point(&mut self.iter).map(|ch| { - // SAFETY: `str` invariant says `ch` is a valid Unicode Scalar Value. - unsafe { char::from_u32_unchecked(ch) } - }) + // SAFETY: `str` invariant says `self.iter` is a valid UTF-8 string and + // the resulting `ch` is a valid Unicode Scalar Value. + unsafe { next_code_point(&mut self.iter).map(|ch| char::from_u32_unchecked(ch)) } } #[inline] @@ -81,10 +80,9 @@ impl fmt::Debug for Chars<'_> { impl<'a> DoubleEndedIterator for Chars<'a> { #[inline] fn next_back(&mut self) -> Option { - next_code_point_reverse(&mut self.iter).map(|ch| { - // SAFETY: `str` invariant says `ch` is a valid Unicode Scalar Value. - unsafe { char::from_u32_unchecked(ch) } - }) + // SAFETY: `str` invariant says `self.iter` is a valid UTF-8 string and + // the resulting `ch` is a valid Unicode Scalar Value. + unsafe { next_code_point_reverse(&mut self.iter).map(|ch| char::from_u32_unchecked(ch)) } } } diff --git a/library/core/src/str/validations.rs b/library/core/src/str/validations.rs index e362d5c05c1..be9c41a491b 100644 --- a/library/core/src/str/validations.rs +++ b/library/core/src/str/validations.rs @@ -25,19 +25,15 @@ pub(super) const fn utf8_is_cont_byte(byte: u8) -> bool { (byte as i8) < -64 } -#[inline] -const fn unwrap_or_0(opt: Option<&u8>) -> u8 { - match opt { - Some(&byte) => byte, - None => 0, - } -} - /// Reads the next code point out of a byte iterator (assuming a /// UTF-8-like encoding). +/// +/// # Safety +/// +/// `bytes` must produce a valid UTF-8-like (UTF-8 or WTF-8) string #[unstable(feature = "str_internals", issue = "none")] #[inline] -pub fn next_code_point<'a, I: Iterator>(bytes: &mut I) -> Option { +pub unsafe fn next_code_point<'a, I: Iterator>(bytes: &mut I) -> Option { // Decode UTF-8 let x = *bytes.next()?; if x < 128 { @@ -48,18 +44,24 @@ pub fn next_code_point<'a, I: Iterator>(bytes: &mut I) -> Option< // Decode from a byte combination out of: [[[x y] z] w] // NOTE: Performance is sensitive to the exact formulation here let init = utf8_first_byte(x, 2); - let y = unwrap_or_0(bytes.next()); + // SAFETY: `bytes` produces an UTF-8-like string, + // so the iterator must produce a value here. + let y = unsafe { *bytes.next().unwrap_unchecked() }; let mut ch = utf8_acc_cont_byte(init, y); if x >= 0xE0 { // [[x y z] w] case // 5th bit in 0xE0 .. 0xEF is always clear, so `init` is still valid - let z = unwrap_or_0(bytes.next()); + // SAFETY: `bytes` produces an UTF-8-like string, + // so the iterator must produce a value here. + let z = unsafe { *bytes.next().unwrap_unchecked() }; let y_z = utf8_acc_cont_byte((y & CONT_MASK) as u32, z); ch = init << 12 | y_z; if x >= 0xF0 { // [x y z w] case // use only the lower 3 bits of `init` - let w = unwrap_or_0(bytes.next()); + // SAFETY: `bytes` produces an UTF-8-like string, + // so the iterator must produce a value here. + let w = unsafe { *bytes.next().unwrap_unchecked() }; ch = (init & 7) << 18 | utf8_acc_cont_byte(y_z, w); } } @@ -69,8 +71,12 @@ pub fn next_code_point<'a, I: Iterator>(bytes: &mut I) -> Option< /// Reads the last code point out of a byte iterator (assuming a /// UTF-8-like encoding). +/// +/// # Safety +/// +/// `bytes` must produce a valid UTF-8-like (UTF-8 or WTF-8) string #[inline] -pub(super) fn next_code_point_reverse<'a, I>(bytes: &mut I) -> Option +pub(super) unsafe fn next_code_point_reverse<'a, I>(bytes: &mut I) -> Option where I: DoubleEndedIterator, { @@ -83,13 +89,19 @@ where // Multibyte case follows // Decode from a byte combination out of: [x [y [z w]]] let mut ch; - let z = unwrap_or_0(bytes.next_back()); + // SAFETY: `bytes` produces an UTF-8-like string, + // so the iterator must produce a value here. + let z = unsafe { *bytes.next_back().unwrap_unchecked() }; ch = utf8_first_byte(z, 2); if utf8_is_cont_byte(z) { - let y = unwrap_or_0(bytes.next_back()); + // SAFETY: `bytes` produces an UTF-8-like string, + // so the iterator must produce a value here. + let y = unsafe { *bytes.next_back().unwrap_unchecked() }; ch = utf8_first_byte(y, 3); if utf8_is_cont_byte(y) { - let x = unwrap_or_0(bytes.next_back()); + // SAFETY: `bytes` produces an UTF-8-like string, + // so the iterator must produce a value here. + let x = unsafe { *bytes.next_back().unwrap_unchecked() }; ch = utf8_first_byte(x, 4); ch = utf8_acc_cont_byte(ch, y); } diff --git a/library/std/src/sys_common/wtf8.rs b/library/std/src/sys_common/wtf8.rs index 0629859bd9d..6e29bc61454 100644 --- a/library/std/src/sys_common/wtf8.rs +++ b/library/std/src/sys_common/wtf8.rs @@ -809,7 +809,8 @@ impl<'a> Iterator for Wtf8CodePoints<'a> { #[inline] fn next(&mut self) -> Option { - next_code_point(&mut self.bytes).map(|c| CodePoint { value: c }) + // SAFETY: `self.bytes` has been created from a WTF-8 string + unsafe { next_code_point(&mut self.bytes).map(|c| CodePoint { value: c }) } } #[inline]