Rollup merge of #81706 - SkiFire13:document-binaryheap-unsafe, r=Mark-Simulacrum
Document BinaryHeap unsafe functions `BinaryHeap` contains some private safe functions but that are actually unsafe to call. This PR marks them `unsafe` and documents all the `unsafe` function calls inside them. While doing this I might also have found a bug: some "SAFETY" comments in `sift_down_range` and `sift_down_to_bottom` are valid only if you assume that `child` doesn't overflow. However it may overflow if `end > isize::MAX` which can be true for ZSTs (but I think only for them). I guess the easiest fix would be to skip any sifting if `mem::size_of::<T> == 0`. Probably conflicts with #81127 but solving the eventual merge conflict should be pretty easy.
This commit is contained in:
commit
56ae3fb2f0
@ -275,7 +275,8 @@ impl<T: Ord + fmt::Debug> fmt::Debug for PeekMut<'_, T> {
|
|||||||
impl<T: Ord> Drop for PeekMut<'_, T> {
|
impl<T: Ord> Drop for PeekMut<'_, T> {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
if self.sift {
|
if self.sift {
|
||||||
self.heap.sift_down(0);
|
// SAFETY: PeekMut is only instantiated for non-empty heaps.
|
||||||
|
unsafe { self.heap.sift_down(0) };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -431,7 +432,8 @@ impl<T: Ord> BinaryHeap<T> {
|
|||||||
self.data.pop().map(|mut item| {
|
self.data.pop().map(|mut item| {
|
||||||
if !self.is_empty() {
|
if !self.is_empty() {
|
||||||
swap(&mut item, &mut self.data[0]);
|
swap(&mut item, &mut self.data[0]);
|
||||||
self.sift_down_to_bottom(0);
|
// SAFETY: !self.is_empty() means that self.len() > 0
|
||||||
|
unsafe { self.sift_down_to_bottom(0) };
|
||||||
}
|
}
|
||||||
item
|
item
|
||||||
})
|
})
|
||||||
@ -473,7 +475,9 @@ impl<T: Ord> BinaryHeap<T> {
|
|||||||
pub fn push(&mut self, item: T) {
|
pub fn push(&mut self, item: T) {
|
||||||
let old_len = self.len();
|
let old_len = self.len();
|
||||||
self.data.push(item);
|
self.data.push(item);
|
||||||
self.sift_up(0, old_len);
|
// SAFETY: Since we pushed a new item it means that
|
||||||
|
// old_len = self.len() - 1 < self.len()
|
||||||
|
unsafe { self.sift_up(0, old_len) };
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Consumes the `BinaryHeap` and returns a vector in sorted
|
/// Consumes the `BinaryHeap` and returns a vector in sorted
|
||||||
@ -506,7 +510,10 @@ impl<T: Ord> BinaryHeap<T> {
|
|||||||
let ptr = self.data.as_mut_ptr();
|
let ptr = self.data.as_mut_ptr();
|
||||||
ptr::swap(ptr, ptr.add(end));
|
ptr::swap(ptr, ptr.add(end));
|
||||||
}
|
}
|
||||||
self.sift_down_range(0, end);
|
// SAFETY: `end` goes from `self.len() - 1` to 1 (both included) so:
|
||||||
|
// 0 < 1 <= end <= self.len() - 1 < self.len()
|
||||||
|
// Which means 0 < end and end < self.len().
|
||||||
|
unsafe { self.sift_down_range(0, end) };
|
||||||
}
|
}
|
||||||
self.into_vec()
|
self.into_vec()
|
||||||
}
|
}
|
||||||
@ -519,47 +526,84 @@ impl<T: Ord> BinaryHeap<T> {
|
|||||||
// the hole is filled back at the end of its scope, even on panic.
|
// the hole is filled back at the end of its scope, even on panic.
|
||||||
// Using a hole reduces the constant factor compared to using swaps,
|
// Using a hole reduces the constant factor compared to using swaps,
|
||||||
// which involves twice as many moves.
|
// which involves twice as many moves.
|
||||||
fn sift_up(&mut self, start: usize, pos: usize) -> usize {
|
|
||||||
unsafe {
|
|
||||||
// Take out the value at `pos` and create a hole.
|
|
||||||
let mut hole = Hole::new(&mut self.data, pos);
|
|
||||||
|
|
||||||
while hole.pos() > start {
|
/// # Safety
|
||||||
let parent = (hole.pos() - 1) / 2;
|
///
|
||||||
if hole.element() <= hole.get(parent) {
|
/// The caller must guarantee that `pos < self.len()`.
|
||||||
break;
|
unsafe fn sift_up(&mut self, start: usize, pos: usize) -> usize {
|
||||||
}
|
// Take out the value at `pos` and create a hole.
|
||||||
hole.move_to(parent);
|
// SAFETY: The caller guarantees that pos < self.len()
|
||||||
|
let mut hole = unsafe { Hole::new(&mut self.data, pos) };
|
||||||
|
|
||||||
|
while hole.pos() > start {
|
||||||
|
let parent = (hole.pos() - 1) / 2;
|
||||||
|
|
||||||
|
// SAFETY: hole.pos() > start >= 0, which means hole.pos() > 0
|
||||||
|
// and so hole.pos() - 1 can't underflow.
|
||||||
|
// This guarantees that parent < hole.pos() so
|
||||||
|
// it's a valid index and also != hole.pos().
|
||||||
|
if hole.element() <= unsafe { hole.get(parent) } {
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
hole.pos()
|
|
||||||
|
// SAFETY: Same as above
|
||||||
|
unsafe { hole.move_to(parent) };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
hole.pos()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Take an element at `pos` and move it down the heap,
|
/// Take an element at `pos` and move it down the heap,
|
||||||
/// while its children are larger.
|
/// while its children are larger.
|
||||||
fn sift_down_range(&mut self, pos: usize, end: usize) {
|
///
|
||||||
unsafe {
|
/// # Safety
|
||||||
let mut hole = Hole::new(&mut self.data, pos);
|
///
|
||||||
let mut child = 2 * pos + 1;
|
/// The caller must guarantee that `pos < end <= self.len()`.
|
||||||
while child < end - 1 {
|
unsafe fn sift_down_range(&mut self, pos: usize, end: usize) {
|
||||||
// compare with the greater of the two children
|
// SAFETY: The caller guarantees that pos < end <= self.len().
|
||||||
child += (hole.get(child) <= hole.get(child + 1)) as usize;
|
let mut hole = unsafe { Hole::new(&mut self.data, pos) };
|
||||||
// if we are already in order, stop.
|
let mut child = 2 * hole.pos() + 1;
|
||||||
if hole.element() >= hole.get(child) {
|
|
||||||
return;
|
// Loop invariant: child == 2 * hole.pos() + 1.
|
||||||
}
|
while child < end - 1 {
|
||||||
hole.move_to(child);
|
// compare with the greater of the two children
|
||||||
child = 2 * hole.pos() + 1;
|
// SAFETY: child < end - 1 < self.len() and
|
||||||
}
|
// child + 1 < end <= self.len(), so they're valid indexes.
|
||||||
if child == end - 1 && hole.element() < hole.get(child) {
|
// child == 2 * hole.pos() + 1 != hole.pos() and
|
||||||
hole.move_to(child);
|
// child + 1 == 2 * hole.pos() + 2 != hole.pos().
|
||||||
|
// FIXME: 2 * hole.pos() + 1 or 2 * hole.pos() + 2 could overflow
|
||||||
|
// if T is a ZST
|
||||||
|
child += unsafe { hole.get(child) <= hole.get(child + 1) } as usize;
|
||||||
|
|
||||||
|
// if we are already in order, stop.
|
||||||
|
// SAFETY: child is now either the old child or the old child+1
|
||||||
|
// We already proven that both are < self.len() and != hole.pos()
|
||||||
|
if hole.element() >= unsafe { hole.get(child) } {
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SAFETY: same as above.
|
||||||
|
unsafe { hole.move_to(child) };
|
||||||
|
child = 2 * hole.pos() + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// SAFETY: && short circuit, which means that in the
|
||||||
|
// second condition it's already true that child == end - 1 < self.len().
|
||||||
|
if child == end - 1 && hole.element() < unsafe { hole.get(child) } {
|
||||||
|
// SAFETY: child is already proven to be a valid index and
|
||||||
|
// child == 2 * hole.pos() + 1 != hole.pos().
|
||||||
|
unsafe { hole.move_to(child) };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sift_down(&mut self, pos: usize) {
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// The caller must guarantee that `pos < self.len()`.
|
||||||
|
unsafe fn sift_down(&mut self, pos: usize) {
|
||||||
let len = self.len();
|
let len = self.len();
|
||||||
self.sift_down_range(pos, len);
|
// SAFETY: pos < len is guaranteed by the caller and
|
||||||
|
// obviously len = self.len() <= self.len().
|
||||||
|
unsafe { self.sift_down_range(pos, len) };
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Take an element at `pos` and move it all the way down the heap,
|
/// Take an element at `pos` and move it all the way down the heap,
|
||||||
@ -567,30 +611,54 @@ impl<T: Ord> BinaryHeap<T> {
|
|||||||
///
|
///
|
||||||
/// Note: This is faster when the element is known to be large / should
|
/// Note: This is faster when the element is known to be large / should
|
||||||
/// be closer to the bottom.
|
/// be closer to the bottom.
|
||||||
fn sift_down_to_bottom(&mut self, mut pos: usize) {
|
///
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// The caller must guarantee that `pos < self.len()`.
|
||||||
|
unsafe fn sift_down_to_bottom(&mut self, mut pos: usize) {
|
||||||
let end = self.len();
|
let end = self.len();
|
||||||
let start = pos;
|
let start = pos;
|
||||||
unsafe {
|
|
||||||
let mut hole = Hole::new(&mut self.data, pos);
|
// SAFETY: The caller guarantees that pos < self.len().
|
||||||
let mut child = 2 * pos + 1;
|
let mut hole = unsafe { Hole::new(&mut self.data, pos) };
|
||||||
while child < end - 1 {
|
let mut child = 2 * hole.pos() + 1;
|
||||||
child += (hole.get(child) <= hole.get(child + 1)) as usize;
|
|
||||||
hole.move_to(child);
|
// Loop invariant: child == 2 * hole.pos() + 1.
|
||||||
child = 2 * hole.pos() + 1;
|
while child < end - 1 {
|
||||||
}
|
// SAFETY: child < end - 1 < self.len() and
|
||||||
if child == end - 1 {
|
// child + 1 < end <= self.len(), so they're valid indexes.
|
||||||
hole.move_to(child);
|
// child == 2 * hole.pos() + 1 != hole.pos() and
|
||||||
}
|
// child + 1 == 2 * hole.pos() + 2 != hole.pos().
|
||||||
pos = hole.pos;
|
// FIXME: 2 * hole.pos() + 1 or 2 * hole.pos() + 2 could overflow
|
||||||
|
// if T is a ZST
|
||||||
|
child += unsafe { hole.get(child) <= hole.get(child + 1) } as usize;
|
||||||
|
|
||||||
|
// SAFETY: Same as above
|
||||||
|
unsafe { hole.move_to(child) };
|
||||||
|
child = 2 * hole.pos() + 1;
|
||||||
}
|
}
|
||||||
self.sift_up(start, pos);
|
|
||||||
|
if child == end - 1 {
|
||||||
|
// SAFETY: child == end - 1 < self.len(), so it's a valid index
|
||||||
|
// and child == 2 * hole.pos() + 1 != hole.pos().
|
||||||
|
unsafe { hole.move_to(child) };
|
||||||
|
}
|
||||||
|
pos = hole.pos();
|
||||||
|
drop(hole);
|
||||||
|
|
||||||
|
// SAFETY: pos is the position in the hole and was already proven
|
||||||
|
// to be a valid index.
|
||||||
|
unsafe { self.sift_up(start, pos) };
|
||||||
}
|
}
|
||||||
|
|
||||||
fn rebuild(&mut self) {
|
fn rebuild(&mut self) {
|
||||||
let mut n = self.len() / 2;
|
let mut n = self.len() / 2;
|
||||||
while n > 0 {
|
while n > 0 {
|
||||||
n -= 1;
|
n -= 1;
|
||||||
self.sift_down(n);
|
// SAFETY: n starts from self.len() / 2 and goes down to 0.
|
||||||
|
// The only case when !(n < self.len()) is if
|
||||||
|
// self.len() == 0, but it's ruled out by the loop condition.
|
||||||
|
unsafe { self.sift_down(n) };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user