implement advance_(back_)_by on more iterators

This commit is contained in:
The8472 2021-07-12 21:40:38 +02:00
parent 6dc08b909b
commit 2c6e67105e
15 changed files with 375 additions and 2 deletions

View File

@ -111,6 +111,7 @@
// that the feature-gate isn't enabled. Ideally, it wouldn't check for the feature gate for docs
// from other crates, but since this can only appear for lang items, it doesn't seem worth fixing.
#![feature(intra_doc_pointers)]
#![feature(iter_advance_by)]
#![feature(iter_zip)]
#![feature(lang_items)]
#![feature(layout_for_ptr)]

View File

@ -161,6 +161,28 @@ fn size_hint(&self) -> (usize, Option<usize>) {
(exact, Some(exact))
}
#[inline]
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
let step_size = self.len().min(n);
if mem::size_of::<T>() == 0 {
// SAFETY: due to unchecked casts of unsigned amounts to signed offsets the wraparound
// effectively results in unsigned pointers representing positions 0..usize::MAX,
// which is valid for ZSTs.
self.ptr = unsafe { arith_offset(self.ptr as *const i8, step_size as isize) as *mut T }
} else {
let to_drop = ptr::slice_from_raw_parts_mut(self.ptr as *mut T, step_size);
// SAFETY: the min() above ensures that step_size is in bounds
unsafe {
self.ptr = self.ptr.add(step_size);
ptr::drop_in_place(to_drop);
}
}
if step_size < n {
return Err(step_size);
}
Ok(())
}
#[inline]
fn count(self) -> usize {
self.len()
@ -203,6 +225,29 @@ fn next_back(&mut self) -> Option<T> {
Some(unsafe { ptr::read(self.end) })
}
}
#[inline]
fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
let step_size = self.len().min(n);
if mem::size_of::<T>() == 0 {
// SAFETY: same as for advance_by()
self.end = unsafe {
arith_offset(self.end as *const i8, step_size.wrapping_neg() as isize) as *mut T
}
} else {
// SAFETY: same as for advance_by()
self.end = unsafe { self.end.offset(step_size.wrapping_neg() as isize) };
let to_drop = ptr::slice_from_raw_parts_mut(self.end as *mut T, step_size);
// SAFETY: same as for advance_by()
unsafe {
ptr::drop_in_place(to_drop);
}
}
if step_size < n {
return Err(step_size);
}
Ok(())
}
}
#[stable(feature = "rust1", since = "1.0.0")]

View File

@ -18,6 +18,7 @@
#![feature(binary_heap_retain)]
#![feature(binary_heap_as_slice)]
#![feature(inplace_iteration)]
#![feature(iter_advance_by)]
#![feature(slice_group_by)]
#![feature(slice_partition_dedup)]
#![feature(vec_spare_capacity)]

View File

@ -970,6 +970,24 @@ fn drop(&mut self) {
assert_eq!(unsafe { DROPS }, 3);
}
#[test]
fn test_into_iter_advance_by() {
let mut i = vec![1, 2, 3, 4, 5].into_iter();
i.advance_by(0).unwrap();
i.advance_back_by(0).unwrap();
assert_eq!(i.as_slice(), [1, 2, 3, 4, 5]);
i.advance_by(1).unwrap();
i.advance_back_by(1).unwrap();
assert_eq!(i.as_slice(), [2, 3, 4]);
assert_eq!(i.advance_back_by(usize::MAX), Err(3));
assert_eq!(i.advance_by(usize::MAX), Err(0));
assert_eq!(i.len(), 0);
}
#[test]
fn test_from_iter_specialization() {
let src: Vec<usize> = vec![0usize; 1];

View File

@ -76,6 +76,11 @@ fn count(self) -> usize {
self.it.count()
}
#[inline]
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
self.it.advance_by(n)
}
#[doc(hidden)]
unsafe fn __iterator_get_unchecked(&mut self, idx: usize) -> T
where
@ -112,6 +117,11 @@ fn rfold<Acc, F>(self, init: Acc, f: F) -> Acc
{
self.it.rfold(init, copy_fold(f))
}
#[inline]
fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
self.it.advance_back_by(n)
}
}
#[stable(feature = "iter_copied", since = "1.36.0")]

View File

