From 94e1413f60a1725b994f986a607a6f370bb230b4 Mon Sep 17 00:00:00 2001 From: Alexander Date: Wed, 6 Oct 2021 17:31:35 +0200 Subject: [PATCH] reset and cleanup --- compiler/rustc_mir_transform/src/lib.rs | 2 + .../src/normalize_array_len.rs | 287 ++++++++++++++++++ src/test/mir-opt/lower_array_len.rs | 50 +++ 3 files changed, 339 insertions(+) create mode 100644 compiler/rustc_mir_transform/src/normalize_array_len.rs create mode 100644 src/test/mir-opt/lower_array_len.rs diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs index 0ca640cd7b1..9b11c8f0b24 100644 --- a/compiler/rustc_mir_transform/src/lib.rs +++ b/compiler/rustc_mir_transform/src/lib.rs @@ -58,6 +58,7 @@ mod lower_intrinsics; mod lower_slice_len; mod match_branches; mod multiple_return_terminators; +mod normalize_array_len; mod nrvo; mod remove_noop_landing_pads; mod remove_storage_markers; @@ -488,6 +489,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { // machine than on MIR with async primitives. let optimizations_with_generators: &[&dyn MirPass<'tcx>] = &[ &lower_slice_len::LowerSliceLenCalls, // has to be done before inlining, otherwise actual call will be almost always inlined. Also simple, so can just do first + &normalize_array_len::NormalizeArrayLen, // has to run after `slice::len` lowering &unreachable_prop::UnreachablePropagation, &uninhabited_enum_branching::UninhabitedEnumBranching, &simplify::SimplifyCfg::new("after-uninhabited-enum-branching"), diff --git a/compiler/rustc_mir_transform/src/normalize_array_len.rs b/compiler/rustc_mir_transform/src/normalize_array_len.rs new file mode 100644 index 00000000000..e6ec7171a47 --- /dev/null +++ b/compiler/rustc_mir_transform/src/normalize_array_len.rs @@ -0,0 +1,287 @@ +//! This pass eliminates casting of arrays into slices when their length +//! is taken using `.len()` method. Handy to preserve information in MIR for const prop + +use crate::transform::MirPass; +use rustc_data_structures::fx::FxIndexMap; +use rustc_index::bit_set::BitSet; +use rustc_index::vec::IndexVec; +use rustc_middle::mir::*; +use rustc_middle::ty::{self, TyCtxt}; + +const MAX_NUM_BLOCKS: usize = 800; +const MAX_NUM_LOCALS: usize = 3000; + +pub struct NormalizeArrayLen; + +impl<'tcx> MirPass<'tcx> for NormalizeArrayLen { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + // if tcx.sess.mir_opt_level() < 3 { + // return; + // } + + // early returns for edge cases of highly unrolled functions + if body.basic_blocks().len() > MAX_NUM_BLOCKS { + return; + } + if body.local_decls().len() > MAX_NUM_LOCALS { + return; + } + normalize_array_len_calls(tcx, body) + } +} + +pub fn normalize_array_len_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let (basic_blocks, local_decls) = body.basic_blocks_and_local_decls_mut(); + + // do a preliminary analysis to see if we ever have locals of type `[T;N]` or `&[T;N]` + let mut interesting_locals = BitSet::new_empty(local_decls.len()); + for (local, decl) in local_decls.iter_enumerated() { + match decl.ty.kind() { + ty::Array(..) => { + interesting_locals.insert(local); + } + ty::Ref(.., ty, Mutability::Not) => match ty.kind() { + ty::Array(..) => { + interesting_locals.insert(local); + } + _ => {} + }, + _ => {} + } + } + if interesting_locals.is_empty() { + // we have found nothing to analyze + return; + } + let num_intesting_locals = interesting_locals.count(); + let mut state = FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default()); + let mut patches_scratchpad = + FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default()); + let mut replacements_scratchpad = + FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default()); + for block in basic_blocks { + // make length calls for arrays [T; N] not to decay into length calls for &[T] + // that forbids constant propagation + normalize_array_len_call( + tcx, + block, + local_decls, + &interesting_locals, + &mut state, + &mut patches_scratchpad, + &mut replacements_scratchpad, + ); + state.clear(); + patches_scratchpad.clear(); + replacements_scratchpad.clear(); + } +} + +struct Patcher<'a, 'tcx> { + tcx: TyCtxt<'tcx>, + patches_scratchpad: &'a FxIndexMap, + replacements_scratchpad: &'a mut FxIndexMap, + local_decls: &'a mut IndexVec>, + statement_idx: usize, +} + +impl<'a, 'tcx> Patcher<'a, 'tcx> { + fn patch_expand_statement( + &mut self, + statement: &mut Statement<'tcx>, + ) -> Option>> { + let idx = self.statement_idx; + if let Some(len_statemnt_idx) = self.patches_scratchpad.get(&idx).copied() { + let mut statements = Vec::with_capacity(2); + + // we are at statement that performs a cast. The only sound way is + // to create another local that performs a similar copy without a cast and then + // use this copy in the Len operation + + match &statement.kind { + StatementKind::Assign(box ( + .., + Rvalue::Cast( + CastKind::Pointer(ty::adjustment::PointerCast::Unsize), + operand, + _, + ), + )) => { + match operand { + Operand::Copy(place) | Operand::Move(place) => { + // create new local + let ty = operand.ty(self.local_decls, self.tcx); + let local_decl = + LocalDecl::with_source_info(ty, statement.source_info.clone()); + let local = self.local_decls.push(local_decl); + // make it live + let mut make_live_statement = statement.clone(); + make_live_statement.kind = StatementKind::StorageLive(local); + statements.push(make_live_statement); + // copy into it + + let operand = Operand::Copy(*place); + let mut make_copy_statement = statement.clone(); + let assign_to = Place::from(local); + let rvalue = Rvalue::Use(operand); + make_copy_statement.kind = + StatementKind::Assign(box (assign_to, rvalue)); + statements.push(make_copy_statement); + + // to reorder we have to copy and make NOP + statements.push(statement.clone()); + statement.make_nop(); + + self.replacements_scratchpad.insert(len_statemnt_idx, local); + } + _ => { + unreachable!("it's a bug in the implementation") + } + } + } + _ => { + unreachable!("it's a bug in the implementation") + } + } + + self.statement_idx += 1; + + Some(statements.into_iter()) + } else if let Some(local) = self.replacements_scratchpad.get(&idx).copied() { + let mut statements = Vec::with_capacity(2); + + match &statement.kind { + StatementKind::Assign(box (into, Rvalue::Len(place))) => { + let add_deref = if let Some(..) = place.as_local() { + false + } else if let Some(..) = place.local_or_deref_local() { + true + } else { + unreachable!("it's a bug in the implementation") + }; + // replace len statement + let mut len_statement = statement.clone(); + let mut place = Place::from(local); + if add_deref { + place = self.tcx.mk_place_deref(place); + } + len_statement.kind = StatementKind::Assign(box (*into, Rvalue::Len(place))); + statements.push(len_statement); + + // make temporary dead + let mut make_dead_statement = statement.clone(); + make_dead_statement.kind = StatementKind::StorageDead(local); + statements.push(make_dead_statement); + + // make original statement NOP + statement.make_nop(); + } + _ => { + unreachable!("it's a bug in the implementation") + } + } + + self.statement_idx += 1; + + Some(statements.into_iter()) + } else { + self.statement_idx += 1; + None + } + } +} + +fn normalize_array_len_call<'tcx>( + tcx: TyCtxt<'tcx>, + block: &mut BasicBlockData<'tcx>, + local_decls: &mut IndexVec>, + interesting_locals: &BitSet, + state: &mut FxIndexMap, + patches_scratchpad: &mut FxIndexMap, + replacements_scratchpad: &mut FxIndexMap, +) { + for (statement_idx, statement) in block.statements.iter_mut().enumerate() { + match &mut statement.kind { + StatementKind::Assign(box (place, rvalue)) => { + match rvalue { + Rvalue::Cast( + CastKind::Pointer(ty::adjustment::PointerCast::Unsize), + operand, + cast_ty, + ) => { + let local = if let Some(local) = place.as_local() { local } else { return }; + match operand { + Operand::Copy(place) | Operand::Move(place) => { + let operand_local = + if let Some(local) = place.local_or_deref_local() { + local + } else { + return; + }; + if !interesting_locals.contains(operand_local) { + return; + } + let operand_ty = local_decls[operand_local].ty; + match (operand_ty.kind(), cast_ty.kind()) { + (ty::Array(of_ty_src, ..), ty::Slice(of_ty_dst)) => { + if of_ty_src == of_ty_dst { + // this is a cast from [T; N] into [T], so we are good + state.insert(local, statement_idx); + } + } + // current way of patching doesn't allow to work with `mut` + ( + ty::Ref( + ty::RegionKind::ReErased, + operand_ty, + Mutability::Not, + ), + ty::Ref(ty::RegionKind::ReErased, cast_ty, Mutability::Not), + ) => { + match (operand_ty.kind(), cast_ty.kind()) { + // current way of patching doesn't allow to work with `mut` + (ty::Array(of_ty_src, ..), ty::Slice(of_ty_dst)) => { + if of_ty_src == of_ty_dst { + // this is a cast from [T; N] into [T], so we are good + state.insert(local, statement_idx); + } + } + _ => {} + } + } + _ => {} + } + } + _ => {} + } + } + Rvalue::Len(place) => { + let local = if let Some(local) = place.local_or_deref_local() { + local + } else { + return; + }; + if let Some(cast_statement_idx) = state.get(&local).copied() { + patches_scratchpad.insert(cast_statement_idx, statement_idx); + } + } + _ => { + // invalidate + state.remove(&place.local); + } + } + } + _ => {} + } + } + + let mut patcher = Patcher { + tcx, + patches_scratchpad: &*patches_scratchpad, + replacements_scratchpad, + local_decls, + statement_idx: 0, + }; + + block.expand_statements(|st| patcher.patch_expand_statement(st)); +} diff --git a/src/test/mir-opt/lower_array_len.rs b/src/test/mir-opt/lower_array_len.rs new file mode 100644 index 00000000000..2b4845bc151 --- /dev/null +++ b/src/test/mir-opt/lower_array_len.rs @@ -0,0 +1,50 @@ + +Viewed +@@ -0,0 +1,47 @@ +// compile-flags: -Z mir-opt-level=3 + +// EMIT_MIR lower_array_len.array_bound.NormalizeArrayLen.diff +// EMIT_MIR lower_array_len.array_bound.SimplifyLocals.diff +// EMIT_MIR lower_array_len.array_bound.InstCombine.diff +pub fn array_bound(index: usize, slice: &[u8; N]) -> u8 { + if index < slice.len() { + slice[index] + } else { + 42 + } +} + +// EMIT_MIR lower_array_len.array_bound_mut.NormalizeArrayLen.diff +// EMIT_MIR lower_array_len.array_bound_mut.SimplifyLocals.diff +// EMIT_MIR lower_array_len.array_bound_mut.InstCombine.diff +pub fn array_bound_mut(index: usize, slice: &mut [u8; N]) -> u8 { + if index < slice.len() { + slice[index] + } else { + slice[0] = 42; + + 42 + } +} + +// EMIT_MIR lower_array_len.array_len.NormalizeArrayLen.diff +// EMIT_MIR lower_array_len.array_len.SimplifyLocals.diff +// EMIT_MIR lower_array_len.array_len.InstCombine.diff +pub fn array_len(arr: &[u8; N]) -> usize { + arr.len() +} + +// EMIT_MIR lower_array_len.array_len_by_value.NormalizeArrayLen.diff +// EMIT_MIR lower_array_len.array_len_by_value.SimplifyLocals.diff +// EMIT_MIR lower_array_len.array_len_by_value.InstCombine.diff +pub fn array_len_by_value(arr: [u8; N]) -> usize { + arr.len() +} + +fn main() { + let _ = array_bound(3, &[0, 1, 2, 3]); + let mut tmp = [0, 1, 2, 3, 4]; + let _ = array_bound_mut(3, &mut [0, 1, 2, 3]); + let _ = array_len(&[0]); + let _ = array_len_by_value([0, 2]); +} \ No newline at end of file