Implement SHA256 SIMD intrinsics on x86

It'd be useful to be able to verify code implementing SHA256 using SIMD
since such code is a bit more complicated and at some points requires
use of pointers. Until now `miri` didn't support x86 SHA256 intrinsics.
This commit implements them.
This commit is contained in:
Martin Habovstiak 2024-07-17 16:11:32 +02:00
parent b3736d687a
commit 728876ea98
3 changed files with 497 additions and 0 deletions

View File

@ -15,6 +15,7 @@
mod avx;
mod avx2;
mod bmi;
mod sha;
mod sse;
mod sse2;
mod sse3;
@ -105,6 +106,11 @@ fn emulate_x86_intrinsic(
this, link_name, abi, args, dest,
);
}
name if name.starts_with("sha") => {
return sha::EvalContextExt::emulate_x86_sha_intrinsic(
this, link_name, abi, args, dest,
);
}
name if name.starts_with("sse.") => {
return sse::EvalContextExt::emulate_x86_sse_intrinsic(
this, link_name, abi, args, dest,

View File

@ -0,0 +1,221 @@
//! Implements sha256 SIMD instructions of x86 targets
//!
//! The functions that actually compute SHA256 were copied from [RustCrypto's sha256 module].
//!
//! [RustCrypto's sha256 module]: https://github.com/RustCrypto/hashes/blob/6be8466247e936c415d8aafb848697f39894a386/sha2/src/sha256/soft.rs
use rustc_span::Symbol;
use rustc_target::spec::abi::Abi;
use crate::*;
impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
fn emulate_x86_sha_intrinsic(
&mut self,
link_name: Symbol,
abi: Abi,
args: &[OpTy<'tcx>],
dest: &MPlaceTy<'tcx>,
) -> InterpResult<'tcx, EmulateItemResult> {
let this = self.eval_context_mut();
this.expect_target_feature_for_intrinsic(link_name, "sha")?;
// Prefix should have already been checked.
let unprefixed_name = link_name.as_str().strip_prefix("llvm.x86.sha").unwrap();
fn read<'c>(this: &mut MiriInterpCx<'c>, reg: &MPlaceTy<'c>) -> InterpResult<'c, [u32; 4]> {
let mut res = [0; 4];
// We reverse the order because x86 is little endian but the copied implementation uses
// big endian.
for (i, dst) in res.iter_mut().rev().enumerate() {
let projected = &this.project_index(reg, i.try_into().unwrap())?;
*dst = this.read_scalar(projected)?.to_u32()?
}
Ok(res)
}
fn write<'c>(
this: &mut MiriInterpCx<'c>,
dest: &MPlaceTy<'c>,
val: [u32; 4],
) -> InterpResult<'c, ()> {
// We reverse the order because x86 is little endian but the copied implementation uses
// big endian.
for (i, part) in val.into_iter().rev().enumerate() {
let projected = &this.project_index(dest, i.try_into().unwrap())?;
this.write_scalar(Scalar::from_u32(part), projected)?;
}
Ok(())
}
match unprefixed_name {
// Used to implement the _mm_sha256rnds2_epu32 function.
"256rnds2" => {
let [a, b, k] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
let (a_reg, a_len) = this.operand_to_simd(a)?;
let (b_reg, b_len) = this.operand_to_simd(b)?;
let (k_reg, k_len) = this.operand_to_simd(k)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
assert_eq!(a_len, 4);
assert_eq!(b_len, 4);
assert_eq!(k_len, 4);
assert_eq!(dest_len, 4);
let a = read(this, &a_reg)?;
let b = read(this, &b_reg)?;
let k = read(this, &k_reg)?;
let result = sha256_digest_round_x2(a, b, k);
write(this, &dest, result)?;
}
// Used to implement the _mm_sha256msg1_epu32 function.
"256msg1" => {
let [a, b] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
let (a_reg, a_len) = this.operand_to_simd(a)?;
let (b_reg, b_len) = this.operand_to_simd(b)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
assert_eq!(a_len, 4);
assert_eq!(b_len, 4);
assert_eq!(dest_len, 4);
let a = read(this, &a_reg)?;
let b = read(this, &b_reg)?;
let result = sha256msg1(a, b);
write(this, &dest, result)?;
}
// Used to implement the _mm_sha256msg2_epu32 function.
"256msg2" => {
let [a, b] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
let (a_reg, a_len) = this.operand_to_simd(a)?;
let (b_reg, b_len) = this.operand_to_simd(b)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
assert_eq!(a_len, 4);
assert_eq!(b_len, 4);
assert_eq!(dest_len, 4);
let a = read(this, &a_reg)?;
let b = read(this, &b_reg)?;
let result = sha256msg2(a, b);
write(this, &dest, result)?;
}
_ => return Ok(EmulateItemResult::NotSupported),
}
Ok(EmulateItemResult::NeedsReturn)
}
}
#[inline(always)]
fn shr(v: [u32; 4], o: u32) -> [u32; 4] {
[v[0] >> o, v[1] >> o, v[2] >> o, v[3] >> o]
}
#[inline(always)]
fn shl(v: [u32; 4], o: u32) -> [u32; 4] {
[v[0] << o, v[1] << o, v[2] << o, v[3] << o]
}
#[inline(always)]
fn or(a: [u32; 4], b: [u32; 4]) -> [u32; 4] {
[a[0] | b[0], a[1] | b[1], a[2] | b[2], a[3] | b[3]]
}
#[inline(always)]
fn xor(a: [u32; 4], b: [u32; 4]) -> [u32; 4] {
[a[0] ^ b[0], a[1] ^ b[1], a[2] ^ b[2], a[3] ^ b[3]]
}
#[inline(always)]
fn add(a: [u32; 4], b: [u32; 4]) -> [u32; 4] {
[
a[0].wrapping_add(b[0]),
a[1].wrapping_add(b[1]),
a[2].wrapping_add(b[2]),
a[3].wrapping_add(b[3]),
]
}
fn sha256load(v2: [u32; 4], v3: [u32; 4]) -> [u32; 4] {
[v3[3], v2[0], v2[1], v2[2]]
}
fn sha256_digest_round_x2(cdgh: [u32; 4], abef: [u32; 4], wk: [u32; 4]) -> [u32; 4] {
macro_rules! big_sigma0 {
($a:expr) => {
($a.rotate_right(2) ^ $a.rotate_right(13) ^ $a.rotate_right(22))
};
}
macro_rules! big_sigma1 {
($a:expr) => {
($a.rotate_right(6) ^ $a.rotate_right(11) ^ $a.rotate_right(25))
};
}
macro_rules! bool3ary_202 {
($a:expr, $b:expr, $c:expr) => {
$c ^ ($a & ($b ^ $c))
};
} // Choose, MD5F, SHA1C
macro_rules! bool3ary_232 {
($a:expr, $b:expr, $c:expr) => {
($a & $b) ^ ($a & $c) ^ ($b & $c)
};
} // Majority, SHA1M
let [_, _, wk1, wk0] = wk;
let [a0, b0, e0, f0] = abef;
let [c0, d0, g0, h0] = cdgh;
// a round
let x0 =
big_sigma1!(e0).wrapping_add(bool3ary_202!(e0, f0, g0)).wrapping_add(wk0).wrapping_add(h0);
let y0 = big_sigma0!(a0).wrapping_add(bool3ary_232!(a0, b0, c0));
let (a1, b1, c1, d1, e1, f1, g1, h1) =
(x0.wrapping_add(y0), a0, b0, c0, x0.wrapping_add(d0), e0, f0, g0);
// a round
let x1 =
big_sigma1!(e1).wrapping_add(bool3ary_202!(e1, f1, g1)).wrapping_add(wk1).wrapping_add(h1);
let y1 = big_sigma0!(a1).wrapping_add(bool3ary_232!(a1, b1, c1));
let (a2, b2, _, _, e2, f2, _, _) =
(x1.wrapping_add(y1), a1, b1, c1, x1.wrapping_add(d1), e1, f1, g1);
[a2, b2, e2, f2]
}
fn sha256msg1(v0: [u32; 4], v1: [u32; 4]) -> [u32; 4] {
// sigma 0 on vectors
#[inline]
fn sigma0x4(x: [u32; 4]) -> [u32; 4] {
let t1 = or(shr(x, 7), shl(x, 25));
let t2 = or(shr(x, 18), shl(x, 14));
let t3 = shr(x, 3);
xor(xor(t1, t2), t3)
}
add(v0, sigma0x4(sha256load(v0, v1)))
}
fn sha256msg2(v4: [u32; 4], v3: [u32; 4]) -> [u32; 4] {
macro_rules! sigma1 {
($a:expr) => {
$a.rotate_right(17) ^ $a.rotate_right(19) ^ ($a >> 10)
};
}
let [x3, x2, x1, x0] = v4;
let [w15, w14, _, _] = v3;
let w16 = x0.wrapping_add(sigma1!(w14));
let w17 = x1.wrapping_add(sigma1!(w15));
let w18 = x2.wrapping_add(sigma1!(w16));
let w19 = x3.wrapping_add(sigma1!(w17));
[w19, w18, w17, w16]
}

