optimize initialization checks

This commit is contained in:
Erik Desjardins 2021-08-10 19:29:18 -04:00
parent c9599c4cac
commit 1eaccab24e

View File

@ -1,7 +1,7 @@
//! The virtual memory representation of the MIR interpreter.
use std::borrow::Cow;
use std::convert::TryFrom;
use std::convert::{TryFrom, TryInto};
use std::iter;
use std::ops::{Deref, Range};
use std::ptr;
@ -720,13 +720,12 @@ impl InitMask {
return Err(self.len..end);
}
// FIXME(oli-obk): optimize this for allocations larger than a block.
let idx = (start..end).find(|&i| !self.get(i));
let uninit_start = find_bit(self, start, end, false);
match idx {
Some(idx) => {
let uninit_end = (idx..end).find(|&i| self.get(i)).unwrap_or(end);
Err(idx..uninit_end)
match uninit_start {
Some(uninit_start) => {
let uninit_end = find_bit(self, uninit_start, end, true).unwrap_or(end);
Err(uninit_start..uninit_end)
}
None => Ok(()),
}
@ -863,9 +862,8 @@ impl<'a> Iterator for InitChunkIter<'a> {
}
let is_init = self.init_mask.get(self.start);
// FIXME(oli-obk): optimize this for allocations larger than a block.
let end_of_chunk =
(self.start..self.end).find(|&i| self.init_mask.get(i) != is_init).unwrap_or(self.end);
find_bit(&self.init_mask, self.start, self.end, !is_init).unwrap_or(self.end);
let range = self.start..end_of_chunk;
self.start = end_of_chunk;
@ -874,6 +872,94 @@ impl<'a> Iterator for InitChunkIter<'a> {
}
}
/// Returns the index of the first bit in `start..end` (end-exclusive) that is equal to is_init.
fn find_bit(init_mask: &InitMask, start: Size, end: Size, is_init: bool) -> Option<Size> {
fn find_bit_fast(init_mask: &InitMask, start: Size, end: Size, is_init: bool) -> Option<Size> {
fn search_block(
bits: Block,
block: usize,
start_bit: usize,
is_init: bool,
) -> Option<Size> {
// invert bits so we're always looking for the first set bit
let bits = if is_init { bits } else { !bits };
// mask off unused start bits
let bits = bits & (!0 << start_bit);
// find set bit, if any
if bits == 0 {
None
} else {
let bit = bits.trailing_zeros();
Some(size_from_bit_index(block, bit))
}
}
if start >= end {
return None;
}
let (start_block, start_bit) = bit_index(start);
let (end_block, end_bit) = bit_index(end);
// handle first block: need to skip `start_bit` bits
if let Some(i) =
search_block(init_mask.blocks[start_block], start_block, start_bit, is_init)
{
if i < end {
return Some(i);
} else {
// if the range is less than a block, we may find a matching bit after `end`
return None;
}
}
let one_block_past_the_end = if end_bit > 0 {
// if `end_bit` > 0, then the range overlaps `end_block`
end_block + 1
} else {
end_block
};
// handle remaining blocks
if start_block < one_block_past_the_end {
for (&bits, block) in init_mask.blocks[start_block + 1..one_block_past_the_end]
.iter()
.zip(start_block + 1..)
{
if let Some(i) = search_block(bits, block, 0, is_init) {
if i < end {
return Some(i);
} else {
// if this is the last block, we may find a matching bit after `end`
return None;
}
}
}
}
None
}
#[cfg_attr(not(debug_assertions), allow(dead_code))]
fn find_bit_slow(init_mask: &InitMask, start: Size, end: Size, is_init: bool) -> Option<Size> {
(start..end).find(|&i| init_mask.get(i) == is_init)
}
let result = find_bit_fast(init_mask, start, end, is_init);
debug_assert_eq!(
result,
find_bit_slow(init_mask, start, end, is_init),
"optimized implementation of find_bit is wrong for start={:?} end={:?} is_init={} init_mask={:#?}",
start,
end,
is_init,
init_mask
);
result
}
#[inline]
fn bit_index(bits: Size) -> (usize, usize) {
let bits = bits.bytes();
@ -881,3 +967,10 @@ fn bit_index(bits: Size) -> (usize, usize) {
let b = bits % InitMask::BLOCK_SIZE;
(usize::try_from(a).unwrap(), usize::try_from(b).unwrap())
}
#[inline]
fn size_from_bit_index(block: impl TryInto<u64>, bit: impl TryInto<u64>) -> Size {
let block = block.try_into().ok().unwrap();
let bit = bit.try_into().ok().unwrap();
Size::from_bytes(block * InitMask::BLOCK_SIZE + bit)
}