@ -79,6 +79,27 @@ fn try_fold<Acc, F, R>(&mut self, mut acc: Acc, mut f: F) -> R
}
}
#[inline]
#[rustc_inherit_overflow_checks]
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
let mut rem = n;
match self.iter.advance_by(rem) {
ret @ Ok(_) => return ret,
Err(advanced) => rem -= advanced,
}
while rem > 0 {
self.iter = self.orig.clone();
match self.iter.advance_by(rem) {
ret @ Ok(_) => return ret,
Err(0) => return Err(n - rem),
Err(advanced) => rem -= advanced,
}
}
Ok(())
}
// No `fold` override, because `fold` doesn't make much sense for `Cycle`,
// and we can't do anything better than the default.
}

View File

@ -112,6 +112,21 @@ fn enumerate<T, Acc>(
self.iter.fold(init, enumerate(self.count, fold))
}
#[inline]
#[rustc_inherit_overflow_checks]
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
match self.iter.advance_by(n) {
ret @ Ok(_) => {
self.count += n;
ret
}
ret @ Err(advanced) => {
self.count += advanced;
ret
}
}
}
#[rustc_inherit_overflow_checks]
#[doc(hidden)]
unsafe fn __iterator_get_unchecked(&mut self, idx: usize) -> <Self as Iterator>::Item
@ -191,6 +206,13 @@ fn enumerate<T, Acc>(
let count = self.count + self.iter.len();
self.iter.rfold(init, enumerate(count, fold))
}
#[inline]
fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
// we do not need to update the count since that only tallies the number of items
// consumed from the front. consuming items from the back can never reduce that.
self.iter.advance_back_by(n)
}
}
#[stable(feature = "rust1", since = "1.0.0")]

View File

@ -391,6 +391,40 @@ fn flatten<T: IntoIterator, Acc>(
init
}
#[inline]
#[rustc_inherit_overflow_checks]
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
let mut rem = n;
loop {
if let Some(ref mut front) = self.frontiter {
match front.advance_by(rem) {
ret @ Ok(_) => return ret,
Err(advanced) => rem -= advanced,
}
}
self.frontiter = match self.iter.next() {
Some(iterable) => Some(iterable.into_iter()),
_ => break,
}
}
self.frontiter = None;
if let Some(ref mut back) = self.backiter {
if let Err(advanced) = back.advance_by(rem) {
rem -= advanced
}
}
if rem > 0 {
return Err(n - rem);
}
self.backiter = None;
Ok(())
}
}
impl<I, U> DoubleEndedIterator for FlattenCompat<I, U>
@ -486,6 +520,41 @@ fn flatten<T: IntoIterator, Acc>(
init
}
#[inline]
#[rustc_inherit_overflow_checks]
fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
let mut rem = n;
loop {
if let Some(ref mut back) = self.backiter {
match back.advance_back_by(rem) {
ret @ Ok(_) => return ret,
Err(advanced) => rem -= advanced,
}
}
match self.iter.next_back() {
Some(iterable) => self.backiter = Some(iterable.into_iter()),
_ => break,
}
}
self.backiter = None;
if let Some(ref mut front) = self.frontiter {
match front.advance_back_by(rem) {
ret @ Ok(_) => return ret,
Err(advanced) => rem -= advanced,
}
}
if rem > 0 {
return Err(n - rem);
}
self.frontiter = None;
Ok(())
}
}
trait ConstSizeIntoIterator: IntoIterator {

View File

@ -114,6 +114,17 @@ fn fold<Acc, Fold>(mut self, init: Acc, fold: Fold) -> Acc
}
self.iter.fold(init, fold)
}
#[inline]
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
if self.n >= n {
self.n -= n;
return Ok(());
}
let rem = n - self.n;
self.n = 0;
self.iter.advance_by(rem)
}
}
#[stable(feature = "rust1", since = "1.0.0")]
@ -174,6 +185,16 @@ fn ok<Acc, T>(mut f: impl FnMut(Acc, T) -> Acc) -> impl FnMut(Acc, T) -> Result<
self.try_rfold(init, ok(fold)).unwrap()
}
#[inline]
fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
let min = crate::cmp::min(self.len(), n);
return match self.iter.advance_back_by(min) {
ret @ Ok(_) if n <= min => ret,
Ok(_) => Err(min),
_ => panic!("ExactSizeIterator contract violation"),
};
}
}
#[stable(feature = "fused", since = "1.26.0")]

View File