View File

@ -0,0 +1,270 @@
// Ignore everything except x86 and x86_64
// Any new targets that are added to CI should be ignored here.
// (We cannot use `cfg`-based tricks here since the `target-feature` flags below only work on x86.)
//@ignore-target-aarch64
//@ignore-target-arm
//@ignore-target-avr
//@ignore-target-s390x
//@ignore-target-thumbv7em
//@ignore-target-wasm32
//@compile-flags: -C target-feature=+sha,+sse2,+ssse3,+sse4.1
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
macro_rules! rounds4 {
($abef:ident, $cdgh:ident, $rest:expr, $i:expr) => {{
let k = K32X4[$i];
let kv = _mm_set_epi32(k[0] as i32, k[1] as i32, k[2] as i32, k[3] as i32);
let t1 = _mm_add_epi32($rest, kv);
$cdgh = _mm_sha256rnds2_epu32($cdgh, $abef, t1);
let t2 = _mm_shuffle_epi32(t1, 0x0E);
$abef = _mm_sha256rnds2_epu32($abef, $cdgh, t2);
}};
}
macro_rules! schedule_rounds4 {
(
$abef:ident, $cdgh:ident,
$w0:expr, $w1:expr, $w2:expr, $w3:expr, $w4:expr,
$i: expr
) => {{
$w4 = schedule($w0, $w1, $w2, $w3);
rounds4!($abef, $cdgh, $w4, $i);
}};
}
fn main() {
assert!(is_x86_feature_detected!("sha"));
assert!(is_x86_feature_detected!("sse2"));
assert!(is_x86_feature_detected!("ssse3"));
assert!(is_x86_feature_detected!("sse4.1"));
unsafe {
test_sha256rnds2();
test_sha256msg1();
test_sha256msg2();
test_sha256();
}
}
#[target_feature(enable = "sha,sse2,ssse3,sse4.1")]
unsafe fn test_sha256rnds2() {
let test_vectors = [
(
[0x3c6ef372, 0xa54ff53a, 0x1f83d9ab, 0x5be0cd19],
[0x6a09e667, 0xbb67ae85, 0x510e527f, 0x9b05688c],
[0x592340c6, 0x17386142, 0x91a0b7b1, 0x94ffa30c],
[0xeef39c6c, 0x4e7dfbc1, 0x467a98f3, 0xeb3d5616],
),
(
[0x6a09e667, 0xbb67ae85, 0x510e527f, 0x9b05688c],
[0xeef39c6c, 0x4e7dfbc1, 0x467a98f3, 0xeb3d5616],
[0x91a0b7b1, 0x94ffa30c, 0x592340c6, 0x17386142],
[0x7e7f3c9d, 0x78db9a20, 0xd82fe6ed, 0xaf1f2704],
),
(
[0xeef39c6c, 0x4e7dfbc1, 0x467a98f3, 0xeb3d5616],
[0x7e7f3c9d, 0x78db9a20, 0xd82fe6ed, 0xaf1f2704],
[0x1a89c3f6, 0xf3b6e817, 0x7a5a8511, 0x8bcc35cf],
[0xc9292f7e, 0x49137bd9, 0x7e5f9e08, 0xd10f9247],
),
];
for (cdgh, abef, wk, expected) in test_vectors {
let output_reg = _mm_sha256rnds2_epu32(set_arr(cdgh), set_arr(abef), set_arr(wk));
let mut output = [0u32; 4];
_mm_storeu_si128(output.as_mut_ptr().cast(), output_reg);
// The values are stored as little endian, so we need to reverse them
output.reverse();
assert_eq!(output, expected);
}
}
#[target_feature(enable = "sha,sse2,ssse3,sse4.1")]
unsafe fn test_sha256msg1() {
let test_vectors = [
(
[0x6f6d6521, 0x61776573, 0x20697320, 0x52757374],
[0x6f6d6521, 0x61776573, 0x20697320, 0x52757374],
[0x2da4b536, 0x77f29328, 0x541a4d59, 0x6afb680c],
),
(
[0x6f6d6521, 0x61776573, 0x20697320, 0x52757374],
[0x6f6d6521, 0x61776573, 0x20697320, 0x52757374],
[0x2da4b536, 0x77f29328, 0x541a4d59, 0x6afb680c],
),
(
[0x6f6d6521, 0x61776573, 0x20697320, 0x52757374],
[0x6f6d6521, 0x61776573, 0x20697320, 0x52757374],
[0x2da4b536, 0x77f29328, 0x541a4d59, 0x6afb680c],
),
];
for (v0, v1, expected) in test_vectors {
let output_reg = _mm_sha256msg1_epu32(set_arr(v0), set_arr(v1));
let mut output = [0u32; 4];
_mm_storeu_si128(output.as_mut_ptr().cast(), output_reg);
// The values are stored as little endian, so we need to reverse them
output.reverse();
assert_eq!(output, expected);
}
}
#[target_feature(enable = "sha,sse2,ssse3,sse4.1")]
unsafe fn test_sha256msg2() {
let test_vectors = [
(
[0x801a28aa, 0xe75ff849, 0xb591b2cc, 0x8b64db2c],
[0x6f6d6521, 0x61776573, 0x20697320, 0x52757374],
[0xe7c46c4e, 0x8ce92ccc, 0xd3c0f3ce, 0xe9745c78],
),
(
[0x171911ae, 0xe75ff849, 0xb591b2cc, 0x8b64db2c],
[0xe7c46c4e, 0x8ce92ccc, 0xd3c0f3ce, 0xe9745c78],
[0xc17c6ea3, 0xc4d10083, 0x712910cd, 0x3f41c8ce],
),
(
[0x6ce67e04, 0x5fb6ff76, 0xe1037a25, 0x3ebc5bda],
[0xc17c6ea3, 0xc4d10083, 0x712910cd, 0x3f41c8ce],
[0xf5ab4eff, 0x83d732a5, 0x9bb941af, 0xdf1d0a8c],
),
];
for (v4, v3, expected) in test_vectors {
let output_reg = _mm_sha256msg2_epu32(set_arr(v4), set_arr(v3));
let mut output = [0u32; 4];
_mm_storeu_si128(output.as_mut_ptr().cast(), output_reg);
// The values are stored as little endian, so we need to reverse them
output.reverse();
assert_eq!(output, expected);
}
}
#[target_feature(enable = "sha,sse2,ssse3,sse4.1")]
unsafe fn set_arr(x: [u32; 4]) -> __m128i {
_mm_set_epi32(x[0] as i32, x[1] as i32, x[2] as i32, x[3] as i32)
}
#[target_feature(enable = "sha,sse2,ssse3,sse4.1")]
unsafe fn test_sha256() {
use std::fmt::Write;
/// The initial state of the hash engine.
const INITIAL_STATE: [u32; 8] = [
0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab,
0x5be0cd19,
];
// We don't want to bother with hash finalization algorithm so we just feed constant data.
// This is the content that's being hashed - you can feed it to sha256sum and it'll output
// the same hash (beware of newlines though).
let first_block = *b"Rust is awesome!Rust is awesome!Rust is awesome!Rust is awesome!";
// sha256 is fianlized by appending 0x80, then zeros and finally the data lenght at the
// end.
let mut final_block = [0; 64];
final_block[0] = 0x80;
final_block[(64 - 8)..].copy_from_slice(&(8u64 * 64).to_be_bytes());
let mut state = INITIAL_STATE;
digest_blocks(&mut state, &[first_block, final_block]);
// We compare strings because it's easier to check the hex and the output of panic.
let mut hash = String::new();
for chunk in &state {
write!(hash, "{:08x}", chunk).expect("writing to String doesn't fail");
}
assert_eq!(hash, "1b2293d21b17a0cb0c18737307c37333dea775eded18cefed45e50389f9f8184");
}
// Almost full SHA256 implementation copied from RustCrypto's sha2 crate
// https://github.com/RustCrypto/hashes/blob/6be8466247e936c415d8aafb848697f39894a386/sha2/src/sha256/x86.rs
#[target_feature(enable = "sha,sse2,ssse3,sse4.1")]
unsafe fn schedule(v0: __m128i, v1: __m128i, v2: __m128i, v3: __m128i) -> __m128i {
let t1 = _mm_sha256msg1_epu32(v0, v1);
let t2 = _mm_alignr_epi8(v3, v2, 4);
let t3 = _mm_add_epi32(t1, t2);
_mm_sha256msg2_epu32(t3, v3)
}
// we use unaligned loads with `__m128i` pointers
#[allow(clippy::cast_ptr_alignment)]
#[target_feature(enable = "sha,sse2,ssse3,sse4.1")]
unsafe fn digest_blocks(state: &mut [u32; 8], blocks: &[[u8; 64]]) {
#[allow(non_snake_case)]
let MASK: __m128i =
_mm_set_epi64x(0x0C0D_0E0F_0809_0A0Bu64 as i64, 0x0405_0607_0001_0203u64 as i64);
let state_ptr: *const __m128i = state.as_ptr().cast();
let dcba = _mm_loadu_si128(state_ptr.add(0));
let efgh = _mm_loadu_si128(state_ptr.add(1));
let cdab = _mm_shuffle_epi32(dcba, 0xB1);
let efgh = _mm_shuffle_epi32(efgh, 0x1B);
let mut abef = _mm_alignr_epi8(cdab, efgh, 8);
let mut cdgh = _mm_blend_epi16(efgh, cdab, 0xF0);
for block in blocks {
let abef_save = abef;
let cdgh_save = cdgh;
let block_ptr: *const __m128i = block.as_ptr().cast();
let mut w0 = _mm_shuffle_epi8(_mm_loadu_si128(block_ptr.add(0)), MASK);
let mut w1 = _mm_shuffle_epi8(_mm_loadu_si128(block_ptr.add(1)), MASK);
let mut w2 = _mm_shuffle_epi8(_mm_loadu_si128(block_ptr.add(2)), MASK);
let mut w3 = _mm_shuffle_epi8(_mm_loadu_si128(block_ptr.add(3)), MASK);
let mut w4;
rounds4!(abef, cdgh, w0, 0);
rounds4!(abef, cdgh, w1, 1);
rounds4!(abef, cdgh, w2, 2);
rounds4!(abef, cdgh, w3, 3);
schedule_rounds4!(abef, cdgh, w0, w1, w2, w3, w4, 4);
schedule_rounds4!(abef, cdgh, w1, w2, w3, w4, w0, 5);
schedule_rounds4!(abef, cdgh, w2, w3, w4, w0, w1, 6);
schedule_rounds4!(abef, cdgh, w3, w4, w0, w1, w2, 7);
schedule_rounds4!(abef, cdgh, w4, w0, w1, w2, w3, 8);
schedule_rounds4!(abef, cdgh, w0, w1, w2, w3, w4, 9);
schedule_rounds4!(abef, cdgh, w1, w2, w3, w4, w0, 10);
schedule_rounds4!(abef, cdgh, w2, w3, w4, w0, w1, 11);
schedule_rounds4!(abef, cdgh, w3, w4, w0, w1, w2, 12);
schedule_rounds4!(abef, cdgh, w4, w0, w1, w2, w3, 13);
schedule_rounds4!(abef, cdgh, w0, w1, w2, w3, w4, 14);
schedule_rounds4!(abef, cdgh, w1, w2, w3, w4, w0, 15);
abef = _mm_add_epi32(abef, abef_save);
cdgh = _mm_add_epi32(cdgh, cdgh_save);
}
let feba = _mm_shuffle_epi32(abef, 0x1B);
let dchg = _mm_shuffle_epi32(cdgh, 0xB1);
let dcba = _mm_blend_epi16(feba, dchg, 0xF0);
let hgef = _mm_alignr_epi8(dchg, feba, 8);
let state_ptr_mut: *mut __m128i = state.as_mut_ptr().cast();
_mm_storeu_si128(state_ptr_mut.add(0), dcba);
_mm_storeu_si128(state_ptr_mut.add(1), hgef);
}
/// Swapped round constants for SHA-256 family of digests
pub static K32X4: [[u32; 4]; 16] = {
let mut res = [[0u32; 4]; 16];
let mut i = 0;
while i < 16 {
res[i] = [K32[4 * i + 3], K32[4 * i + 2], K32[4 * i + 1], K32[4 * i]];
i += 1;
}
res
};
/// Round constants for SHA-256 family of digests
pub static K32: [u32; 64] = [
0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2,
];