diff --git a/library/core/src/lib.rs b/library/core/src/lib.rs index 8b04bafcda5..be734a9ba52 100644 --- a/library/core/src/lib.rs +++ b/library/core/src/lib.rs @@ -178,6 +178,7 @@ #![feature(ip)] #![feature(ip_bits)] #![feature(is_ascii_octdigit)] +#![feature(isqrt)] #![feature(maybe_uninit_uninit_array)] #![feature(ptr_alignment_type)] #![feature(ptr_metadata)] diff --git a/library/core/src/num/int_macros.rs b/library/core/src/num/int_macros.rs index 1f43520e1b3..3cbb55af3bc 100644 --- a/library/core/src/num/int_macros.rs +++ b/library/core/src/num/int_macros.rs @@ -898,6 +898,30 @@ pub const fn checked_pow(self, mut exp: u32) -> Option { acc.checked_mul(base) } + /// Returns the square root of the number, rounded down. + /// + /// Returns `None` if `self` is negative. + /// + /// # Examples + /// + /// Basic usage: + /// ``` + /// #![feature(isqrt)] + #[doc = concat!("assert_eq!(10", stringify!($SelfT), ".checked_isqrt(), Some(3));")] + /// ``` + #[unstable(feature = "isqrt", issue = "116226")] + #[rustc_const_unstable(feature = "isqrt", issue = "116226")] + #[must_use = "this returns the result of the operation, \ + without modifying the original"] + #[inline] + pub const fn checked_isqrt(self) -> Option { + if self < 0 { + None + } else { + Some((self as $UnsignedT).isqrt() as Self) + } + } + /// Saturating integer addition. Computes `self + rhs`, saturating at the numeric /// bounds instead of overflowing. /// @@ -2061,6 +2085,36 @@ pub const fn pow(self, mut exp: u32) -> Self { acc * base } + /// Returns the square root of the number, rounded down. + /// + /// # Panics + /// + /// This function will panic if `self` is negative. + /// + /// # Examples + /// + /// Basic usage: + /// ``` + /// #![feature(isqrt)] + #[doc = concat!("assert_eq!(10", stringify!($SelfT), ".isqrt(), 3);")] + /// ``` + #[unstable(feature = "isqrt", issue = "116226")] + #[rustc_const_unstable(feature = "isqrt", issue = "116226")] + #[must_use = "this returns the result of the operation, \ + without modifying the original"] + #[inline] + pub const fn isqrt(self) -> Self { + // I would like to implement it as + // ``` + // self.checked_isqrt().expect("argument of integer square root must be non-negative") + // ``` + // but `expect` is not yet stable as a `const fn`. + match self.checked_isqrt() { + Some(sqrt) => sqrt, + None => panic!("argument of integer square root must be non-negative"), + } + } + /// Calculates the quotient of Euclidean division of `self` by `rhs`. /// /// This computes the integer `q` such that `self = q * rhs + r`, with diff --git a/library/core/src/num/uint_macros.rs b/library/core/src/num/uint_macros.rs index 7cbef9e7793..a9c5312a1c0 100644 --- a/library/core/src/num/uint_macros.rs +++ b/library/core/src/num/uint_macros.rs @@ -1995,6 +1995,54 @@ pub const fn pow(self, mut exp: u32) -> Self { acc * base } + /// Returns the square root of the number, rounded down. + /// + /// # Examples + /// + /// Basic usage: + /// ``` + /// #![feature(isqrt)] + #[doc = concat!("assert_eq!(10", stringify!($SelfT), ".isqrt(), 3);")] + /// ``` + #[unstable(feature = "isqrt", issue = "116226")] + #[rustc_const_unstable(feature = "isqrt", issue = "116226")] + #[must_use = "this returns the result of the operation, \ + without modifying the original"] + #[inline] + pub const fn isqrt(self) -> Self { + if self < 2 { + return self; + } + + // The algorithm is based on the one presented in + // + // which cites as source the following C code: + // . + + let mut op = self; + let mut res = 0; + let mut one = 1 << (self.ilog2() & !1); + + while one != 0 { + if op >= res + one { + op -= res + one; + res = (res >> 1) + one; + } else { + res >>= 1; + } + one >>= 2; + } + + // SAFETY: the result is positive and fits in an integer with half as many bits. + // Inform the optimizer about it. + unsafe { + intrinsics::assume(0 < res); + intrinsics::assume(res < 1 << (Self::BITS / 2)); + } + + res + } + /// Performs Euclidean division. /// /// Since, for the positive integers, all common diff --git a/library/core/tests/lib.rs b/library/core/tests/lib.rs index 773f2b955d8..e4003a208bc 100644 --- a/library/core/tests/lib.rs +++ b/library/core/tests/lib.rs @@ -56,6 +56,7 @@ #![feature(min_specialization)] #![feature(numfmt)] #![feature(num_midpoint)] +#![feature(isqrt)] #![feature(step_trait)] #![feature(str_internals)] #![feature(std_internals)] diff --git a/library/core/tests/num/int_macros.rs b/library/core/tests/num/int_macros.rs index 439bbe66997..165d9a29617 100644 --- a/library/core/tests/num/int_macros.rs +++ b/library/core/tests/num/int_macros.rs @@ -290,6 +290,38 @@ fn test_pow() { assert_eq!(r.saturating_pow(0), 1 as $T); } + #[test] + fn test_isqrt() { + assert_eq!($T::MIN.checked_isqrt(), None); + assert_eq!((-1 as $T).checked_isqrt(), None); + assert_eq!((0 as $T).isqrt(), 0 as $T); + assert_eq!((1 as $T).isqrt(), 1 as $T); + assert_eq!((2 as $T).isqrt(), 1 as $T); + assert_eq!((99 as $T).isqrt(), 9 as $T); + assert_eq!((100 as $T).isqrt(), 10 as $T); + } + + #[cfg(not(miri))] // Miri is too slow + #[test] + fn test_lots_of_isqrt() { + let n_max: $T = (1024 * 1024).min($T::MAX as u128) as $T; + for n in 0..=n_max { + let isqrt: $T = n.isqrt(); + + assert!(isqrt.pow(2) <= n); + let (square, overflow) = (isqrt + 1).overflowing_pow(2); + assert!(overflow || square > n); + } + + for n in ($T::MAX - 127)..=$T::MAX { + let isqrt: $T = n.isqrt(); + + assert!(isqrt.pow(2) <= n); + let (square, overflow) = (isqrt + 1).overflowing_pow(2); + assert!(overflow || square > n); + } + } + #[test] fn test_div_floor() { let a: $T = 8; diff --git a/library/core/tests/num/uint_macros.rs b/library/core/tests/num/uint_macros.rs index 7d6203db0b9..955440647eb 100644 --- a/library/core/tests/num/uint_macros.rs +++ b/library/core/tests/num/uint_macros.rs @@ -206,6 +206,35 @@ fn test_pow() { assert_eq!(r.saturating_pow(2), MAX); } + #[test] + fn test_isqrt() { + assert_eq!((0 as $T).isqrt(), 0 as $T); + assert_eq!((1 as $T).isqrt(), 1 as $T); + assert_eq!((2 as $T).isqrt(), 1 as $T); + assert_eq!((99 as $T).isqrt(), 9 as $T); + assert_eq!((100 as $T).isqrt(), 10 as $T); + assert_eq!($T::MAX.isqrt(), (1 << ($T::BITS / 2)) - 1); + } + + #[cfg(not(miri))] // Miri is too slow + #[test] + fn test_lots_of_isqrt() { + let n_max: $T = (1024 * 1024).min($T::MAX as u128) as $T; + for n in 0..=n_max { + let isqrt: $T = n.isqrt(); + + assert!(isqrt.pow(2) <= n); + assert!(isqrt + 1 == (1 as $T) << ($T::BITS / 2) || (isqrt + 1).pow(2) > n); + } + + for n in ($T::MAX - 255)..=$T::MAX { + let isqrt: $T = n.isqrt(); + + assert!(isqrt.pow(2) <= n); + assert!(isqrt + 1 == (1 as $T) << ($T::BITS / 2) || (isqrt + 1).pow(2) > n); + } + } + #[test] fn test_div_floor() { assert_eq!((8 as $T).div_floor(3), 2);