Use IndexVec for coroutine local mapping

This commit is contained in:
Liu Dingming 2024-07-04 04:52:01 +08:00
parent 2db4ff40af
commit a9194f30eb
2 changed files with 21 additions and 19 deletions

View File

@ -208,6 +208,11 @@ pub fn get_or_insert_with(&mut self, index: I, value: impl FnOnce() -> T) -> &mu
pub fn remove(&mut self, index: I) -> Option<T> { pub fn remove(&mut self, index: I) -> Option<T> {
self.get_mut(index)?.take() self.get_mut(index)?.take()
} }
#[inline]
pub fn contains(&self, index: I) -> bool {
self.get(index).and_then(Option::as_ref).is_some()
}
} }
impl<I: Idx, T: fmt::Debug> fmt::Debug for IndexVec<I, T> { impl<I: Idx, T: fmt::Debug> fmt::Debug for IndexVec<I, T> {

View File

@ -58,7 +58,7 @@
use crate::errors; use crate::errors;
use crate::pass_manager as pm; use crate::pass_manager as pm;
use crate::simplify; use crate::simplify;
use rustc_data_structures::fx::{FxHashMap, FxHashSet}; use rustc_data_structures::fx::FxHashSet;
use rustc_errors::pluralize; use rustc_errors::pluralize;
use rustc_hir as hir; use rustc_hir as hir;
use rustc_hir::lang_items::LangItem; use rustc_hir::lang_items::LangItem;
@ -236,8 +236,7 @@ struct TransformVisitor<'tcx> {
discr_ty: Ty<'tcx>, discr_ty: Ty<'tcx>,
// Mapping from Local to (type of local, coroutine struct index) // Mapping from Local to (type of local, coroutine struct index)
// FIXME(eddyb) This should use `IndexVec<Local, Option<_>>`. remap: IndexVec<Local, Option<(Ty<'tcx>, VariantIdx, FieldIdx)>>,
remap: FxHashMap<Local, (Ty<'tcx>, VariantIdx, FieldIdx)>,
// A map from a suspension point in a block to the locals which have live storage at that point // A map from a suspension point in a block to the locals which have live storage at that point
storage_liveness: IndexVec<BasicBlock, Option<BitSet<Local>>>, storage_liveness: IndexVec<BasicBlock, Option<BitSet<Local>>>,
@ -485,7 +484,7 @@ fn tcx(&self) -> TyCtxt<'tcx> {
} }
fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) { fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
assert_eq!(self.remap.get(local), None); assert!(!self.remap.contains(*local));
} }
fn visit_place( fn visit_place(
@ -495,7 +494,7 @@ fn visit_place(
_location: Location, _location: Location,
) { ) {
// Replace an Local in the remap with a coroutine struct access // Replace an Local in the remap with a coroutine struct access
if let Some(&(ty, variant_index, idx)) = self.remap.get(&place.local) { if let Some(&Some((ty, variant_index, idx))) = self.remap.get(place.local) {
replace_base(place, self.make_field(variant_index, idx, ty), self.tcx); replace_base(place, self.make_field(variant_index, idx, ty), self.tcx);
} }
} }
@ -504,7 +503,7 @@ fn visit_basic_block_data(&mut self, block: BasicBlock, data: &mut BasicBlockDat
// Remove StorageLive and StorageDead statements for remapped locals // Remove StorageLive and StorageDead statements for remapped locals
data.retain_statements(|s| match s.kind { data.retain_statements(|s| match s.kind {
StatementKind::StorageLive(l) | StatementKind::StorageDead(l) => { StatementKind::StorageLive(l) | StatementKind::StorageDead(l) => {
!self.remap.contains_key(&l) !self.remap.contains(l)
} }
_ => true, _ => true,
}); });
@ -529,13 +528,9 @@ fn visit_basic_block_data(&mut self, block: BasicBlock, data: &mut BasicBlockDat
// The resume arg target location might itself be remapped if its base local is // The resume arg target location might itself be remapped if its base local is
// live across a yield. // live across a yield.
let resume_arg = if let Some(&Some((ty, variant, idx))) = self.remap.get(resume_arg.local) {
if let Some(&(ty, variant, idx)) = self.remap.get(&resume_arg.local) {
replace_base(&mut resume_arg, self.make_field(variant, idx, ty), self.tcx); replace_base(&mut resume_arg, self.make_field(variant, idx, ty), self.tcx);
resume_arg }
} else {
resume_arg
};
let storage_liveness: GrowableBitSet<Local> = let storage_liveness: GrowableBitSet<Local> =
self.storage_liveness[block].clone().unwrap().into(); self.storage_liveness[block].clone().unwrap().into();
@ -543,7 +538,7 @@ fn visit_basic_block_data(&mut self, block: BasicBlock, data: &mut BasicBlockDat
for i in 0..self.always_live_locals.domain_size() { for i in 0..self.always_live_locals.domain_size() {
let l = Local::new(i); let l = Local::new(i);
let needs_storage_dead = storage_liveness.contains(l) let needs_storage_dead = storage_liveness.contains(l)
&& !self.remap.contains_key(&l) && !self.remap.contains(l)
&& !self.always_live_locals.contains(l); && !self.always_live_locals.contains(l);
if needs_storage_dead { if needs_storage_dead {
data.statements data.statements
@ -1037,7 +1032,7 @@ fn compute_layout<'tcx>(
liveness: LivenessInfo, liveness: LivenessInfo,
body: &Body<'tcx>, body: &Body<'tcx>,
) -> ( ) -> (
FxHashMap<Local, (Ty<'tcx>, VariantIdx, FieldIdx)>, IndexVec<Local, Option<(Ty<'tcx>, VariantIdx, FieldIdx)>>,
CoroutineLayout<'tcx>, CoroutineLayout<'tcx>,
IndexVec<BasicBlock, Option<BitSet<Local>>>, IndexVec<BasicBlock, Option<BitSet<Local>>>,
) { ) {
@ -1098,7 +1093,7 @@ fn compute_layout<'tcx>(
// Create a map from local indices to coroutine struct indices. // Create a map from local indices to coroutine struct indices.
let mut variant_fields: IndexVec<VariantIdx, IndexVec<FieldIdx, CoroutineSavedLocal>> = let mut variant_fields: IndexVec<VariantIdx, IndexVec<FieldIdx, CoroutineSavedLocal>> =
iter::repeat(IndexVec::new()).take(RESERVED_VARIANTS).collect(); iter::repeat(IndexVec::new()).take(RESERVED_VARIANTS).collect();
let mut remap = FxHashMap::default(); let mut remap = IndexVec::from_elem_n(None, saved_locals.domain_size());
for (suspension_point_idx, live_locals) in live_locals_at_suspension_points.iter().enumerate() { for (suspension_point_idx, live_locals) in live_locals_at_suspension_points.iter().enumerate() {
let variant_index = VariantIdx::from(RESERVED_VARIANTS + suspension_point_idx); let variant_index = VariantIdx::from(RESERVED_VARIANTS + suspension_point_idx);
let mut fields = IndexVec::new(); let mut fields = IndexVec::new();
@ -1109,7 +1104,7 @@ fn compute_layout<'tcx>(
// around inside coroutines, so it doesn't matter which variant // around inside coroutines, so it doesn't matter which variant
// index we access them by. // index we access them by.
let idx = FieldIdx::from_usize(idx); let idx = FieldIdx::from_usize(idx);
remap.entry(locals[saved_local]).or_insert((tys[saved_local].ty, variant_index, idx)); remap[locals[saved_local]] = Some((tys[saved_local].ty, variant_index, idx));
} }
variant_fields.push(fields); variant_fields.push(fields);
variant_source_info.push(source_info_at_suspension_points[suspension_point_idx]); variant_source_info.push(source_info_at_suspension_points[suspension_point_idx]);
@ -1121,7 +1116,9 @@ fn compute_layout<'tcx>(
for var in &body.var_debug_info { for var in &body.var_debug_info {
let VarDebugInfoContents::Place(place) = &var.value else { continue }; let VarDebugInfoContents::Place(place) = &var.value else { continue };
let Some(local) = place.as_local() else { continue }; let Some(local) = place.as_local() else { continue };
let Some(&(_, variant, field)) = remap.get(&local) else { continue }; let Some(&Some((_, variant, field))) = remap.get(local) else {
continue;
};
let saved_local = variant_fields[variant][field]; let saved_local = variant_fields[variant][field];
field_names.get_or_insert_with(saved_local, || var.name); field_names.get_or_insert_with(saved_local, || var.name);
@ -1524,7 +1521,7 @@ fn create_cases<'tcx>(
for i in 0..(body.local_decls.len()) { for i in 0..(body.local_decls.len()) {
let l = Local::new(i); let l = Local::new(i);
let needs_storage_live = point.storage_liveness.contains(l) let needs_storage_live = point.storage_liveness.contains(l)
&& !transform.remap.contains_key(&l) && !transform.remap.contains(l)
&& !transform.always_live_locals.contains(l); && !transform.always_live_locals.contains(l);
if needs_storage_live { if needs_storage_live {
statements statements