Rollup merge of #136690 - Voultapher:use-more-explicit-and-reliable-ptr-select, r=thomcc

Use more explicit and reliable ptr select in sort impls

Using `if ...` with the intent to avoid branches can be surprising to readers and carries the risk of turning into jumps/branches generated by some future compiler version, breaking crucial optimizations.

This commit replaces their usage with the explicit and IR annotated `bool::select_unpredictable`.
This commit is contained in:
Matthias Krüger 2025-02-19 21:16:09 +01:00 committed by GitHub
commit 59d2b102b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -387,7 +387,7 @@ unsafe fn swap_if_less<T, F>(v_base: *mut T, a_pos: usize, b_pos: usize, is_less
where
F: FnMut(&T, &T) -> bool,
{
// SAFETY: the caller must guarantee that `a` and `b` each added to `v_base` yield valid
// SAFETY: the caller must guarantee that `a_pos` and `b_pos` each added to `v_base` yield valid
// pointers into `v_base`, and are properly aligned, and part of the same allocation.
unsafe {
let v_a = v_base.add(a_pos);
@ -404,16 +404,16 @@ where
// The equivalent code with a branch would be:
//
// if should_swap {
// ptr::swap(left, right, 1);
// ptr::swap(v_a, v_b, 1);
// }
// The goal is to generate cmov instructions here.
let left_swap = if should_swap { v_b } else { v_a };
let right_swap = if should_swap { v_a } else { v_b };
let v_a_swap = should_swap.select_unpredictable(v_b, v_a);
let v_b_swap = should_swap.select_unpredictable(v_a, v_b);
let right_swap_tmp = ManuallyDrop::new(ptr::read(right_swap));
ptr::copy(left_swap, v_a, 1);
ptr::copy_nonoverlapping(&*right_swap_tmp, v_b, 1);
let v_b_swap_tmp = ManuallyDrop::new(ptr::read(v_b_swap));
ptr::copy(v_a_swap, v_a, 1);
ptr::copy_nonoverlapping(&*v_b_swap_tmp, v_b, 1);
}
}
@ -640,26 +640,21 @@ pub unsafe fn sort4_stable<T, F: FnMut(&T, &T) -> bool>(
// 1, 1 | c b a d
let c3 = is_less(&*c, &*a);
let c4 = is_less(&*d, &*b);
let min = select(c3, c, a);
let max = select(c4, b, d);
let unknown_left = select(c3, a, select(c4, c, b));
let unknown_right = select(c4, d, select(c3, b, c));
let min = c3.select_unpredictable(c, a);
let max = c4.select_unpredictable(b, d);
let unknown_left = c3.select_unpredictable(a, c4.select_unpredictable(c, b));
let unknown_right = c4.select_unpredictable(d, c3.select_unpredictable(b, c));
// Sort the last two unknown elements.
let c5 = is_less(&*unknown_right, &*unknown_left);
let lo = select(c5, unknown_right, unknown_left);
let hi = select(c5, unknown_left, unknown_right);
let lo = c5.select_unpredictable(unknown_right, unknown_left);
let hi = c5.select_unpredictable(unknown_left, unknown_right);
ptr::copy_nonoverlapping(min, dst, 1);
ptr::copy_nonoverlapping(lo, dst.add(1), 1);
ptr::copy_nonoverlapping(hi, dst.add(2), 1);
ptr::copy_nonoverlapping(max, dst.add(3), 1);
}
#[inline(always)]
fn select<T>(cond: bool, if_true: *const T, if_false: *const T) -> *const T {
if cond { if_true } else { if_false }
}
}
/// SAFETY: The caller MUST guarantee that `v_base` is valid for 8 reads and