Rollup merge of #81706 - SkiFire13:document-binaryheap-unsafe, r=Mark-Simulacrum
Document BinaryHeap unsafe functions `BinaryHeap` contains some private safe functions but that are actually unsafe to call. This PR marks them `unsafe` and documents all the `unsafe` function calls inside them. While doing this I might also have found a bug: some "SAFETY" comments in `sift_down_range` and `sift_down_to_bottom` are valid only if you assume that `child` doesn't overflow. However it may overflow if `end > isize::MAX` which can be true for ZSTs (but I think only for them). I guess the easiest fix would be to skip any sifting if `mem::size_of::<T> == 0`. Probably conflicts with #81127 but solving the eventual merge conflict should be pretty easy.
This commit is contained in:
commit
56ae3fb2f0
1 changed files with 116 additions and 48 deletions
|
@ -275,7 +275,8 @@ impl<T: Ord + fmt::Debug> fmt::Debug for PeekMut<'_, T> {
|
|||
impl<T: Ord> Drop for PeekMut<'_, T> {
|
||||
fn drop(&mut self) {
|
||||
if self.sift {
|
||||
self.heap.sift_down(0);
|
||||
// SAFETY: PeekMut is only instantiated for non-empty heaps.
|
||||
unsafe { self.heap.sift_down(0) };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -431,7 +432,8 @@ impl<T: Ord> BinaryHeap<T> {
|
|||
self.data.pop().map(|mut item| {
|
||||
if !self.is_empty() {
|
||||
swap(&mut item, &mut self.data[0]);
|
||||
self.sift_down_to_bottom(0);
|
||||
// SAFETY: !self.is_empty() means that self.len() > 0
|
||||
unsafe { self.sift_down_to_bottom(0) };
|
||||
}
|
||||
item
|
||||
})
|
||||
|
@ -473,7 +475,9 @@ impl<T: Ord> BinaryHeap<T> {
|
|||
pub fn push(&mut self, item: T) {
|
||||
let old_len = self.len();
|
||||
self.data.push(item);
|
||||
self.sift_up(0, old_len);
|
||||
// SAFETY: Since we pushed a new item it means that
|
||||
// old_len = self.len() - 1 < self.len()
|
||||
unsafe { self.sift_up(0, old_len) };
|
||||
}
|
||||
|
||||
/// Consumes the `BinaryHeap` and returns a vector in sorted
|
||||
|
@ -506,7 +510,10 @@ impl<T: Ord> BinaryHeap<T> {
|
|||
let ptr = self.data.as_mut_ptr();
|
||||
ptr::swap(ptr, ptr.add(end));
|
||||
}
|
||||
self.sift_down_range(0, end);
|
||||
// SAFETY: `end` goes from `self.len() - 1` to 1 (both included) so:
|
||||
// 0 < 1 <= end <= self.len() - 1 < self.len()
|
||||
// Which means 0 < end and end < self.len().
|
||||
unsafe { self.sift_down_range(0, end) };
|
||||
}
|
||||
self.into_vec()
|
||||
}
|
||||
|
@ -519,47 +526,84 @@ impl<T: Ord> BinaryHeap<T> {
|
|||
// the hole is filled back at the end of its scope, even on panic.
|
||||
// Using a hole reduces the constant factor compared to using swaps,
|
||||
// which involves twice as many moves.
|
||||
fn sift_up(&mut self, start: usize, pos: usize) -> usize {
|
||||
unsafe {
|
||||
// Take out the value at `pos` and create a hole.
|
||||
let mut hole = Hole::new(&mut self.data, pos);
|
||||
|
||||
while hole.pos() > start {
|
||||
let parent = (hole.pos() - 1) / 2;
|
||||
if hole.element() <= hole.get(parent) {
|
||||
break;
|
||||
}
|
||||
hole.move_to(parent);
|
||||
/// # Safety
|
||||
///
|
||||
/// The caller must guarantee that `pos < self.len()`.
|
||||
unsafe fn sift_up(&mut self, start: usize, pos: usize) -> usize {
|
||||
// Take out the value at `pos` and create a hole.
|
||||
// SAFETY: The caller guarantees that pos < self.len()
|
||||
let mut hole = unsafe { Hole::new(&mut self.data, pos) };
|
||||
|
||||
while hole.pos() > start {
|
||||
let parent = (hole.pos() - 1) / 2;
|
||||
|
||||
// SAFETY: hole.pos() > start >= 0, which means hole.pos() > 0
|
||||
// and so hole.pos() - 1 can't underflow.
|
||||
// This guarantees that parent < hole.pos() so
|
||||
// it's a valid index and also != hole.pos().
|
||||
if hole.element() <= unsafe { hole.get(parent) } {
|
||||
break;
|
||||
}
|
||||
hole.pos()
|
||||
|
||||
// SAFETY: Same as above
|
||||
unsafe { hole.move_to(parent) };
|
||||
}
|
||||
|
||||
hole.pos()
|
||||
}
|
||||
|
||||
/// Take an element at `pos` and move it down the heap,
|
||||
/// while its children are larger.
|
||||
fn sift_down_range(&mut self, pos: usize, end: usize) {
|
||||
unsafe {
|
||||
let mut hole = Hole::new(&mut self.data, pos);
|
||||
let mut child = 2 * pos + 1;
|
||||
while child < end - 1 {
|
||||
// compare with the greater of the two children
|
||||
child += (hole.get(child) <= hole.get(child + 1)) as usize;
|
||||
// if we are already in order, stop.
|
||||
if hole.element() >= hole.get(child) {
|
||||
return;
|
||||
}
|
||||
hole.move_to(child);
|
||||
child = 2 * hole.pos() + 1;
|
||||
}
|
||||
if child == end - 1 && hole.element() < hole.get(child) {
|
||||
hole.move_to(child);
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The caller must guarantee that `pos < end <= self.len()`.
|
||||
unsafe fn sift_down_range(&mut self, pos: usize, end: usize) {
|
||||
// SAFETY: The caller guarantees that pos < end <= self.len().
|
||||
let mut hole = unsafe { Hole::new(&mut self.data, pos) };
|
||||
let mut child = 2 * hole.pos() + 1;
|
||||
|
||||
// Loop invariant: child == 2 * hole.pos() + 1.
|
||||
while child < end - 1 {
|
||||
// compare with the greater of the two children
|
||||
// SAFETY: child < end - 1 < self.len() and
|
||||
// child + 1 < end <= self.len(), so they're valid indexes.
|
||||
// child == 2 * hole.pos() + 1 != hole.pos() and
|
||||
// child + 1 == 2 * hole.pos() + 2 != hole.pos().
|
||||
// FIXME: 2 * hole.pos() + 1 or 2 * hole.pos() + 2 could overflow
|
||||
// if T is a ZST
|
||||
child += unsafe { hole.get(child) <= hole.get(child + 1) } as usize;
|
||||
|
||||
// if we are already in order, stop.
|
||||
// SAFETY: child is now either the old child or the old child+1
|
||||
// We already proven that both are < self.len() and != hole.pos()
|
||||
if hole.element() >= unsafe { hole.get(child) } {
|
||||
return;
|
||||
}
|
||||
|
||||
// SAFETY: same as above.
|
||||
unsafe { hole.move_to(child) };
|
||||
child = 2 * hole.pos() + 1;
|
||||
}
|
||||
|
||||
// SAFETY: && short circuit, which means that in the
|
||||
// second condition it's already true that child == end - 1 < self.len().
|
||||
if child == end - 1 && hole.element() < unsafe { hole.get(child) } {
|
||||
// SAFETY: child is already proven to be a valid index and
|
||||
// child == 2 * hole.pos() + 1 != hole.pos().
|
||||
unsafe { hole.move_to(child) };
|
||||
}
|
||||
}
|
||||
|
||||
fn sift_down(&mut self, pos: usize) {
|
||||
/// # Safety
|
||||
///
|
||||
/// The caller must guarantee that `pos < self.len()`.
|
||||
unsafe fn sift_down(&mut self, pos: usize) {
|
||||
let len = self.len();
|
||||
self.sift_down_range(pos, len);
|
||||
// SAFETY: pos < len is guaranteed by the caller and
|
||||
// obviously len = self.len() <= self.len().
|
||||
unsafe { self.sift_down_range(pos, len) };
|
||||
}
|
||||
|
||||
/// Take an element at `pos` and move it all the way down the heap,
|
||||
|
@ -567,30 +611,54 @@ impl<T: Ord> BinaryHeap<T> {
|
|||
///
|
||||
/// Note: This is faster when the element is known to be large / should
|
||||
/// be closer to the bottom.
|
||||
fn sift_down_to_bottom(&mut self, mut pos: usize) {
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The caller must guarantee that `pos < self.len()`.
|
||||
unsafe fn sift_down_to_bottom(&mut self, mut pos: usize) {
|
||||
let end = self.len();
|
||||
let start = pos;
|
||||
unsafe {
|
||||
let mut hole = Hole::new(&mut self.data, pos);
|
||||
let mut child = 2 * pos + 1;
|
||||
while child < end - 1 {
|
||||
child += (hole.get(child) <= hole.get(child + 1)) as usize;
|
||||
hole.move_to(child);
|
||||
child = 2 * hole.pos() + 1;
|
||||
}
|
||||
if child == end - 1 {
|
||||
hole.move_to(child);
|
||||
}
|
||||
pos = hole.pos;
|
||||
|
||||
// SAFETY: The caller guarantees that pos < self.len().
|
||||
let mut hole = unsafe { Hole::new(&mut self.data, pos) };
|
||||
let mut child = 2 * hole.pos() + 1;
|
||||
|
||||
// Loop invariant: child == 2 * hole.pos() + 1.
|
||||
while child < end - 1 {
|
||||
// SAFETY: child < end - 1 < self.len() and
|
||||
// child + 1 < end <= self.len(), so they're valid indexes.
|
||||
// child == 2 * hole.pos() + 1 != hole.pos() and
|
||||
// child + 1 == 2 * hole.pos() + 2 != hole.pos().
|
||||
// FIXME: 2 * hole.pos() + 1 or 2 * hole.pos() + 2 could overflow
|
||||
// if T is a ZST
|
||||
child += unsafe { hole.get(child) <= hole.get(child + 1) } as usize;
|
||||
|
||||
// SAFETY: Same as above
|
||||
unsafe { hole.move_to(child) };
|
||||
child = 2 * hole.pos() + 1;
|
||||
}
|
||||
self.sift_up(start, pos);
|
||||
|
||||
if child == end - 1 {
|
||||
// SAFETY: child == end - 1 < self.len(), so it's a valid index
|
||||
// and child == 2 * hole.pos() + 1 != hole.pos().
|
||||
unsafe { hole.move_to(child) };
|
||||
}
|
||||
pos = hole.pos();
|
||||
drop(hole);
|
||||
|
||||
// SAFETY: pos is the position in the hole and was already proven
|
||||
// to be a valid index.
|
||||
unsafe { self.sift_up(start, pos) };
|
||||
}
|
||||
|
||||
fn rebuild(&mut self) {
|
||||
let mut n = self.len() / 2;
|
||||
while n > 0 {
|
||||
n -= 1;
|
||||
self.sift_down(n);
|
||||
// SAFETY: n starts from self.len() / 2 and goes down to 0.
|
||||
// The only case when !(n < self.len()) is if
|
||||
// self.len() == 0, but it's ruled out by the loop condition.
|
||||
unsafe { self.sift_down(n) };
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue