From 2c6e67105e4967f8b37ebe9ed92880c6773eb29e Mon Sep 17 00:00:00 2001 From: The8472 Date: Mon, 12 Jul 2021 21:40:38 +0200 Subject: [PATCH] implement advance_(back_)_by on more iterators --- library/alloc/src/lib.rs | 1 + library/alloc/src/vec/into_iter.rs | 45 +++++++++++ library/alloc/tests/lib.rs | 1 + library/alloc/tests/vec.rs | 18 +++++ library/core/src/iter/adapters/copied.rs | 10 +++ library/core/src/iter/adapters/cycle.rs | 21 ++++++ library/core/src/iter/adapters/enumerate.rs | 22 ++++++ library/core/src/iter/adapters/flatten.rs | 69 +++++++++++++++++ library/core/src/iter/adapters/skip.rs | 21 ++++++ library/core/src/iter/adapters/take.rs | 34 +++++++++ library/core/src/iter/range.rs | 79 ++++++++++++++++++++ library/core/src/iter/traits/double_ended.rs | 8 +- library/core/src/iter/traits/iterator.rs | 8 +- library/core/tests/iter/adapters/flatten.rs | 17 +++++ library/core/tests/iter/range.rs | 23 ++++++ 15 files changed, 375 insertions(+), 2 deletions(-) diff --git a/library/alloc/src/lib.rs b/library/alloc/src/lib.rs index 2aed9d03bc0..ca41ce975e4 100644 --- a/library/alloc/src/lib.rs +++ b/library/alloc/src/lib.rs @@ -111,6 +111,7 @@ // that the feature-gate isn't enabled. Ideally, it wouldn't check for the feature gate for docs // from other crates, but since this can only appear for lang items, it doesn't seem worth fixing. #![feature(intra_doc_pointers)] +#![feature(iter_advance_by)] #![feature(iter_zip)] #![feature(lang_items)] #![feature(layout_for_ptr)] diff --git a/library/alloc/src/vec/into_iter.rs b/library/alloc/src/vec/into_iter.rs index 4cb0a4b10bd..eae9ad076dc 100644 --- a/library/alloc/src/vec/into_iter.rs +++ b/library/alloc/src/vec/into_iter.rs @@ -161,6 +161,28 @@ fn size_hint(&self) -> (usize, Option) { (exact, Some(exact)) } + #[inline] + fn advance_by(&mut self, n: usize) -> Result<(), usize> { + let step_size = self.len().min(n); + if mem::size_of::() == 0 { + // SAFETY: due to unchecked casts of unsigned amounts to signed offsets the wraparound + // effectively results in unsigned pointers representing positions 0..usize::MAX, + // which is valid for ZSTs. + self.ptr = unsafe { arith_offset(self.ptr as *const i8, step_size as isize) as *mut T } + } else { + let to_drop = ptr::slice_from_raw_parts_mut(self.ptr as *mut T, step_size); + // SAFETY: the min() above ensures that step_size is in bounds + unsafe { + self.ptr = self.ptr.add(step_size); + ptr::drop_in_place(to_drop); + } + } + if step_size < n { + return Err(step_size); + } + Ok(()) + } + #[inline] fn count(self) -> usize { self.len() @@ -203,6 +225,29 @@ fn next_back(&mut self) -> Option { Some(unsafe { ptr::read(self.end) }) } } + + #[inline] + fn advance_back_by(&mut self, n: usize) -> Result<(), usize> { + let step_size = self.len().min(n); + if mem::size_of::() == 0 { + // SAFETY: same as for advance_by() + self.end = unsafe { + arith_offset(self.end as *const i8, step_size.wrapping_neg() as isize) as *mut T + } + } else { + // SAFETY: same as for advance_by() + self.end = unsafe { self.end.offset(step_size.wrapping_neg() as isize) }; + let to_drop = ptr::slice_from_raw_parts_mut(self.end as *mut T, step_size); + // SAFETY: same as for advance_by() + unsafe { + ptr::drop_in_place(to_drop); + } + } + if step_size < n { + return Err(step_size); + } + Ok(()) + } } #[stable(feature = "rust1", since = "1.0.0")] diff --git a/library/alloc/tests/lib.rs b/library/alloc/tests/lib.rs index cae4dae708e..c6159539b48 100644 --- a/library/alloc/tests/lib.rs +++ b/library/alloc/tests/lib.rs @@ -18,6 +18,7 @@ #![feature(binary_heap_retain)] #![feature(binary_heap_as_slice)] #![feature(inplace_iteration)] +#![feature(iter_advance_by)] #![feature(slice_group_by)] #![feature(slice_partition_dedup)] #![feature(vec_spare_capacity)] diff --git a/library/alloc/tests/vec.rs b/library/alloc/tests/vec.rs index c2df50b48f5..00a878c0794 100644 --- a/library/alloc/tests/vec.rs +++ b/library/alloc/tests/vec.rs @@ -970,6 +970,24 @@ fn drop(&mut self) { assert_eq!(unsafe { DROPS }, 3); } +#[test] +fn test_into_iter_advance_by() { + let mut i = vec![1, 2, 3, 4, 5].into_iter(); + i.advance_by(0).unwrap(); + i.advance_back_by(0).unwrap(); + assert_eq!(i.as_slice(), [1, 2, 3, 4, 5]); + + i.advance_by(1).unwrap(); + i.advance_back_by(1).unwrap(); + assert_eq!(i.as_slice(), [2, 3, 4]); + + assert_eq!(i.advance_back_by(usize::MAX), Err(3)); + + assert_eq!(i.advance_by(usize::MAX), Err(0)); + + assert_eq!(i.len(), 0); +} + #[test] fn test_from_iter_specialization() { let src: Vec = vec![0usize; 1]; diff --git a/library/core/src/iter/adapters/copied.rs b/library/core/src/iter/adapters/copied.rs index 3d3c8da678b..e5f2886dcaf 100644 --- a/library/core/src/iter/adapters/copied.rs +++ b/library/core/src/iter/adapters/copied.rs @@ -76,6 +76,11 @@ fn count(self) -> usize { self.it.count() } + #[inline] + fn advance_by(&mut self, n: usize) -> Result<(), usize> { + self.it.advance_by(n) + } + #[doc(hidden)] unsafe fn __iterator_get_unchecked(&mut self, idx: usize) -> T where @@ -112,6 +117,11 @@ fn rfold(self, init: Acc, f: F) -> Acc { self.it.rfold(init, copy_fold(f)) } + + #[inline] + fn advance_back_by(&mut self, n: usize) -> Result<(), usize> { + self.it.advance_back_by(n) + } } #[stable(feature = "iter_copied", since = "1.36.0")] diff --git a/library/core/src/iter/adapters/cycle.rs b/library/core/src/iter/adapters/cycle.rs index 815e708f9ec..02b5939072e 100644 --- a/library/core/src/iter/adapters/cycle.rs +++ b/library/core/src/iter/adapters/cycle.rs @@ -79,6 +79,27 @@ fn try_fold(&mut self, mut acc: Acc, mut f: F) -> R } } + #[inline] + #[rustc_inherit_overflow_checks] + fn advance_by(&mut self, n: usize) -> Result<(), usize> { + let mut rem = n; + match self.iter.advance_by(rem) { + ret @ Ok(_) => return ret, + Err(advanced) => rem -= advanced, + } + + while rem > 0 { + self.iter = self.orig.clone(); + match self.iter.advance_by(rem) { + ret @ Ok(_) => return ret, + Err(0) => return Err(n - rem), + Err(advanced) => rem -= advanced, + } + } + + Ok(()) + } + // No `fold` override, because `fold` doesn't make much sense for `Cycle`, // and we can't do anything better than the default. } diff --git a/library/core/src/iter/adapters/enumerate.rs b/library/core/src/iter/adapters/enumerate.rs index 3478a0cd408..c877b45095a 100644 --- a/library/core/src/iter/adapters/enumerate.rs +++ b/library/core/src/iter/adapters/enumerate.rs @@ -112,6 +112,21 @@ fn enumerate( self.iter.fold(init, enumerate(self.count, fold)) } + #[inline] + #[rustc_inherit_overflow_checks] + fn advance_by(&mut self, n: usize) -> Result<(), usize> { + match self.iter.advance_by(n) { + ret @ Ok(_) => { + self.count += n; + ret + } + ret @ Err(advanced) => { + self.count += advanced; + ret + } + } + } + #[rustc_inherit_overflow_checks] #[doc(hidden)] unsafe fn __iterator_get_unchecked(&mut self, idx: usize) -> ::Item @@ -191,6 +206,13 @@ fn enumerate( let count = self.count + self.iter.len(); self.iter.rfold(init, enumerate(count, fold)) } + + #[inline] + fn advance_back_by(&mut self, n: usize) -> Result<(), usize> { + // we do not need to update the count since that only tallies the number of items + // consumed from the front. consuming items from the back can never reduce that. + self.iter.advance_back_by(n) + } } #[stable(feature = "rust1", since = "1.0.0")] diff --git a/library/core/src/iter/adapters/flatten.rs b/library/core/src/iter/adapters/flatten.rs index 48880a4d91a..e1d665bb2a1 100644 --- a/library/core/src/iter/adapters/flatten.rs +++ b/library/core/src/iter/adapters/flatten.rs @@ -391,6 +391,40 @@ fn flatten( init } + + #[inline] + #[rustc_inherit_overflow_checks] + fn advance_by(&mut self, n: usize) -> Result<(), usize> { + let mut rem = n; + loop { + if let Some(ref mut front) = self.frontiter { + match front.advance_by(rem) { + ret @ Ok(_) => return ret, + Err(advanced) => rem -= advanced, + } + } + self.frontiter = match self.iter.next() { + Some(iterable) => Some(iterable.into_iter()), + _ => break, + } + } + + self.frontiter = None; + + if let Some(ref mut back) = self.backiter { + if let Err(advanced) = back.advance_by(rem) { + rem -= advanced + } + } + + if rem > 0 { + return Err(n - rem); + } + + self.backiter = None; + + Ok(()) + } } impl DoubleEndedIterator for FlattenCompat @@ -486,6 +520,41 @@ fn flatten( init } + + #[inline] + #[rustc_inherit_overflow_checks] + fn advance_back_by(&mut self, n: usize) -> Result<(), usize> { + let mut rem = n; + loop { + if let Some(ref mut back) = self.backiter { + match back.advance_back_by(rem) { + ret @ Ok(_) => return ret, + Err(advanced) => rem -= advanced, + } + } + match self.iter.next_back() { + Some(iterable) => self.backiter = Some(iterable.into_iter()), + _ => break, + } + } + + self.backiter = None; + + if let Some(ref mut front) = self.frontiter { + match front.advance_back_by(rem) { + ret @ Ok(_) => return ret, + Err(advanced) => rem -= advanced, + } + } + + if rem > 0 { + return Err(n - rem); + } + + self.frontiter = None; + + Ok(()) + } } trait ConstSizeIntoIterator: IntoIterator { diff --git a/library/core/src/iter/adapters/skip.rs b/library/core/src/iter/adapters/skip.rs index c358a6d12b7..e29ff1291cf 100644 --- a/library/core/src/iter/adapters/skip.rs +++ b/library/core/src/iter/adapters/skip.rs @@ -114,6 +114,17 @@ fn fold(mut self, init: Acc, fold: Fold) -> Acc } self.iter.fold(init, fold) } + + #[inline] + fn advance_by(&mut self, n: usize) -> Result<(), usize> { + if self.n >= n { + self.n -= n; + return Ok(()); + } + let rem = n - self.n; + self.n = 0; + self.iter.advance_by(rem) + } } #[stable(feature = "rust1", since = "1.0.0")] @@ -174,6 +185,16 @@ fn ok(mut f: impl FnMut(Acc, T) -> Acc) -> impl FnMut(Acc, T) -> Result< self.try_rfold(init, ok(fold)).unwrap() } + + #[inline] + fn advance_back_by(&mut self, n: usize) -> Result<(), usize> { + let min = crate::cmp::min(self.len(), n); + return match self.iter.advance_back_by(min) { + ret @ Ok(_) if n <= min => ret, + Ok(_) => Err(min), + _ => panic!("ExactSizeIterator contract violation"), + }; + } } #[stable(feature = "fused", since = "1.26.0")] diff --git a/library/core/src/iter/adapters/take.rs b/library/core/src/iter/adapters/take.rs index beda8c32c6b..bff68339dda 100644 --- a/library/core/src/iter/adapters/take.rs +++ b/library/core/src/iter/adapters/take.rs @@ -111,6 +111,22 @@ fn ok(mut f: impl FnMut(B, T) -> B) -> impl FnMut(B, T) -> Result { self.try_fold(init, ok(fold)).unwrap() } + + #[inline] + #[rustc_inherit_overflow_checks] + fn advance_by(&mut self, n: usize) -> Result<(), usize> { + let min = crate::cmp::min(self.n, n); + return match self.iter.advance_by(min) { + Ok(_) => { + self.n -= min; + if min < n { Err(min) } else { Ok(()) } + } + ret @ Err(advanced) => { + self.n -= advanced; + ret + } + }; + } } #[unstable(issue = "none", feature = "inplace_iteration")] @@ -197,6 +213,24 @@ fn rfold(mut self, init: Acc, fold: Fold) -> Acc } } } + + #[inline] + fn advance_back_by(&mut self, n: usize) -> Result<(), usize> { + let inner_len = self.iter.len(); + let len = self.n; + let remainder = len.saturating_sub(n); + let to_advance = inner_len - remainder; + match self.iter.advance_back_by(to_advance) { + Ok(_) => { + self.n = remainder; + if n > len { + return Err(len); + } + return Ok(()); + } + _ => panic!("ExactSizeIterator contract violation"), + } + } } #[stable(feature = "rust1", since = "1.0.0")] diff --git a/library/core/src/iter/range.rs b/library/core/src/iter/range.rs index 0f835689699..06733a1b50b 100644 --- a/library/core/src/iter/range.rs +++ b/library/core/src/iter/range.rs @@ -521,10 +521,12 @@ trait RangeIteratorImpl { // Iterator fn spec_next(&mut self) -> Option; fn spec_nth(&mut self, n: usize) -> Option; + fn spec_advance_by(&mut self, n: usize) -> Result<(), usize>; // DoubleEndedIterator fn spec_next_back(&mut self) -> Option; fn spec_nth_back(&mut self, n: usize) -> Option; + fn spec_advance_back_by(&mut self, n: usize) -> Result<(), usize>; } impl RangeIteratorImpl for ops::Range { @@ -555,6 +557,22 @@ impl RangeIteratorImpl for ops::Range { None } + #[inline] + default fn spec_advance_by(&mut self, n: usize) -> Result<(), usize> { + let available = if self.start <= self.end { + Step::steps_between(&self.start, &self.end).unwrap_or(usize::MAX) + } else { + 0 + }; + + let taken = available.min(n); + + self.start = + Step::forward_checked(self.start.clone(), taken).expect("`Step` invariants not upheld"); + + if taken < n { Err(taken) } else { Ok(()) } + } + #[inline] default fn spec_next_back(&mut self) -> Option { if self.start < self.end { @@ -579,6 +597,22 @@ impl RangeIteratorImpl for ops::Range { self.end = self.start.clone(); None } + + #[inline] + default fn spec_advance_back_by(&mut self, n: usize) -> Result<(), usize> { + let available = if self.start <= self.end { + Step::steps_between(&self.start, &self.end).unwrap_or(usize::MAX) + } else { + 0 + }; + + let taken = available.min(n); + + self.end = + Step::backward_checked(self.end.clone(), taken).expect("`Step` invariants not upheld"); + + if taken < n { Err(taken) } else { Ok(()) } + } } impl RangeIteratorImpl for ops::Range { @@ -607,6 +641,25 @@ fn spec_nth(&mut self, n: usize) -> Option { None } + #[inline] + fn spec_advance_by(&mut self, n: usize) -> Result<(), usize> { + let available = if self.start <= self.end { + Step::steps_between(&self.start, &self.end).unwrap_or(usize::MAX) + } else { + 0 + }; + + let taken = available.min(n); + + // SAFETY: the conditions above ensure that the count is in bounds. If start <= end + // then steps_between either returns a bound to which we clamp or returns None which + // together with the initial inequality implies more than usize::MAX steps. + // Otherwise 0 is returned which always safe to use. + self.start = unsafe { Step::forward_unchecked(self.start.clone(), taken) }; + + if taken < n { Err(taken) } else { Ok(()) } + } + #[inline] fn spec_next_back(&mut self) -> Option { if self.start < self.end { @@ -631,6 +684,22 @@ fn spec_nth_back(&mut self, n: usize) -> Option { self.end = self.start.clone(); None } + + #[inline] + fn spec_advance_back_by(&mut self, n: usize) -> Result<(), usize> { + let available = if self.start <= self.end { + Step::steps_between(&self.start, &self.end).unwrap_or(usize::MAX) + } else { + 0 + }; + + let taken = available.min(n); + + // SAFETY: same as the spec_advance_by() implementation + self.end = unsafe { Step::backward_unchecked(self.end.clone(), taken) }; + + if taken < n { Err(taken) } else { Ok(()) } + } } #[stable(feature = "rust1", since = "1.0.0")] @@ -677,6 +746,11 @@ fn is_sorted(self) -> bool { true } + #[inline] + fn advance_by(&mut self, n: usize) -> Result<(), usize> { + self.spec_advance_by(n) + } + #[inline] #[doc(hidden)] unsafe fn __iterator_get_unchecked(&mut self, idx: usize) -> Self::Item @@ -750,6 +824,11 @@ fn next_back(&mut self) -> Option { fn nth_back(&mut self, n: usize) -> Option { self.spec_nth_back(n) } + + #[inline] + fn advance_back_by(&mut self, n: usize) -> Result<(), usize> { + self.spec_advance_back_by(n) + } } // Safety: diff --git a/library/core/src/iter/traits/double_ended.rs b/library/core/src/iter/traits/double_ended.rs index 9a9cf200770..9a589c1f3b5 100644 --- a/library/core/src/iter/traits/double_ended.rs +++ b/library/core/src/iter/traits/double_ended.rs @@ -103,9 +103,15 @@ pub trait DoubleEndedIterator: Iterator { /// elements the iterator is advanced by before running out of elements (i.e. the length /// of the iterator). Note that `k` is always less than `n`. /// - /// Calling `advance_back_by(0)` does not consume any elements and always returns [`Ok(())`]. + /// Calling `advance_back_by(0)` can do meaningful work, for example [`Flatten`] can advance its + /// outer iterator until it finds an inner iterator that is not empty, which then often + /// allows it to return a more accurate `size_hint()` than in its initial state. + /// `advance_back_by(0)` may either return `Ok()` or `Err(0)`. The former conveys no information + /// whether the iterator is or is not exhausted, the latter can be treated as if [`next_back`] + /// had returned `None`. Replacing a `Err(0)` with `Ok` is only correct for `n = 0`. /// /// [`advance_by`]: Iterator::advance_by + /// [`Flatten`]: crate::iter::Flatten /// [`next_back`]: DoubleEndedIterator::next_back /// /// # Examples diff --git a/library/core/src/iter/traits/iterator.rs b/library/core/src/iter/traits/iterator.rs index f2336fb2865..e6b6aec7d94 100644 --- a/library/core/src/iter/traits/iterator.rs +++ b/library/core/src/iter/traits/iterator.rs @@ -246,8 +246,14 @@ fn some(_: Option, x: T) -> Option { /// of elements the iterator is advanced by before running out of elements (i.e. the /// length of the iterator). Note that `k` is always less than `n`. /// - /// Calling `advance_by(0)` does not consume any elements and always returns [`Ok(())`][Ok]. + /// Calling `advance_by(0)` can do meaningful work, for example [`Flatten`] + /// can advance its outer iterator until it finds an inner iterator that is not empty, which + /// then often allows it to return a more accurate `size_hint()` than in its initial state. + /// `advance_by(0)` may either return `Ok()` or `Err(0)`. The former conveys no information + /// whether the iterator is or is not exhausted, the latter can be treated as if [`next`] + /// had returned `None`. Replacing a `Err(0)` with `Ok` is only correct for `n = 0`. /// + /// [`Flatten`]: crate::iter::Flatten /// [`next`]: Iterator::next /// /// # Examples diff --git a/library/core/tests/iter/adapters/flatten.rs b/library/core/tests/iter/adapters/flatten.rs index aaac39c2979..4ae50a2f066 100644 --- a/library/core/tests/iter/adapters/flatten.rs +++ b/library/core/tests/iter/adapters/flatten.rs @@ -58,6 +58,23 @@ fn test_flatten_try_folds() { assert_eq!(iter.next_back(), Some(35)); } +#[test] +fn test_flatten_advance_by() { + let mut it = once(0..10).chain(once(10..30)).chain(once(30..40)).flatten(); + it.advance_by(5).unwrap(); + assert_eq!(it.next(), Some(5)); + it.advance_by(9).unwrap(); + assert_eq!(it.next(), Some(15)); + it.advance_back_by(4).unwrap(); + assert_eq!(it.next_back(), Some(35)); + it.advance_back_by(9).unwrap(); + assert_eq!(it.next_back(), Some(25)); + + assert_eq!(it.advance_by(usize::MAX), Err(9)); + assert_eq!(it.advance_back_by(usize::MAX), Err(0)); + assert_eq!(it.size_hint(), (0, Some(0))); +} + #[test] fn test_flatten_non_fused_outer() { let mut iter = NonFused::new(once(0..2)).flatten(); diff --git a/library/core/tests/iter/range.rs b/library/core/tests/iter/range.rs index 44adc3c58d2..6b4cf33efe1 100644 --- a/library/core/tests/iter/range.rs +++ b/library/core/tests/iter/range.rs @@ -285,6 +285,29 @@ fn test_range_step() { assert_eq!((isize::MIN..isize::MAX).step_by(1).size_hint(), (usize::MAX, Some(usize::MAX))); } +#[test] +fn test_range_advance_by() { + let mut r = 0..usize::MAX; + r.advance_by(0).unwrap(); + r.advance_back_by(0).unwrap(); + + assert_eq!(r.len(), usize::MAX); + + r.advance_by(1).unwrap(); + r.advance_back_by(1).unwrap(); + + assert_eq!((r.start, r.end), (1, usize::MAX - 1)); + + assert_eq!(r.advance_by(usize::MAX), Err(usize::MAX - 2)); + + let mut r = 0u128..u128::MAX; + + r.advance_by(usize::MAX).unwrap(); + r.advance_back_by(usize::MAX).unwrap(); + + assert_eq!((r.start, r.end), (0u128 + usize::MAX as u128, u128::MAX - usize::MAX as u128)); +} + #[test] fn test_range_inclusive_step() { assert_eq!((0..=50).step_by(10).collect::>(), [0, 10, 20, 30, 40, 50]);