From cfcce8e684c5e1bb2f9a74e55debf801ef27706f Mon Sep 17 00:00:00 2001 From: The 8472 Date: Sun, 23 Oct 2022 19:19:37 +0200 Subject: [PATCH] specialize iter::ArrayChunks::fold for TrustedRandomAccess iters This is fairly safe use of TRA since it consumes the iterator so no struct in an unsafe state will be left exposed to user code --- .../core/src/iter/adapters/array_chunks.rs | 89 ++++++++++++++++++- 1 file changed, 86 insertions(+), 3 deletions(-) diff --git a/library/core/src/iter/adapters/array_chunks.rs b/library/core/src/iter/adapters/array_chunks.rs index d4fb886101f..3f0fad4ed33 100644 --- a/library/core/src/iter/adapters/array_chunks.rs +++ b/library/core/src/iter/adapters/array_chunks.rs @@ -1,6 +1,8 @@ use crate::array; -use crate::iter::{ByRefSized, FusedIterator, Iterator}; -use crate::ops::{ControlFlow, Try}; +use crate::const_closure::ConstFnMutClosure; +use crate::iter::{ByRefSized, FusedIterator, Iterator, TrustedRandomAccessNoCoerce}; +use crate::mem::{self, MaybeUninit}; +use crate::ops::{ControlFlow, NeverShortCircuit, Try}; /// An iterator over `N` elements of the iterator at a time. /// @@ -82,7 +84,13 @@ fn try_fold(&mut self, init: B, mut f: F) -> R } } - impl_fold_via_try_fold! { fold -> try_fold } + fn fold(self, init: B, f: F) -> B + where + Self: Sized, + F: FnMut(B, Self::Item) -> B, + { + ::fold(self, init, f) + } } #[unstable(feature = "iter_array_chunks", reason = "recently added", issue = "100450")] @@ -168,3 +176,78 @@ fn is_empty(&self) -> bool { self.iter.len() < N } } + +trait SpecFold: Iterator { + fn fold(self, init: B, f: F) -> B + where + Self: Sized, + F: FnMut(B, Self::Item) -> B; +} + +impl SpecFold for ArrayChunks +where + I: Iterator, +{ + #[inline] + default fn fold(mut self, init: B, mut f: F) -> B + where + Self: Sized, + F: FnMut(B, Self::Item) -> B, + { + let fold = ConstFnMutClosure::new(&mut f, NeverShortCircuit::wrap_mut_2_imp); + self.try_fold(init, fold).0 + } +} + +impl SpecFold for ArrayChunks +where + I: Iterator + TrustedRandomAccessNoCoerce, +{ + #[inline] + fn fold(mut self, init: B, mut f: F) -> B + where + Self: Sized, + F: FnMut(B, Self::Item) -> B, + { + if self.remainder.is_some() { + return init; + } + + let mut accum = init; + let inner_len = self.iter.size(); + let mut i = 0; + // Use a while loop because (0..len).step_by(N) doesn't optimize well. + while inner_len - i >= N { + let mut chunk = MaybeUninit::uninit_array(); + let mut guard = array::Guard { array_mut: &mut chunk, initialized: 0 }; + for j in 0..N { + // SAFETY: The method consumes the iterator and the loop condition ensures that + // all accesses are in bounds and only happen once. + guard.array_mut[j].write(unsafe { self.iter.__iterator_get_unchecked(i + j) }); + guard.initialized = j + 1; + } + mem::forget(guard); + // SAFETY: The loop above initialized all elements + let chunk = unsafe { MaybeUninit::array_assume_init(chunk) }; + accum = f(accum, chunk); + i += N; + } + + let remainder = inner_len % N; + + let mut tail = MaybeUninit::uninit_array(); + let mut guard = array::Guard { array_mut: &mut tail, initialized: 0 }; + for i in 0..remainder { + // SAFETY: the remainder was not visited by the previous loop, so we're still only + // accessing each element once + let val = unsafe { self.iter.__iterator_get_unchecked(inner_len - remainder + i) }; + guard.array_mut[i].write(val); + guard.initialized = i + 1; + } + mem::forget(guard); + // SAFETY: the loop above initialized elements up to the `remainder` index + self.remainder = Some(unsafe { array::IntoIter::new_unchecked(tail, 0..remainder) }); + + accum + } +}