diff --git a/library/core/src/iter/adapters/step_by.rs b/library/core/src/iter/adapters/step_by.rs index b8b96417d13..616dd0afc51 100644 --- a/library/core/src/iter/adapters/step_by.rs +++ b/library/core/src/iter/adapters/step_by.rs @@ -1,6 +1,7 @@ use crate::{ intrinsics, iter::{from_fn, TrustedLen, TrustedRandomAccess}, + num::NonZeroUsize, ops::{Range, Try}, }; @@ -22,7 +23,11 @@ pub struct StepBy { /// Additionally this type-dependent preprocessing means specialized implementations /// cannot be used interchangeably. iter: I, - step: usize, + /// This field is `step - 1`, aka the correct amount to pass to `nth` when iterating. + /// It MUST NOT be `usize::MAX`, as `unsafe` code depends on being able to add one + /// without the risk of overflow. (This is important so that length calculations + /// don't need to check for division-by-zero, for example.) + step_minus_one: usize, first_take: bool, } @@ -31,7 +36,16 @@ impl StepBy { pub(in crate::iter) fn new(iter: I, step: usize) -> StepBy { assert!(step != 0); let iter = >::setup(iter, step); - StepBy { iter, step: step - 1, first_take: true } + StepBy { iter, step_minus_one: step - 1, first_take: true } + } + + /// The `step` that was originally passed to `Iterator::step_by(step)`, + /// aka `self.step_minus_one + 1`. + #[inline] + fn original_step(&self) -> NonZeroUsize { + // SAFETY: By type invariant, `step_minus_one` cannot be `MAX`, which + // means the addition cannot overflow and the result cannot be zero. + unsafe { NonZeroUsize::new_unchecked(intrinsics::unchecked_add(self.step_minus_one, 1)) } } } @@ -81,8 +95,8 @@ impl StepBy // The zero-based index starting from the end of the iterator of the // last element. Used in the `DoubleEndedIterator` implementation. fn next_back_index(&self) -> usize { - let rem = self.iter.len() % (self.step + 1); - if self.first_take { if rem == 0 { self.step } else { rem - 1 } } else { rem } + let rem = self.iter.len() % self.original_step(); + if self.first_take { if rem == 0 { self.step_minus_one } else { rem - 1 } } else { rem } } } @@ -209,7 +223,7 @@ unsafe impl StepByImpl for StepBy { #[inline] default fn spec_next(&mut self) -> Option { - let step_size = if self.first_take { 0 } else { self.step }; + let step_size = if self.first_take { 0 } else { self.step_minus_one }; self.first_take = false; self.iter.nth(step_size) } @@ -217,22 +231,22 @@ unsafe impl StepByImpl for StepBy { #[inline] default fn spec_size_hint(&self) -> (usize, Option) { #[inline] - fn first_size(step: usize) -> impl Fn(usize) -> usize { - move |n| if n == 0 { 0 } else { 1 + (n - 1) / (step + 1) } + fn first_size(step: NonZeroUsize) -> impl Fn(usize) -> usize { + move |n| if n == 0 { 0 } else { 1 + (n - 1) / step } } #[inline] - fn other_size(step: usize) -> impl Fn(usize) -> usize { - move |n| n / (step + 1) + fn other_size(step: NonZeroUsize) -> impl Fn(usize) -> usize { + move |n| n / step } let (low, high) = self.iter.size_hint(); if self.first_take { - let f = first_size(self.step); + let f = first_size(self.original_step()); (f(low), high.map(f)) } else { - let f = other_size(self.step); + let f = other_size(self.original_step()); (f(low), high.map(f)) } } @@ -247,10 +261,9 @@ fn other_size(step: usize) -> impl Fn(usize) -> usize { } n -= 1; } - // n and self.step are indices, we need to add 1 to get the amount of elements + // n and self.step_minus_one are indices, we need to add 1 to get the amount of elements // When calling `.nth`, we need to subtract 1 again to convert back to an index - // step + 1 can't overflow because `.step_by` sets `self.step` to `step - 1` - let mut step = self.step + 1; + let mut step = self.original_step().get(); // n + 1 could overflow // thus, if n is usize::MAX, instead of adding one, we call .nth(step) if n == usize::MAX { @@ -288,8 +301,11 @@ fn other_size(step: usize) -> impl Fn(usize) -> usize { R: Try, { #[inline] - fn nth(iter: &mut I, step: usize) -> impl FnMut() -> Option + '_ { - move || iter.nth(step) + fn nth( + iter: &mut I, + step_minus_one: usize, + ) -> impl FnMut() -> Option + '_ { + move || iter.nth(step_minus_one) } if self.first_take { @@ -299,7 +315,7 @@ fn nth(iter: &mut I, step: usize) -> impl FnMut() -> Option acc = f(acc, x)?, } } - from_fn(nth(&mut self.iter, self.step)).try_fold(acc, f) + from_fn(nth(&mut self.iter, self.step_minus_one)).try_fold(acc, f) } default fn spec_fold(mut self, mut acc: Acc, mut f: F) -> Acc @@ -307,8 +323,11 @@ fn nth(iter: &mut I, step: usize) -> impl FnMut() -> Option Acc, { #[inline] - fn nth(iter: &mut I, step: usize) -> impl FnMut() -> Option + '_ { - move || iter.nth(step) + fn nth( + iter: &mut I, + step_minus_one: usize, + ) -> impl FnMut() -> Option + '_ { + move || iter.nth(step_minus_one) } if self.first_take { @@ -318,7 +337,7 @@ fn nth(iter: &mut I, step: usize) -> impl FnMut() -> Option acc = f(acc, x), } } - from_fn(nth(&mut self.iter, self.step)).fold(acc, f) + from_fn(nth(&mut self.iter, self.step_minus_one)).fold(acc, f) } } @@ -336,7 +355,7 @@ unsafe impl StepByBackImpl for St // is out of bounds because the length of `self.iter` does not exceed // `usize::MAX` (because `I: ExactSizeIterator`) and `nth_back` is // zero-indexed - let n = n.saturating_mul(self.step + 1).saturating_add(self.next_back_index()); + let n = n.saturating_mul(self.original_step().get()).saturating_add(self.next_back_index()); self.iter.nth_back(n) } @@ -348,16 +367,16 @@ unsafe impl StepByBackImpl for St #[inline] fn nth_back( iter: &mut I, - step: usize, + step_minus_one: usize, ) -> impl FnMut() -> Option + '_ { - move || iter.nth_back(step) + move || iter.nth_back(step_minus_one) } match self.next_back() { None => try { init }, Some(x) => { let acc = f(init, x)?; - from_fn(nth_back(&mut self.iter, self.step)).try_fold(acc, f) + from_fn(nth_back(&mut self.iter, self.step_minus_one)).try_fold(acc, f) } } } @@ -371,16 +390,16 @@ fn nth_back( #[inline] fn nth_back( iter: &mut I, - step: usize, + step_minus_one: usize, ) -> impl FnMut() -> Option + '_ { - move || iter.nth_back(step) + move || iter.nth_back(step_minus_one) } match self.next_back() { None => init, Some(x) => { let acc = f(init, x); - from_fn(nth_back(&mut self.iter, self.step)).fold(acc, f) + from_fn(nth_back(&mut self.iter, self.step_minus_one)).fold(acc, f) } } } @@ -424,8 +443,7 @@ unsafe impl StepByImpl> for StepBy> { fn spec_next(&mut self) -> Option<$t> { // if a step size larger than the type has been specified fall back to // t::MAX, in which case remaining will be at most 1. - // The `+ 1` can't overflow since the constructor substracted 1 from the original value. - let step = <$t>::try_from(self.step + 1).unwrap_or(<$t>::MAX); + let step = <$t>::try_from(self.original_step().get()).unwrap_or(<$t>::MAX); let remaining = self.iter.end; if remaining > 0 { let val = self.iter.start; @@ -474,7 +492,7 @@ fn spec_fold(self, init: Acc, mut f: F) -> Acc { // if a step size larger than the type has been specified fall back to // t::MAX, in which case remaining will be at most 1. - let step = <$t>::try_from(self.step + 1).unwrap_or(<$t>::MAX); + let step = <$t>::try_from(self.original_step().get()).unwrap_or(<$t>::MAX); let remaining = self.iter.end; let mut acc = init; let mut val = self.iter.start; @@ -500,7 +518,7 @@ unsafe impl StepByBackImpl> for StepBy> { fn spec_next_back(&mut self) -> Option where Range<$t>: DoubleEndedIterator + ExactSizeIterator, { - let step = (self.step + 1) as $t; + let step = self.original_step().get() as $t; let remaining = self.iter.end; if remaining > 0 { let start = self.iter.start; diff --git a/tests/codegen/step_by-overflow-checks.rs b/tests/codegen/step_by-overflow-checks.rs new file mode 100644 index 00000000000..43e8514a8b7 --- /dev/null +++ b/tests/codegen/step_by-overflow-checks.rs @@ -0,0 +1,26 @@ +//@ compile-flags: -O + +#![crate_type = "lib"] + +use std::iter::StepBy; +use std::slice::Iter; + +// The constructor for `StepBy` ensures we can never end up needing to do zero +// checks on denominators, so check that the code isn't emitting panic paths. + +// CHECK-LABEL: @step_by_len_std +#[no_mangle] +pub fn step_by_len_std(x: &StepBy>) -> usize { + // CHECK-NOT: div_by_zero + // CHECK: udiv + // CHECK-NOT: div_by_zero + x.len() +} + +// CHECK-LABEL: @step_by_len_naive +#[no_mangle] +pub fn step_by_len_naive(x: Iter, step_minus_one: usize) -> usize { + // CHECK: udiv + // CHECK: call{{.+}}div_by_zero + x.len() / (step_minus_one + 1) +}