Simplify some mir passes by using let chains

This commit is contained in:
DaniPopes 2023-10-09 05:22:31 +02:00
parent bf9a1c8a19
commit 47ebffabb8
No known key found for this signature in database
GPG Key ID: 0F09640DDB7AC692
3 changed files with 43 additions and 86 deletions

View File

@ -2,9 +2,8 @@
use crate::MirPass; use crate::MirPass;
use rustc_middle::mir::*; use rustc_middle::mir::*;
use rustc_middle::ty::GenericArgsRef; use rustc_middle::ty::{self, TyCtxt};
use rustc_middle::ty::{self, Ty, TyCtxt}; use rustc_span::symbol::sym;
use rustc_span::symbol::{sym, Symbol};
use rustc_target::abi::{FieldIdx, VariantIdx}; use rustc_target::abi::{FieldIdx, VariantIdx};
pub struct LowerIntrinsics; pub struct LowerIntrinsics;
@ -16,12 +15,10 @@ impl<'tcx> MirPass<'tcx> for LowerIntrinsics {
let terminator = block.terminator.as_mut().unwrap(); let terminator = block.terminator.as_mut().unwrap();
if let TerminatorKind::Call { func, args, destination, target, .. } = if let TerminatorKind::Call { func, args, destination, target, .. } =
&mut terminator.kind &mut terminator.kind
&& let ty::FnDef(def_id, generic_args) = *func.ty(local_decls, tcx).kind()
&& tcx.is_intrinsic(def_id)
{ {
let func_ty = func.ty(local_decls, tcx); let intrinsic_name = tcx.item_name(def_id);
let Some((intrinsic_name, generic_args)) = resolve_rust_intrinsic(tcx, func_ty)
else {
continue;
};
match intrinsic_name { match intrinsic_name {
sym::unreachable => { sym::unreachable => {
terminator.kind = TerminatorKind::Unreachable; terminator.kind = TerminatorKind::Unreachable;
@ -309,15 +306,3 @@ impl<'tcx> MirPass<'tcx> for LowerIntrinsics {
} }
} }
} }
fn resolve_rust_intrinsic<'tcx>(
tcx: TyCtxt<'tcx>,
func_ty: Ty<'tcx>,
) -> Option<(Symbol, GenericArgsRef<'tcx>)> {
if let ty::FnDef(def_id, args) = *func_ty.kind() {
if tcx.is_intrinsic(def_id) {
return Some((tcx.item_name(def_id), args));
}
}
None
}

View File

@ -34,67 +34,44 @@ pub fn lower_slice_len_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
} }
} }
struct SliceLenPatchInformation<'tcx> {
add_statement: Statement<'tcx>,
new_terminator_kind: TerminatorKind<'tcx>,
}
fn lower_slice_len_call<'tcx>( fn lower_slice_len_call<'tcx>(
tcx: TyCtxt<'tcx>, tcx: TyCtxt<'tcx>,
block: &mut BasicBlockData<'tcx>, block: &mut BasicBlockData<'tcx>,
local_decls: &IndexSlice<Local, LocalDecl<'tcx>>, local_decls: &IndexSlice<Local, LocalDecl<'tcx>>,
slice_len_fn_item_def_id: DefId, slice_len_fn_item_def_id: DefId,
) { ) {
let mut patch_found: Option<SliceLenPatchInformation<'_>> = None;
let terminator = block.terminator(); let terminator = block.terminator();
match &terminator.kind { if let TerminatorKind::Call {
TerminatorKind::Call { func,
func, args,
args, destination,
destination, target: Some(bb),
target: Some(bb), call_source: CallSource::Normal,
call_source: CallSource::Normal, ..
.. } = &terminator.kind
} => { // some heuristics for fast rejection
// some heuristics for fast rejection && let [arg] = &args[..]
if args.len() != 1 { && let Some(arg) = arg.place()
return; && let ty::FnDef(fn_def_id, _) = func.ty(local_decls, tcx).kind()
} && *fn_def_id == slice_len_fn_item_def_id
let Some(arg) = args[0].place() else { return }; {
let func_ty = func.ty(local_decls, tcx); // perform modifications from something like:
match func_ty.kind() { // _5 = core::slice::<impl [u8]>::len(move _6) -> bb1
ty::FnDef(fn_def_id, _) if fn_def_id == &slice_len_fn_item_def_id => { // into:
// perform modifications // _5 = Len(*_6)
// from something like `_5 = core::slice::<impl [u8]>::len(move _6) -> bb1` // goto bb1
// into:
// ```
// _5 = Len(*_6)
// goto bb1
// ```
// make new RValue for Len // make new RValue for Len
let deref_arg = tcx.mk_place_deref(arg); let deref_arg = tcx.mk_place_deref(arg);
let r_value = Rvalue::Len(deref_arg); let r_value = Rvalue::Len(deref_arg);
let len_statement_kind = let len_statement_kind =
StatementKind::Assign(Box::new((*destination, r_value))); StatementKind::Assign(Box::new((*destination, r_value)));
let add_statement = let add_statement =
Statement { kind: len_statement_kind, source_info: terminator.source_info }; Statement { kind: len_statement_kind, source_info: terminator.source_info };
// modify terminator into simple Goto // modify terminator into simple Goto
let new_terminator_kind = TerminatorKind::Goto { target: *bb }; let new_terminator_kind = TerminatorKind::Goto { target: *bb };
let patch = SliceLenPatchInformation { add_statement, new_terminator_kind };
patch_found = Some(patch);
}
_ => {}
}
}
_ => {}
}
if let Some(SliceLenPatchInformation { add_statement, new_terminator_kind }) = patch_found {
block.statements.push(add_statement); block.statements.push(add_statement);
block.terminator_mut().kind = new_terminator_kind; block.terminator_mut().kind = new_terminator_kind;
} }

View File

@ -30,22 +30,17 @@ fn get_switched_on_type<'tcx>(
let terminator = block_data.terminator(); let terminator = block_data.terminator();
// Only bother checking blocks which terminate by switching on a local. // Only bother checking blocks which terminate by switching on a local.
if let Some(local) = get_discriminant_local(&terminator.kind) { if let Some(local) = get_discriminant_local(&terminator.kind)
let stmt_before_term = (!block_data.statements.is_empty()) && let [.., stmt_before_term] = &block_data.statements[..]
.then(|| &block_data.statements[block_data.statements.len() - 1].kind); && let StatementKind::Assign(box (l, Rvalue::Discriminant(place))) = stmt_before_term.kind
&& l.as_local() == Some(local)
if let Some(StatementKind::Assign(box (l, Rvalue::Discriminant(place)))) = stmt_before_term && let ty = place.ty(body, tcx).ty
{ && ty.is_enum()
if l.as_local() == Some(local) { {
let ty = place.ty(body, tcx).ty; Some(ty)
if ty.is_enum() { } else {
return Some(ty); None
}
}
}
} }
None
} }
fn variant_discriminants<'tcx>( fn variant_discriminants<'tcx>(