From 4541aa971f94f5b262ac44bbba3ad0561c76477b Mon Sep 17 00:00:00 2001 From: Alexis Bourget Date: Wed, 14 Jul 2021 15:31:12 +0200 Subject: [PATCH] Add safety comments in private core::slice::rotate::ptr_rotate function --- library/core/src/slice/rotate.rs | 56 ++++++++++++++++++++++++++++++-- 1 file changed, 54 insertions(+), 2 deletions(-) diff --git a/library/core/src/slice/rotate.rs b/library/core/src/slice/rotate.rs index a89596b15ef..7528927ef33 100644 --- a/library/core/src/slice/rotate.rs +++ b/library/core/src/slice/rotate.rs @@ -1,5 +1,3 @@ -// ignore-tidy-undocumented-unsafe - use crate::cmp; use crate::mem::{self, MaybeUninit}; use crate::ptr; @@ -79,8 +77,10 @@ pub unsafe fn ptr_rotate(mut left: usize, mut mid: *mut T, mut right: usize) // the way until about `left + right == 32`, but the worst case performance breaks even // around 16. 24 was chosen as middle ground. If the size of `T` is larger than 4 // `usize`s, this algorithm also outperforms other algorithms. + // SAFETY: callers must ensure `mid - left` is valid for reading and writing. let x = unsafe { mid.sub(left) }; // beginning of first round + // SAFETY: see previous comment. let mut tmp: T = unsafe { x.read() }; let mut i = right; // `gcd` can be found before hand by calculating `gcd(left + right, right)`, @@ -92,6 +92,21 @@ pub unsafe fn ptr_rotate(mut left: usize, mut mid: *mut T, mut right: usize) // the very end. This is possibly due to the fact that swapping or replacing temporaries // uses only one memory address in the loop instead of needing to manage two. loop { + // [long-safety-expl] + // SAFETY: callers must ensure `[left, left+mid+right)` are all valid for reading and + // writing. + // + // - `i` start with `right` so `mid-left <= x+i = x+right = mid-left+right < mid+right` + // - `i <= left+right-1` is always true + // - if `i < left`, `right` is added so `i < left+right` and on the next + // iteration `left` is removed from `i` so it doesn't go further + // - if `i >= left`, `left` is removed immediately and so it doesn't go further. + // - overflows cannot happen for `i` since the function's safety contract ask for + // `mid+right-1 = x+left+right` to be valid for writing + // - underflows cannot happen because `i` must be bigger or equal to `left` for + // a substraction of `left` to happen. + // + // So `x+i` is valid for reading and writing if the caller respected the contract tmp = unsafe { x.add(i).replace(tmp) }; // instead of incrementing `i` and then checking if it is outside the bounds, we // check if `i` will go outside the bounds on the next increment. This prevents @@ -100,6 +115,8 @@ pub unsafe fn ptr_rotate(mut left: usize, mut mid: *mut T, mut right: usize) i -= left; if i == 0 { // end of first round + // SAFETY: tmp has been read from a valid source and x is valid for writing + // according to the caller. unsafe { x.write(tmp) }; break; } @@ -113,13 +130,24 @@ pub unsafe fn ptr_rotate(mut left: usize, mut mid: *mut T, mut right: usize) } // finish the chunk with more rounds for start in 1..gcd { + // SAFETY: `gcd` is at most equal to `right` so all values in `1..gcd` are valid for + // reading and writing as per the function's safety contract, see [long-safety-expl] + // above tmp = unsafe { x.add(start).read() }; + // [safety-expl-addition] + // + // Here `start < gcd` so `start < right` so `i < right+right`: `right` being the + // greatest common divisor of `(left+right, right)` means that `left = right` so + // `i < left+right` so `x+i = mid-left+i` is always valid for reading and writing + // according to the function's safety contract. i = start + right; loop { + // SAFETY: see [long-safety-expl] and [safety-expl-addition] tmp = unsafe { x.add(i).replace(tmp) }; if i >= left { i -= left; if i == start { + // SAFETY: see [long-safety-expl] and [safety-expl-addition] unsafe { x.add(start).write(tmp) }; break; } @@ -135,14 +163,30 @@ pub unsafe fn ptr_rotate(mut left: usize, mut mid: *mut T, mut right: usize) // The `[T; 0]` here is to ensure this is appropriately aligned for T let mut rawarray = MaybeUninit::<(BufType, [T; 0])>::uninit(); let buf = rawarray.as_mut_ptr() as *mut T; + // SAFETY: `mid-left <= mid-left+right < mid+right` let dim = unsafe { mid.sub(left).add(right) }; if left <= right { + // SAFETY: + // + // 1) The `else if` condition about the sizes ensures `[mid-left; left]` will fit in + // `buf` without overflow and `buf` was created just above and so cannot be + // overlapped with any value of `[mid-left; left]` + // 2) [mid-left, mid+right) are all valid for reading and writing and we don't care + // about overlaps here. + // 3) The `if` condition about `left <= right` ensures writing `left` elements to + // `dim = mid-left+right` is valid because: + // - `buf` is valid and `left` elements were written in it in 1) + // - `dim+left = mid-left+right+left = mid+right` and we write `[dim, dim+left)` unsafe { + // 1) ptr::copy_nonoverlapping(mid.sub(left), buf, left); + // 2) ptr::copy(mid, mid.sub(left), right); + // 3) ptr::copy_nonoverlapping(buf, dim, left); } } else { + // SAFETY: same reasoning as above but with `left` and `right` reversed unsafe { ptr::copy_nonoverlapping(mid, buf, right); ptr::copy(mid.sub(left), dim, left); @@ -156,6 +200,10 @@ pub unsafe fn ptr_rotate(mut left: usize, mut mid: *mut T, mut right: usize) // of this algorithm would be, and swapping using that last chunk instead of swapping // adjacent chunks like this algorithm is doing, but this way is still faster. loop { + // SAFETY: + // `left >= right` so `[mid-right, mid+right)` is valid for reading and writing + // Substracting `right` from `mid` each turn is counterbalanced by the addition and + // check after it. unsafe { ptr::swap_nonoverlapping(mid.sub(right), mid, right); mid = mid.sub(right); @@ -168,6 +216,10 @@ pub unsafe fn ptr_rotate(mut left: usize, mut mid: *mut T, mut right: usize) } else { // Algorithm 3, `left < right` loop { + // SAFETY: `[mid-left, mid+left)` is valid for reading and writing because + // `left < right` so `mid+left < mid+right`. + // Adding `left` to `mid` each turn is counterbalanced by the substraction and check + // after it. unsafe { ptr::swap_nonoverlapping(mid.sub(left), mid, left); mid = mid.add(left);