Rollup merge of #128166 - ChaiTRex:isqrt, r=tgross35
Improved `checked_isqrt` and `isqrt` methods ### Improved tests of `isqrt` and `checked_isqrt` implementations * Inputs chosen more thoroughly and systematically. * Checks that `isqrt` and `checked_isqrt` have equivalent results for signed types, either equivalent numerically or equivalent as a panic and a `None`. * Checks that `isqrt` has numerically-equivalent results for unsigned types and their `NonZero` counterparts. ### Added benchmarks for `isqrt` implementations ### Greatly sped up `checked_isqrt` and `isqrt` methods * Uses a lookup table for 8-bit integers and then the Karatsuba square root algorithm for larger integers. * Includes optimization hints that give the compiler the exact numeric range of results. ### Feature tracking issue `isqrt` is an unstable feature tracked at #116226. <details><summary>Benchmarked improvements</summary> ### Command used to benchmark ./x bench library/core -- int_sqrt ### Before benchmarks: num::int_sqrt::i128::isqrt 439591.65/iter +/- 6652.70 num::int_sqrt::i16::isqrt 5302.97/iter +/- 160.93 num::int_sqrt::i32::isqrt 62999.11/iter +/- 2022.05 num::int_sqrt::i64::isqrt 125248.81/iter +/- 1674.43 num::int_sqrt::i8::isqrt 123.56/iter +/- 1.87 num::int_sqrt::isize::isqrt 125356.56/iter +/- 1017.03 num::int_sqrt::non_zero_u128::isqrt 437443.75/iter +/- 3535.43 num::int_sqrt::non_zero_u16::isqrt 8604.58/iter +/- 94.76 num::int_sqrt::non_zero_u32::isqrt 62933.33/iter +/- 517.30 num::int_sqrt::non_zero_u64::isqrt 125076.38/iter +/- 11340.61 num::int_sqrt::non_zero_u8::isqrt 221.51/iter +/- 1.58 num::int_sqrt::non_zero_usize::isqrt 136005.21/iter +/- 2020.35 num::int_sqrt::u128::isqrt 439014.55/iter +/- 3920.45 num::int_sqrt::u16::isqrt 8575.08/iter +/- 148.06 num::int_sqrt::u32::isqrt 63008.89/iter +/- 803.67 num::int_sqrt::u64::isqrt 125088.09/iter +/- 879.29 num::int_sqrt::u8::isqrt 230.18/iter +/- 2.04 num::int_sqrt::usize::isqrt 125237.51/iter +/- 4747.83 ### After benchmarks: num::int_sqrt::i128::isqrt 105184.89/iter +/- 1171.38 num::int_sqrt::i16::isqrt 1910.26/iter +/- 78.50 num::int_sqrt::i32::isqrt 34260.34/iter +/- 960.84 num::int_sqrt::i64::isqrt 45939.19/iter +/- 2525.65 num::int_sqrt::i8::isqrt 22.87/iter +/- 0.45 num::int_sqrt::isize::isqrt 45884.17/iter +/- 595.49 num::int_sqrt::non_zero_u128::isqrt 106344.27/iter +/- 780.99 num::int_sqrt::non_zero_u16::isqrt 2790.19/iter +/- 53.43 num::int_sqrt::non_zero_u32::isqrt 33613.99/iter +/- 362.96 num::int_sqrt::non_zero_u64::isqrt 46235.42/iter +/- 429.69 num::int_sqrt::non_zero_u8::isqrt 31.78/iter +/- 0.75 num::int_sqrt::non_zero_usize::isqrt 46208.75/iter +/- 375.27 num::int_sqrt::u128::isqrt 106385.94/iter +/- 1649.95 num::int_sqrt::u16::isqrt 2747.69/iter +/- 28.72 num::int_sqrt::u32::isqrt 33627.09/iter +/- 475.68 num::int_sqrt::u64::isqrt 46182.29/iter +/- 311.16 num::int_sqrt::u8::isqrt 33.10/iter +/- 0.30 num::int_sqrt::usize::isqrt 46165.00/iter +/- 388.41 </details> Tracking Issue for {u8,i8,...}::isqrt #116226 try-job: test-various
This commit is contained in:
commit
4b08b2e400
@ -8,6 +8,7 @@
|
||||
#![feature(iter_array_chunks)]
|
||||
#![feature(iter_next_chunk)]
|
||||
#![feature(iter_advance_by)]
|
||||
#![feature(isqrt)]
|
||||
|
||||
extern crate test;
|
||||
|
||||
|
62
library/core/benches/num/int_sqrt/mod.rs
Normal file
62
library/core/benches/num/int_sqrt/mod.rs
Normal file
@ -0,0 +1,62 @@
|
||||
use rand::Rng;
|
||||
use test::{black_box, Bencher};
|
||||
|
||||
macro_rules! int_sqrt_bench {
|
||||
($t:ty, $predictable:ident, $random:ident, $random_small:ident, $random_uniform:ident) => {
|
||||
#[bench]
|
||||
fn $predictable(bench: &mut Bencher) {
|
||||
bench.iter(|| {
|
||||
for n in 0..(<$t>::BITS / 8) {
|
||||
for i in 1..=(100 as $t) {
|
||||
let x = black_box(i << (n * 8));
|
||||
black_box(x.isqrt());
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn $random(bench: &mut Bencher) {
|
||||
let mut rng = crate::bench_rng();
|
||||
/* Exponentially distributed random numbers from the whole range of the type. */
|
||||
let numbers: Vec<$t> =
|
||||
(0..256).map(|_| rng.gen::<$t>() >> rng.gen_range(0..<$t>::BITS)).collect();
|
||||
bench.iter(|| {
|
||||
for x in &numbers {
|
||||
black_box(black_box(x).isqrt());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn $random_small(bench: &mut Bencher) {
|
||||
let mut rng = crate::bench_rng();
|
||||
/* Exponentially distributed random numbers from the range 0..256. */
|
||||
let numbers: Vec<$t> =
|
||||
(0..256).map(|_| (rng.gen::<u8>() >> rng.gen_range(0..u8::BITS)) as $t).collect();
|
||||
bench.iter(|| {
|
||||
for x in &numbers {
|
||||
black_box(black_box(x).isqrt());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn $random_uniform(bench: &mut Bencher) {
|
||||
let mut rng = crate::bench_rng();
|
||||
/* Exponentially distributed random numbers from the whole range of the type. */
|
||||
let numbers: Vec<$t> = (0..256).map(|_| rng.gen::<$t>()).collect();
|
||||
bench.iter(|| {
|
||||
for x in &numbers {
|
||||
black_box(black_box(x).isqrt());
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
int_sqrt_bench! {u8, u8_sqrt_predictable, u8_sqrt_random, u8_sqrt_random_small, u8_sqrt_uniform}
|
||||
int_sqrt_bench! {u16, u16_sqrt_predictable, u16_sqrt_random, u16_sqrt_random_small, u16_sqrt_uniform}
|
||||
int_sqrt_bench! {u32, u32_sqrt_predictable, u32_sqrt_random, u32_sqrt_random_small, u32_sqrt_uniform}
|
||||
int_sqrt_bench! {u64, u64_sqrt_predictable, u64_sqrt_random, u64_sqrt_random_small, u64_sqrt_uniform}
|
||||
int_sqrt_bench! {u128, u128_sqrt_predictable, u128_sqrt_random, u128_sqrt_random_small, u128_sqrt_uniform}
|
@ -2,6 +2,7 @@
|
||||
mod flt2dec;
|
||||
mod int_log;
|
||||
mod int_pow;
|
||||
mod int_sqrt;
|
||||
|
||||
use std::str::FromStr;
|
||||
|
||||
|
@ -1641,7 +1641,33 @@ pub const fn checked_isqrt(self) -> Option<Self> {
|
||||
if self < 0 {
|
||||
None
|
||||
} else {
|
||||
Some((self as $UnsignedT).isqrt() as Self)
|
||||
// SAFETY: Input is nonnegative in this `else` branch.
|
||||
let result = unsafe {
|
||||
crate::num::int_sqrt::$ActualT(self as $ActualT) as $SelfT
|
||||
};
|
||||
|
||||
// Inform the optimizer what the range of outputs is. If
|
||||
// testing `core` crashes with no panic message and a
|
||||
// `num::int_sqrt::i*` test failed, it's because your edits
|
||||
// caused these assertions to become false.
|
||||
//
|
||||
// SAFETY: Integer square root is a monotonically nondecreasing
|
||||
// function, which means that increasing the input will never
|
||||
// cause the output to decrease. Thus, since the input for
|
||||
// nonnegative signed integers is bounded by
|
||||
// `[0, <$ActualT>::MAX]`, sqrt(n) will be bounded by
|
||||
// `[sqrt(0), sqrt(<$ActualT>::MAX)]`.
|
||||
unsafe {
|
||||
// SAFETY: `<$ActualT>::MAX` is nonnegative.
|
||||
const MAX_RESULT: $SelfT = unsafe {
|
||||
crate::num::int_sqrt::$ActualT(<$ActualT>::MAX) as $SelfT
|
||||
};
|
||||
|
||||
crate::hint::assert_unchecked(result >= 0);
|
||||
crate::hint::assert_unchecked(result <= MAX_RESULT);
|
||||
}
|
||||
|
||||
Some(result)
|
||||
}
|
||||
}
|
||||
|
||||
@ -2862,15 +2888,11 @@ pub const fn pow(self, mut exp: u32) -> Self {
|
||||
#[must_use = "this returns the result of the operation, \
|
||||
without modifying the original"]
|
||||
#[inline]
|
||||
#[track_caller]
|
||||
pub const fn isqrt(self) -> Self {
|
||||
// I would like to implement it as
|
||||
// ```
|
||||
// self.checked_isqrt().expect("argument of integer square root must be non-negative")
|
||||
// ```
|
||||
// but `expect` is not yet stable as a `const fn`.
|
||||
match self.checked_isqrt() {
|
||||
Some(sqrt) => sqrt,
|
||||
None => panic!("argument of integer square root must be non-negative"),
|
||||
None => crate::num::int_sqrt::panic_for_negative_argument(),
|
||||
}
|
||||
}
|
||||
|
||||
|
316
library/core/src/num/int_sqrt.rs
Normal file
316
library/core/src/num/int_sqrt.rs
Normal file
@ -0,0 +1,316 @@
|
||||
//! These functions use the [Karatsuba square root algorithm][1] to compute the
|
||||
//! [integer square root](https://en.wikipedia.org/wiki/Integer_square_root)
|
||||
//! for the primitive integer types.
|
||||
//!
|
||||
//! The signed integer functions can only handle **nonnegative** inputs, so
|
||||
//! that must be checked before calling those.
|
||||
//!
|
||||
//! [1]: <https://web.archive.org/web/20230511212802/https://inria.hal.science/inria-00072854v1/file/RR-3805.pdf>
|
||||
//! "Paul Zimmermann. Karatsuba Square Root. \[Research Report\] RR-3805,
|
||||
//! INRIA. 1999, pp.8. (inria-00072854)"
|
||||
|
||||
/// This array stores the [integer square roots](
|
||||
/// https://en.wikipedia.org/wiki/Integer_square_root) and remainders of each
|
||||
/// [`u8`](prim@u8) value. For example, `U8_ISQRT_WITH_REMAINDER[17]` will be
|
||||
/// `(4, 1)` because the integer square root of 17 is 4 and because 17 is 1
|
||||
/// higher than 4 squared.
|
||||
const U8_ISQRT_WITH_REMAINDER: [(u8, u8); 256] = {
|
||||
let mut result = [(0, 0); 256];
|
||||
|
||||
let mut n: usize = 0;
|
||||
let mut isqrt_n: usize = 0;
|
||||
while n < result.len() {
|
||||
result[n] = (isqrt_n as u8, (n - isqrt_n.pow(2)) as u8);
|
||||
|
||||
n += 1;
|
||||
if n == (isqrt_n + 1).pow(2) {
|
||||
isqrt_n += 1;
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
};
|
||||
|
||||
/// Returns the [integer square root](
|
||||
/// https://en.wikipedia.org/wiki/Integer_square_root) of any [`u8`](prim@u8)
|
||||
/// input.
|
||||
#[must_use = "this returns the result of the operation, \
|
||||
without modifying the original"]
|
||||
#[inline]
|
||||
pub const fn u8(n: u8) -> u8 {
|
||||
U8_ISQRT_WITH_REMAINDER[n as usize].0
|
||||
}
|
||||
|
||||
/// Generates an `i*` function that returns the [integer square root](
|
||||
/// https://en.wikipedia.org/wiki/Integer_square_root) of any **nonnegative**
|
||||
/// input of a specific signed integer type.
|
||||
macro_rules! signed_fn {
|
||||
($SignedT:ident, $UnsignedT:ident) => {
|
||||
/// Returns the [integer square root](
|
||||
/// https://en.wikipedia.org/wiki/Integer_square_root) of any
|
||||
/// **nonnegative**
|
||||
#[doc = concat!("[`", stringify!($SignedT), "`](prim@", stringify!($SignedT), ")")]
|
||||
/// input.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// This results in undefined behavior when the input is negative.
|
||||
#[must_use = "this returns the result of the operation, \
|
||||
without modifying the original"]
|
||||
#[inline]
|
||||
pub const unsafe fn $SignedT(n: $SignedT) -> $SignedT {
|
||||
debug_assert!(n >= 0, "Negative input inside `isqrt`.");
|
||||
$UnsignedT(n as $UnsignedT) as $SignedT
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
signed_fn!(i8, u8);
|
||||
signed_fn!(i16, u16);
|
||||
signed_fn!(i32, u32);
|
||||
signed_fn!(i64, u64);
|
||||
signed_fn!(i128, u128);
|
||||
|
||||
/// Generates a `u*` function that returns the [integer square root](
|
||||
/// https://en.wikipedia.org/wiki/Integer_square_root) of any input of
|
||||
/// a specific unsigned integer type.
|
||||
macro_rules! unsigned_fn {
|
||||
($UnsignedT:ident, $HalfBitsT:ident, $stages:ident) => {
|
||||
/// Returns the [integer square root](
|
||||
/// https://en.wikipedia.org/wiki/Integer_square_root) of any
|
||||
#[doc = concat!("[`", stringify!($UnsignedT), "`](prim@", stringify!($UnsignedT), ")")]
|
||||
/// input.
|
||||
#[must_use = "this returns the result of the operation, \
|
||||
without modifying the original"]
|
||||
#[inline]
|
||||
pub const fn $UnsignedT(mut n: $UnsignedT) -> $UnsignedT {
|
||||
if n <= <$HalfBitsT>::MAX as $UnsignedT {
|
||||
$HalfBitsT(n as $HalfBitsT) as $UnsignedT
|
||||
} else {
|
||||
// The normalization shift satisfies the Karatsuba square root
|
||||
// algorithm precondition "a₃ ≥ b/4" where a₃ is the most
|
||||
// significant quarter of `n`'s bits and b is the number of
|
||||
// values that can be represented by that quarter of the bits.
|
||||
//
|
||||
// b/4 would then be all 0s except the second most significant
|
||||
// bit (010...0) in binary. Since a₃ must be at least b/4, a₃'s
|
||||
// most significant bit or its neighbor must be a 1. Since a₃'s
|
||||
// most significant bits are `n`'s most significant bits, the
|
||||
// same applies to `n`.
|
||||
//
|
||||
// The reason to shift by an even number of bits is because an
|
||||
// even number of bits produces the square root shifted to the
|
||||
// left by half of the normalization shift:
|
||||
//
|
||||
// sqrt(n << (2 * p))
|
||||
// sqrt(2.pow(2 * p) * n)
|
||||
// sqrt(2.pow(2 * p)) * sqrt(n)
|
||||
// 2.pow(p) * sqrt(n)
|
||||
// sqrt(n) << p
|
||||
//
|
||||
// Shifting by an odd number of bits leaves an ugly sqrt(2)
|
||||
// multiplied in:
|
||||
//
|
||||
// sqrt(n << (2 * p + 1))
|
||||
// sqrt(2.pow(2 * p + 1) * n)
|
||||
// sqrt(2 * 2.pow(2 * p) * n)
|
||||
// sqrt(2) * sqrt(2.pow(2 * p)) * sqrt(n)
|
||||
// sqrt(2) * 2.pow(p) * sqrt(n)
|
||||
// sqrt(2) * (sqrt(n) << p)
|
||||
const EVEN_MAKING_BITMASK: u32 = !1;
|
||||
let normalization_shift = n.leading_zeros() & EVEN_MAKING_BITMASK;
|
||||
n <<= normalization_shift;
|
||||
|
||||
let s = $stages(n);
|
||||
|
||||
let denormalization_shift = normalization_shift >> 1;
|
||||
s >> denormalization_shift
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Generates the first stage of the computation after normalization.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// `$n` must be nonzero.
|
||||
macro_rules! first_stage {
|
||||
($original_bits:literal, $n:ident) => {{
|
||||
debug_assert!($n != 0, "`$n` is zero in `first_stage!`.");
|
||||
|
||||
const N_SHIFT: u32 = $original_bits - 8;
|
||||
let n = $n >> N_SHIFT;
|
||||
|
||||
let (s, r) = U8_ISQRT_WITH_REMAINDER[n as usize];
|
||||
|
||||
// Inform the optimizer that `s` is nonzero. This will allow it to
|
||||
// avoid generating code to handle division-by-zero panics in the next
|
||||
// stage.
|
||||
//
|
||||
// SAFETY: If the original `$n` is zero, the top of the `unsigned_fn`
|
||||
// macro recurses instead of continuing to this point, so the original
|
||||
// `$n` wasn't a 0 if we've reached here.
|
||||
//
|
||||
// Then the `unsigned_fn` macro normalizes `$n` so that at least one of
|
||||
// its two most-significant bits is a 1.
|
||||
//
|
||||
// Then this stage puts the eight most-significant bits of `$n` into
|
||||
// `n`. This means that `n` here has at least one 1 bit in its two
|
||||
// most-significant bits, making `n` nonzero.
|
||||
//
|
||||
// `U8_ISQRT_WITH_REMAINDER[n as usize]` will give a nonzero `s` when
|
||||
// given a nonzero `n`.
|
||||
unsafe { crate::hint::assert_unchecked(s != 0) };
|
||||
(s, r)
|
||||
}};
|
||||
}
|
||||
|
||||
/// Generates a middle stage of the computation.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// `$s` must be nonzero.
|
||||
macro_rules! middle_stage {
|
||||
($original_bits:literal, $ty:ty, $n:ident, $s:ident, $r:ident) => {{
|
||||
debug_assert!($s != 0, "`$s` is zero in `middle_stage!`.");
|
||||
|
||||
const N_SHIFT: u32 = $original_bits - <$ty>::BITS;
|
||||
let n = ($n >> N_SHIFT) as $ty;
|
||||
|
||||
const HALF_BITS: u32 = <$ty>::BITS >> 1;
|
||||
const QUARTER_BITS: u32 = <$ty>::BITS >> 2;
|
||||
const LOWER_HALF_1_BITS: $ty = (1 << HALF_BITS) - 1;
|
||||
const LOWEST_QUARTER_1_BITS: $ty = (1 << QUARTER_BITS) - 1;
|
||||
|
||||
let lo = n & LOWER_HALF_1_BITS;
|
||||
let numerator = (($r as $ty) << QUARTER_BITS) | (lo >> QUARTER_BITS);
|
||||
let denominator = ($s as $ty) << 1;
|
||||
let q = numerator / denominator;
|
||||
let u = numerator % denominator;
|
||||
|
||||
let mut s = ($s << QUARTER_BITS) as $ty + q;
|
||||
let (mut r, overflow) =
|
||||
((u << QUARTER_BITS) | (lo & LOWEST_QUARTER_1_BITS)).overflowing_sub(q * q);
|
||||
if overflow {
|
||||
r = r.wrapping_add(2 * s - 1);
|
||||
s -= 1;
|
||||
}
|
||||
|
||||
// Inform the optimizer that `s` is nonzero. This will allow it to
|
||||
// avoid generating code to handle division-by-zero panics in the next
|
||||
// stage.
|
||||
//
|
||||
// SAFETY: If the original `$n` is zero, the top of the `unsigned_fn`
|
||||
// macro recurses instead of continuing to this point, so the original
|
||||
// `$n` wasn't a 0 if we've reached here.
|
||||
//
|
||||
// Then the `unsigned_fn` macro normalizes `$n` so that at least one of
|
||||
// its two most-significant bits is a 1.
|
||||
//
|
||||
// Then these stages take as many of the most-significant bits of `$n`
|
||||
// as will fit in this stage's type. For example, the stage that
|
||||
// handles `u32` deals with the 32 most-significant bits of `$n`. This
|
||||
// means that each stage has at least one 1 bit in `n`'s two
|
||||
// most-significant bits, making `n` nonzero.
|
||||
//
|
||||
// Then this stage will produce the correct integer square root for
|
||||
// that `n` value. Since `n` is nonzero, `s` will also be nonzero.
|
||||
unsafe { crate::hint::assert_unchecked(s != 0) };
|
||||
(s, r)
|
||||
}};
|
||||
}
|
||||
|
||||
/// Generates the last stage of the computation before denormalization.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// `$s` must be nonzero.
|
||||
macro_rules! last_stage {
|
||||
($ty:ty, $n:ident, $s:ident, $r:ident) => {{
|
||||
debug_assert!($s != 0, "`$s` is zero in `last_stage!`.");
|
||||
|
||||
const HALF_BITS: u32 = <$ty>::BITS >> 1;
|
||||
const QUARTER_BITS: u32 = <$ty>::BITS >> 2;
|
||||
const LOWER_HALF_1_BITS: $ty = (1 << HALF_BITS) - 1;
|
||||
|
||||
let lo = $n & LOWER_HALF_1_BITS;
|
||||
let numerator = (($r as $ty) << QUARTER_BITS) | (lo >> QUARTER_BITS);
|
||||
let denominator = ($s as $ty) << 1;
|
||||
|
||||
let q = numerator / denominator;
|
||||
let mut s = ($s << QUARTER_BITS) as $ty + q;
|
||||
let (s_squared, overflow) = s.overflowing_mul(s);
|
||||
if overflow || s_squared > $n {
|
||||
s -= 1;
|
||||
}
|
||||
s
|
||||
}};
|
||||
}
|
||||
|
||||
/// Takes the normalized [`u16`](prim@u16) input and gets its normalized
|
||||
/// [integer square root](https://en.wikipedia.org/wiki/Integer_square_root).
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// `n` must be nonzero.
|
||||
#[inline]
|
||||
const fn u16_stages(n: u16) -> u16 {
|
||||
let (s, r) = first_stage!(16, n);
|
||||
last_stage!(u16, n, s, r)
|
||||
}
|
||||
|
||||
/// Takes the normalized [`u32`](prim@u32) input and gets its normalized
|
||||
/// [integer square root](https://en.wikipedia.org/wiki/Integer_square_root).
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// `n` must be nonzero.
|
||||
#[inline]
|
||||
const fn u32_stages(n: u32) -> u32 {
|
||||
let (s, r) = first_stage!(32, n);
|
||||
let (s, r) = middle_stage!(32, u16, n, s, r);
|
||||
last_stage!(u32, n, s, r)
|
||||
}
|
||||
|
||||
/// Takes the normalized [`u64`](prim@u64) input and gets its normalized
|
||||
/// [integer square root](https://en.wikipedia.org/wiki/Integer_square_root).
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// `n` must be nonzero.
|
||||
#[inline]
|
||||
const fn u64_stages(n: u64) -> u64 {
|
||||
let (s, r) = first_stage!(64, n);
|
||||
let (s, r) = middle_stage!(64, u16, n, s, r);
|
||||
let (s, r) = middle_stage!(64, u32, n, s, r);
|
||||
last_stage!(u64, n, s, r)
|
||||
}
|
||||
|
||||
/// Takes the normalized [`u128`](prim@u128) input and gets its normalized
|
||||
/// [integer square root](https://en.wikipedia.org/wiki/Integer_square_root).
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// `n` must be nonzero.
|
||||
#[inline]
|
||||
const fn u128_stages(n: u128) -> u128 {
|
||||
let (s, r) = first_stage!(128, n);
|
||||
let (s, r) = middle_stage!(128, u16, n, s, r);
|
||||
let (s, r) = middle_stage!(128, u32, n, s, r);
|
||||
let (s, r) = middle_stage!(128, u64, n, s, r);
|
||||
last_stage!(u128, n, s, r)
|
||||
}
|
||||
|
||||
unsigned_fn!(u16, u8, u16_stages);
|
||||
unsigned_fn!(u32, u16, u32_stages);
|
||||
unsigned_fn!(u64, u32, u64_stages);
|
||||
unsigned_fn!(u128, u64, u128_stages);
|
||||
|
||||
/// Instantiate this panic logic once, rather than for all the isqrt methods
|
||||
/// on every single primitive type.
|
||||
#[cold]
|
||||
#[track_caller]
|
||||
pub const fn panic_for_negative_argument() -> ! {
|
||||
panic!("argument of integer square root cannot be negative")
|
||||
}
|
@ -41,6 +41,7 @@ macro_rules! unlikely {
|
||||
|
||||
mod error;
|
||||
mod int_log10;
|
||||
mod int_sqrt;
|
||||
mod nonzero;
|
||||
mod overflow_panic;
|
||||
mod saturating;
|
||||
|
@ -7,7 +7,7 @@
|
||||
use crate::ops::{BitOr, BitOrAssign, Div, DivAssign, Neg, Rem, RemAssign};
|
||||
use crate::panic::{RefUnwindSafe, UnwindSafe};
|
||||
use crate::str::FromStr;
|
||||
use crate::{fmt, hint, intrinsics, ptr, ub_checks};
|
||||
use crate::{fmt, intrinsics, ptr, ub_checks};
|
||||
|
||||
/// A marker trait for primitive types which can be zero.
|
||||
///
|
||||
@ -1545,31 +1545,14 @@ pub const fn is_power_of_two(self) -> bool {
|
||||
without modifying the original"]
|
||||
#[inline]
|
||||
pub const fn isqrt(self) -> Self {
|
||||
// The algorithm is based on the one presented in
|
||||
// <https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Binary_numeral_system_(base_2)>
|
||||
// which cites as source the following C code:
|
||||
// <https://web.archive.org/web/20120306040058/http://medialab.freaknet.org/martin/src/sqrt/sqrt.c>.
|
||||
let result = self.get().isqrt();
|
||||
|
||||
let mut op = self.get();
|
||||
let mut res = 0;
|
||||
let mut one = 1 << (self.ilog2() & !1);
|
||||
|
||||
while one != 0 {
|
||||
if op >= res + one {
|
||||
op -= res + one;
|
||||
res = (res >> 1) + one;
|
||||
} else {
|
||||
res >>= 1;
|
||||
}
|
||||
one >>= 2;
|
||||
}
|
||||
|
||||
// SAFETY: The result fits in an integer with half as many bits.
|
||||
// Inform the optimizer about it.
|
||||
unsafe { hint::assert_unchecked(res < 1 << (Self::BITS / 2)) };
|
||||
|
||||
// SAFETY: The square root of an integer >= 1 is always >= 1.
|
||||
unsafe { Self::new_unchecked(res) }
|
||||
// SAFETY: Integer square root is a monotonically nondecreasing
|
||||
// function, which means that increasing the input will never cause
|
||||
// the output to decrease. Thus, since the input for nonzero
|
||||
// unsigned integers has a lower bound of 1, the lower bound of the
|
||||
// results will be sqrt(1), which is 1, so a result can't be zero.
|
||||
unsafe { Self::new_unchecked(result) }
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -2762,10 +2762,24 @@ pub const fn pow(self, mut exp: u32) -> Self {
|
||||
without modifying the original"]
|
||||
#[inline]
|
||||
pub const fn isqrt(self) -> Self {
|
||||
match NonZero::new(self) {
|
||||
Some(x) => x.isqrt().get(),
|
||||
None => 0,
|
||||
let result = crate::num::int_sqrt::$ActualT(self as $ActualT) as $SelfT;
|
||||
|
||||
// Inform the optimizer what the range of outputs is. If testing
|
||||
// `core` crashes with no panic message and a `num::int_sqrt::u*`
|
||||
// test failed, it's because your edits caused these assertions or
|
||||
// the assertions in `fn isqrt` of `nonzero.rs` to become false.
|
||||
//
|
||||
// SAFETY: Integer square root is a monotonically nondecreasing
|
||||
// function, which means that increasing the input will never
|
||||
// cause the output to decrease. Thus, since the input for unsigned
|
||||
// integers is bounded by `[0, <$ActualT>::MAX]`, sqrt(n) will be
|
||||
// bounded by `[sqrt(0), sqrt(<$ActualT>::MAX)]`.
|
||||
unsafe {
|
||||
const MAX_RESULT: $SelfT = crate::num::int_sqrt::$ActualT(<$ActualT>::MAX) as $SelfT;
|
||||
crate::hint::assert_unchecked(result <= MAX_RESULT);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Performs Euclidean division.
|
||||
|
@ -288,38 +288,6 @@ fn test_pow() {
|
||||
assert_eq!(r.saturating_pow(0), 1 as $T);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_isqrt() {
|
||||
assert_eq!($T::MIN.checked_isqrt(), None);
|
||||
assert_eq!((-1 as $T).checked_isqrt(), None);
|
||||
assert_eq!((0 as $T).isqrt(), 0 as $T);
|
||||
assert_eq!((1 as $T).isqrt(), 1 as $T);
|
||||
assert_eq!((2 as $T).isqrt(), 1 as $T);
|
||||
assert_eq!((99 as $T).isqrt(), 9 as $T);
|
||||
assert_eq!((100 as $T).isqrt(), 10 as $T);
|
||||
}
|
||||
|
||||
#[cfg(not(miri))] // Miri is too slow
|
||||
#[test]
|
||||
fn test_lots_of_isqrt() {
|
||||
let n_max: $T = (1024 * 1024).min($T::MAX as u128) as $T;
|
||||
for n in 0..=n_max {
|
||||
let isqrt: $T = n.isqrt();
|
||||
|
||||
assert!(isqrt.pow(2) <= n);
|
||||
let (square, overflow) = (isqrt + 1).overflowing_pow(2);
|
||||
assert!(overflow || square > n);
|
||||
}
|
||||
|
||||
for n in ($T::MAX - 127)..=$T::MAX {
|
||||
let isqrt: $T = n.isqrt();
|
||||
|
||||
assert!(isqrt.pow(2) <= n);
|
||||
let (square, overflow) = (isqrt + 1).overflowing_pow(2);
|
||||
assert!(overflow || square > n);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_div_floor() {
|
||||
let a: $T = 8;
|
||||
|
248
library/core/tests/num/int_sqrt.rs
Normal file
248
library/core/tests/num/int_sqrt.rs
Normal file
@ -0,0 +1,248 @@
|
||||
macro_rules! tests {
|
||||
($isqrt_consistency_check_fn_macro:ident : $($T:ident)+) => {
|
||||
$(
|
||||
mod $T {
|
||||
$isqrt_consistency_check_fn_macro!($T);
|
||||
|
||||
// Check that the following produce the correct values from
|
||||
// `isqrt`:
|
||||
//
|
||||
// * the first and last 128 nonnegative values
|
||||
// * powers of two, minus one
|
||||
// * powers of two
|
||||
//
|
||||
// For signed types, check that `checked_isqrt` and `isqrt`
|
||||
// either produce the same numeric value or respectively
|
||||
// produce `None` and a panic. Make sure to do a consistency
|
||||
// check for `<$T>::MIN` as well, as no nonnegative values
|
||||
// negate to it.
|
||||
//
|
||||
// For unsigned types check that `isqrt` produces the same
|
||||
// numeric value for `$T` and `NonZero<$T>`.
|
||||
#[test]
|
||||
fn isqrt() {
|
||||
isqrt_consistency_check(<$T>::MIN);
|
||||
|
||||
for n in (0..=127)
|
||||
.chain(<$T>::MAX - 127..=<$T>::MAX)
|
||||
.chain((0..<$T>::MAX.count_ones()).map(|exponent| (1 << exponent) - 1))
|
||||
.chain((0..<$T>::MAX.count_ones()).map(|exponent| 1 << exponent))
|
||||
{
|
||||
isqrt_consistency_check(n);
|
||||
|
||||
let isqrt_n = n.isqrt();
|
||||
assert!(
|
||||
isqrt_n
|
||||
.checked_mul(isqrt_n)
|
||||
.map(|isqrt_n_squared| isqrt_n_squared <= n)
|
||||
.unwrap_or(false),
|
||||
"`{n}.isqrt()` should be lower than {isqrt_n}."
|
||||
);
|
||||
assert!(
|
||||
(isqrt_n + 1)
|
||||
.checked_mul(isqrt_n + 1)
|
||||
.map(|isqrt_n_plus_1_squared| n < isqrt_n_plus_1_squared)
|
||||
.unwrap_or(true),
|
||||
"`{n}.isqrt()` should be higher than {isqrt_n})."
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Check the square roots of:
|
||||
//
|
||||
// * the first 1,024 perfect squares
|
||||
// * halfway between each of the first 1,024 perfect squares
|
||||
// and the next perfect square
|
||||
// * the next perfect square after the each of the first 1,024
|
||||
// perfect squares, minus one
|
||||
// * the last 1,024 perfect squares
|
||||
// * the last 1,024 perfect squares, minus one
|
||||
// * halfway between each of the last 1,024 perfect squares
|
||||
// and the previous perfect square
|
||||
#[test]
|
||||
// Skip this test on Miri, as it takes too long to run.
|
||||
#[cfg(not(miri))]
|
||||
fn isqrt_extended() {
|
||||
// The correct value is worked out by using the fact that
|
||||
// the nth nonzero perfect square is the sum of the first n
|
||||
// odd numbers:
|
||||
//
|
||||
// 1 = 1
|
||||
// 4 = 1 + 3
|
||||
// 9 = 1 + 3 + 5
|
||||
// 16 = 1 + 3 + 5 + 7
|
||||
//
|
||||
// Note also that the last odd number added in is two times
|
||||
// the square root of the previous perfect square, plus
|
||||
// one:
|
||||
//
|
||||
// 1 = 2*0 + 1
|
||||
// 3 = 2*1 + 1
|
||||
// 5 = 2*2 + 1
|
||||
// 7 = 2*3 + 1
|
||||
//
|
||||
// That means we can add the square root of this perfect
|
||||
// square once to get about halfway to the next perfect
|
||||
// square, then we can add the square root of this perfect
|
||||
// square again to get to the next perfect square, minus
|
||||
// one, then we can add one to get to the next perfect
|
||||
// square.
|
||||
//
|
||||
// This allows us to, for each of the first 1,024 perfect
|
||||
// squares, test that the square roots of the following are
|
||||
// all correct and equal to each other:
|
||||
//
|
||||
// * the current perfect square
|
||||
// * about halfway to the next perfect square
|
||||
// * the next perfect square, minus one
|
||||
let mut n: $T = 0;
|
||||
for sqrt_n in 0..1_024.min((1_u128 << (<$T>::MAX.count_ones()/2)) - 1) as $T {
|
||||
isqrt_consistency_check(n);
|
||||
assert_eq!(
|
||||
n.isqrt(),
|
||||
sqrt_n,
|
||||
"`{sqrt_n}.pow(2).isqrt()` should be {sqrt_n}."
|
||||
);
|
||||
|
||||
n += sqrt_n;
|
||||
isqrt_consistency_check(n);
|
||||
assert_eq!(
|
||||
n.isqrt(),
|
||||
sqrt_n,
|
||||
"{n} is about halfway between `{sqrt_n}.pow(2)` and `{}.pow(2)`, so `{n}.isqrt()` should be {sqrt_n}.",
|
||||
sqrt_n + 1
|
||||
);
|
||||
|
||||
n += sqrt_n;
|
||||
isqrt_consistency_check(n);
|
||||
assert_eq!(
|
||||
n.isqrt(),
|
||||
sqrt_n,
|
||||
"`({}.pow(2) - 1).isqrt()` should be {sqrt_n}.",
|
||||
sqrt_n + 1
|
||||
);
|
||||
|
||||
n += 1;
|
||||
}
|
||||
|
||||
// Similarly, for each of the last 1,024 perfect squares,
|
||||
// check:
|
||||
//
|
||||
// * the current perfect square
|
||||
// * the current perfect square, minus one
|
||||
// * about halfway to the previous perfect square
|
||||
//
|
||||
// `MAX`'s `isqrt` return value is verified in the `isqrt`
|
||||
// test function above.
|
||||
let maximum_sqrt = <$T>::MAX.isqrt();
|
||||
let mut n = maximum_sqrt * maximum_sqrt;
|
||||
|
||||
for sqrt_n in (maximum_sqrt - 1_024.min((1_u128 << (<$T>::MAX.count_ones()/2)) - 1) as $T..maximum_sqrt).rev() {
|
||||
isqrt_consistency_check(n);
|
||||
assert_eq!(
|
||||
n.isqrt(),
|
||||
sqrt_n + 1,
|
||||
"`{0}.pow(2).isqrt()` should be {0}.",
|
||||
sqrt_n + 1
|
||||
);
|
||||
|
||||
n -= 1;
|
||||
isqrt_consistency_check(n);
|
||||
assert_eq!(
|
||||
n.isqrt(),
|
||||
sqrt_n,
|
||||
"`({}.pow(2) - 1).isqrt()` should be {sqrt_n}.",
|
||||
sqrt_n + 1
|
||||
);
|
||||
|
||||
n -= sqrt_n;
|
||||
isqrt_consistency_check(n);
|
||||
assert_eq!(
|
||||
n.isqrt(),
|
||||
sqrt_n,
|
||||
"{n} is about halfway between `{sqrt_n}.pow(2)` and `{}.pow(2)`, so `{n}.isqrt()` should be {sqrt_n}.",
|
||||
sqrt_n + 1
|
||||
);
|
||||
|
||||
n -= sqrt_n;
|
||||
}
|
||||
}
|
||||
}
|
||||
)*
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! signed_check {
|
||||
($T:ident) => {
|
||||
/// This takes an input and, if it's nonnegative or
|
||||
#[doc = concat!("`", stringify!($T), "::MIN`,")]
|
||||
/// checks that `isqrt` and `checked_isqrt` produce equivalent results
|
||||
/// for that input and for the negative of that input.
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// This cannot check that negative inputs to `isqrt` cause panics if
|
||||
/// panics abort instead of unwind.
|
||||
fn isqrt_consistency_check(n: $T) {
|
||||
// `<$T>::MIN` will be negative, so ignore it in this nonnegative
|
||||
// section.
|
||||
if n >= 0 {
|
||||
assert_eq!(
|
||||
Some(n.isqrt()),
|
||||
n.checked_isqrt(),
|
||||
"`{n}.checked_isqrt()` should match `Some({n}.isqrt())`.",
|
||||
);
|
||||
}
|
||||
|
||||
// `wrapping_neg` so that `<$T>::MIN` will negate to itself rather
|
||||
// than panicking.
|
||||
let negative_n = n.wrapping_neg();
|
||||
|
||||
// Zero negated will still be nonnegative, so ignore it in this
|
||||
// negative section.
|
||||
if negative_n < 0 {
|
||||
assert_eq!(
|
||||
negative_n.checked_isqrt(),
|
||||
None,
|
||||
"`({negative_n}).checked_isqrt()` should be `None`, as {negative_n} is negative.",
|
||||
);
|
||||
|
||||
// `catch_unwind` only works when panics unwind rather than abort.
|
||||
#[cfg(panic = "unwind")]
|
||||
{
|
||||
std::panic::catch_unwind(core::panic::AssertUnwindSafe(|| (-n).isqrt())).expect_err(
|
||||
&format!("`({negative_n}).isqrt()` should have panicked, as {negative_n} is negative.")
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! unsigned_check {
|
||||
($T:ident) => {
|
||||
/// This takes an input and, if it's nonzero, checks that `isqrt`
|
||||
/// produces the same numeric value for both
|
||||
#[doc = concat!("`", stringify!($T), "` and ")]
|
||||
#[doc = concat!("`NonZero<", stringify!($T), ">`.")]
|
||||
fn isqrt_consistency_check(n: $T) {
|
||||
// Zero cannot be turned into a `NonZero` value, so ignore it in
|
||||
// this nonzero section.
|
||||
if n > 0 {
|
||||
assert_eq!(
|
||||
n.isqrt(),
|
||||
core::num::NonZero::<$T>::new(n)
|
||||
.expect(
|
||||
"Was not able to create a new `NonZero` value from a nonzero number."
|
||||
)
|
||||
.isqrt()
|
||||
.get(),
|
||||
"`{n}.isqrt` should match `NonZero`'s `{n}.isqrt().get()`.",
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
tests!(signed_check: i8 i16 i32 i64 i128);
|
||||
tests!(unsigned_check: u8 u16 u32 u64 u128);
|
@ -27,6 +27,7 @@
|
||||
mod dec2flt;
|
||||
mod flt2dec;
|
||||
mod int_log;
|
||||
mod int_sqrt;
|
||||
mod ops;
|
||||
mod wrapping;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user