@ -111,6 +111,22 @@ fn ok<B, T>(mut f: impl FnMut(B, T) -> B) -> impl FnMut(B, T) -> Result<B, !> {
self.try_fold(init, ok(fold)).unwrap()
}
#[inline]
#[rustc_inherit_overflow_checks]
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
let min = crate::cmp::min(self.n, n);
return match self.iter.advance_by(min) {
Ok(_) => {
self.n -= min;
if min < n { Err(min) } else { Ok(()) }
}
ret @ Err(advanced) => {
self.n -= advanced;
ret
}
};
}
}
#[unstable(issue = "none", feature = "inplace_iteration")]
@ -197,6 +213,24 @@ fn rfold<Acc, Fold>(mut self, init: Acc, fold: Fold) -> Acc
}
}
}
#[inline]
fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
let inner_len = self.iter.len();
let len = self.n;
let remainder = len.saturating_sub(n);
let to_advance = inner_len - remainder;
match self.iter.advance_back_by(to_advance) {
Ok(_) => {
self.n = remainder;
if n > len {
return Err(len);
}
return Ok(());
}
_ => panic!("ExactSizeIterator contract violation"),
}
}
}
#[stable(feature = "rust1", since = "1.0.0")]

View File

@ -521,10 +521,12 @@ trait RangeIteratorImpl {
// Iterator
fn spec_next(&mut self) -> Option<Self::Item>;
fn spec_nth(&mut self, n: usize) -> Option<Self::Item>;
fn spec_advance_by(&mut self, n: usize) -> Result<(), usize>;
// DoubleEndedIterator
fn spec_next_back(&mut self) -> Option<Self::Item>;
fn spec_nth_back(&mut self, n: usize) -> Option<Self::Item>;
fn spec_advance_back_by(&mut self, n: usize) -> Result<(), usize>;
}
impl<A: Step> RangeIteratorImpl for ops::Range<A> {
@ -555,6 +557,22 @@ impl<A: Step> RangeIteratorImpl for ops::Range<A> {
None
}
#[inline]
default fn spec_advance_by(&mut self, n: usize) -> Result<(), usize> {
let available = if self.start <= self.end {
Step::steps_between(&self.start, &self.end).unwrap_or(usize::MAX)
} else {
0
};
let taken = available.min(n);
self.start =
Step::forward_checked(self.start.clone(), taken).expect("`Step` invariants not upheld");
if taken < n { Err(taken) } else { Ok(()) }
}
#[inline]
default fn spec_next_back(&mut self) -> Option<A> {
if self.start < self.end {
@ -579,6 +597,22 @@ impl<A: Step> RangeIteratorImpl for ops::Range<A> {
self.end = self.start.clone();
None
}
#[inline]
default fn spec_advance_back_by(&mut self, n: usize) -> Result<(), usize> {
let available = if self.start <= self.end {
Step::steps_between(&self.start, &self.end).unwrap_or(usize::MAX)
} else {
0
};
let taken = available.min(n);
self.end =
Step::backward_checked(self.end.clone(), taken).expect("`Step` invariants not upheld");
if taken < n { Err(taken) } else { Ok(()) }
}
}
impl<T: TrustedStep> RangeIteratorImpl for ops::Range<T> {
@ -607,6 +641,25 @@ fn spec_nth(&mut self, n: usize) -> Option<T> {
None
}
#[inline]
fn spec_advance_by(&mut self, n: usize) -> Result<(), usize> {
let available = if self.start <= self.end {
Step::steps_between(&self.start, &self.end).unwrap_or(usize::MAX)
} else {
0
};
let taken = available.min(n);
// SAFETY: the conditions above ensure that the count is in bounds. If start <= end
// then steps_between either returns a bound to which we clamp or returns None which
// together with the initial inequality implies more than usize::MAX steps.
// Otherwise 0 is returned which always safe to use.
self.start = unsafe { Step::forward_unchecked(self.start.clone(), taken) };
if taken < n { Err(taken) } else { Ok(()) }
}
#[inline]
fn spec_next_back(&mut self) -> Option<T> {
if self.start < self.end {
@ -631,6 +684,22 @@ fn spec_nth_back(&mut self, n: usize) -> Option<T> {
self.end = self.start.clone();
None
}
#[inline]
fn spec_advance_back_by(&mut self, n: usize) -> Result<(), usize> {
let available = if self.start <= self.end {
Step::steps_between(&self.start, &self.end).unwrap_or(usize::MAX)
} else {
0
};
let taken = available.min(n);
// SAFETY: same as the spec_advance_by() implementation
self.end = unsafe { Step::backward_unchecked(self.end.clone(), taken) };
if taken < n { Err(taken) } else { Ok(()) }
}
}
#[stable(feature = "rust1", since = "1.0.0")]
@ -677,6 +746,11 @@ fn is_sorted(self) -> bool {
true
}
#[inline]
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
self.spec_advance_by(n)
}
#[inline]
#[doc(hidden)]
unsafe fn __iterator_get_unchecked(&mut self, idx: usize) -> Self::Item
@ -750,6 +824,11 @@ fn next_back(&mut self) -> Option<A> {
fn nth_back(&mut self, n: usize) -> Option<A> {
self.spec_nth_back(n)
}
#[inline]
fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
self.spec_advance_back_by(n)
}
}
// Safety:

View File

@ -103,9 +103,15 @@ pub trait DoubleEndedIterator: Iterator {
/// elements the iterator is advanced by before running out of elements (i.e. the length
/// of the iterator). Note that `k` is always less than `n`.
///
/// Calling `advance_back_by(0)` does not consume any elements and always returns [`Ok(())`].
/// Calling `advance_back_by(0)` can do meaningful work, for example [`Flatten`] can advance its
/// outer iterator until it finds an inner iterator that is not empty, which then often
/// allows it to return a more accurate `size_hint()` than in its initial state.
/// `advance_back_by(0)` may either return `Ok()` or `Err(0)`. The former conveys no information
/// whether the iterator is or is not exhausted, the latter can be treated as if [`next_back`]
/// had returned `None`. Replacing a `Err(0)` with `Ok` is only correct for `n = 0`.
///
/// [`advance_by`]: Iterator::advance_by
/// [`Flatten`]: crate::iter::Flatten
/// [`next_back`]: DoubleEndedIterator::next_back
///
/// # Examples

View File

@ -246,8 +246,14 @@ fn some<T>(_: Option<T>, x: T) -> Option<T> {
/// of elements the iterator is advanced by before running out of elements (i.e. the
/// length of the iterator). Note that `k` is always less than `n`.
///
/// Calling `advance_by(0)` does not consume any elements and always returns [`Ok(())`][Ok].
/// Calling `advance_by(0)` can do meaningful work, for example [`Flatten`]
/// can advance its outer iterator until it finds an inner iterator that is not empty, which
/// then often allows it to return a more accurate `size_hint()` than in its initial state.
/// `advance_by(0)` may either return `Ok()` or `Err(0)`. The former conveys no information
/// whether the iterator is or is not exhausted, the latter can be treated as if [`next`]
/// had returned `None`. Replacing a `Err(0)` with `Ok` is only correct for `n = 0`.
///
/// [`Flatten`]: crate::iter::Flatten
/// [`next`]: Iterator::next
///
/// # Examples

View File

@ -58,6 +58,23 @@ fn test_flatten_try_folds() {
assert_eq!(iter.next_back(), Some(35));
}
#[test]
fn test_flatten_advance_by() {
let mut it = once(0..10).chain(once(10..30)).chain(once(30..40)).flatten();
it.advance_by(5).unwrap();
assert_eq!(it.next(), Some(5));
it.advance_by(9).unwrap();
assert_eq!(it.next(), Some(15));
it.advance_back_by(4).unwrap();
assert_eq!(it.next_back(), Some(35));
it.advance_back_by(9).unwrap();
assert_eq!(it.next_back(), Some(25));
assert_eq!(it.advance_by(usize::MAX), Err(9));
assert_eq!(it.advance_back_by(usize::MAX), Err(0));
assert_eq!(it.size_hint(), (0, Some(0)));
}
#[test]
fn test_flatten_non_fused_outer() {
let mut iter = NonFused::new(once(0..2)).flatten();

View File

@ -285,6 +285,29 @@ fn test_range_step() {
assert_eq!((isize::MIN..isize::MAX).step_by(1).size_hint(), (usize::MAX, Some(usize::MAX)));
}
#[test]
fn test_range_advance_by() {
let mut r = 0..usize::MAX;
r.advance_by(0).unwrap();
r.advance_back_by(0).unwrap();
assert_eq!(r.len(), usize::MAX);
r.advance_by(1).unwrap();
r.advance_back_by(1).unwrap();
assert_eq!((r.start, r.end), (1, usize::MAX - 1));
assert_eq!(r.advance_by(usize::MAX), Err(usize::MAX - 2));
let mut r = 0u128..u128::MAX;
r.advance_by(usize::MAX).unwrap();
r.advance_back_by(usize::MAX).unwrap();
assert_eq!((r.start, r.end), (0u128 + usize::MAX as u128, u128::MAX - usize::MAX as u128));
}
#[test]
fn test_range_inclusive_step() {
assert_eq!((0..=50).step_by(10).collect::<Vec<_>>(), [0, 10, 20, 30, 40, 50]);