diff --git a/compiler/rustc_mir_transform/src/unreachable_prop.rs b/compiler/rustc_mir_transform/src/unreachable_prop.rs index c7e22dfe842..d25c6d471cd 100644 --- a/compiler/rustc_mir_transform/src/unreachable_prop.rs +++ b/compiler/rustc_mir_transform/src/unreachable_prop.rs @@ -3,9 +3,12 @@ //! post-order traversal of the blocks. use crate::MirPass; -use rustc_data_structures::fx::{FxHashMap, FxHashSet}; +use rustc_data_structures::fx::FxHashSet; +use rustc_middle::mir::interpret::Scalar; +use rustc_middle::mir::patch::MirPatch; use rustc_middle::mir::*; -use rustc_middle::ty::TyCtxt; +use rustc_middle::ty::{self, TyCtxt}; +use rustc_target::abi::Size; pub struct UnreachablePropagation; @@ -20,102 +23,134 @@ impl MirPass<'_> for UnreachablePropagation { } fn run_pass<'tcx>(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let mut patch = MirPatch::new(body); let mut unreachable_blocks = FxHashSet::default(); - let mut replacements = FxHashMap::default(); for (bb, bb_data) in traversal::postorder(body) { let terminator = bb_data.terminator(); - if terminator.kind == TerminatorKind::Unreachable { - unreachable_blocks.insert(bb); - } else { - let is_unreachable = |succ: BasicBlock| unreachable_blocks.contains(&succ); - let terminator_kind_opt = remove_successors(&terminator.kind, is_unreachable); - - if let Some(terminator_kind) = terminator_kind_opt { - if terminator_kind == TerminatorKind::Unreachable { - unreachable_blocks.insert(bb); - } - replacements.insert(bb, terminator_kind); + let is_unreachable = match &terminator.kind { + TerminatorKind::Unreachable => true, + // This will unconditionally run into an unreachable and is therefore unreachable as well. + TerminatorKind::Goto { target } if unreachable_blocks.contains(target) => { + patch.patch_terminator(bb, TerminatorKind::Unreachable); + true } + // Try to remove unreachable targets from the switch. + TerminatorKind::SwitchInt { .. } => { + remove_successors_from_switch(tcx, bb, &unreachable_blocks, body, &mut patch) + } + _ => false, + }; + if is_unreachable { + unreachable_blocks.insert(bb); } } + if !tcx + .consider_optimizing(|| format!("UnreachablePropagation {:?} ", body.source.def_id())) + { + return; + } + + patch.apply(body); + // We do want do keep some unreachable blocks, but make them empty. for bb in unreachable_blocks { - if !tcx.consider_optimizing(|| { - format!("UnreachablePropagation {:?} ", body.source.def_id()) - }) { - break; - } - body.basic_blocks_mut()[bb].statements.clear(); } - - for (bb, terminator_kind) in replacements { - if !tcx.consider_optimizing(|| { - format!("UnreachablePropagation {:?} ", body.source.def_id()) - }) { - break; - } - - body.basic_blocks_mut()[bb].terminator_mut().kind = terminator_kind; - } - - // Do not remove dead blocks, let `SimplifyCfg` do it. } } -fn remove_successors<'tcx, F>( - terminator_kind: &TerminatorKind<'tcx>, - is_unreachable: F, -) -> Option> -where - F: Fn(BasicBlock) -> bool, -{ - let terminator = match terminator_kind { - // This will unconditionally run into an unreachable and is therefore unreachable as well. - TerminatorKind::Goto { target } if is_unreachable(*target) => TerminatorKind::Unreachable, - TerminatorKind::SwitchInt { targets, discr } => { - let otherwise = targets.otherwise(); +/// Return whether the current terminator is fully unreachable. +fn remove_successors_from_switch<'tcx>( + tcx: TyCtxt<'tcx>, + bb: BasicBlock, + unreachable_blocks: &FxHashSet, + body: &Body<'tcx>, + patch: &mut MirPatch<'tcx>, +) -> bool { + let terminator = body.basic_blocks[bb].terminator(); + let TerminatorKind::SwitchInt { discr, targets } = &terminator.kind else { bug!() }; + let source_info = terminator.source_info; + let location = body.terminator_loc(bb); - // If all targets are unreachable, we can be unreachable as well. - if targets.all_targets().iter().all(|bb| is_unreachable(*bb)) { - TerminatorKind::Unreachable - } else if is_unreachable(otherwise) { - // If there are multiple targets, don't delete unreachable branches (like an unreachable otherwise) - // unless otherwise is unreachable, in which case deleting a normal branch causes it to be merged with - // the otherwise, keeping its unreachable. - // This looses information about reachability causing worse codegen. - // For example (see tests/codegen/match-optimizes-away.rs) - // - // pub enum Two { A, B } - // pub fn identity(x: Two) -> Two { - // match x { - // Two::A => Two::A, - // Two::B => Two::B, - // } - // } - // - // This generates a `switchInt() -> [0: 0, 1: 1, otherwise: unreachable]`, which allows us or LLVM to - // turn it into just `x` later. Without the unreachable, such a transformation would be illegal. - // If the otherwise branch is unreachable, we can delete all other unreachable targets, as they will - // still point to the unreachable and therefore not lose reachability information. - let reachable_iter = targets.iter().filter(|(_, bb)| !is_unreachable(*bb)); + let is_unreachable = |bb| unreachable_blocks.contains(&bb); - let new_targets = SwitchTargets::new(reachable_iter, otherwise); + // If there are multiple targets, we want to keep information about reachability for codegen. + // For example (see tests/codegen/match-optimizes-away.rs) + // + // pub enum Two { A, B } + // pub fn identity(x: Two) -> Two { + // match x { + // Two::A => Two::A, + // Two::B => Two::B, + // } + // } + // + // This generates a `switchInt() -> [0: 0, 1: 1, otherwise: unreachable]`, which allows us or LLVM to + // turn it into just `x` later. Without the unreachable, such a transformation would be illegal. + // + // In order to preserve this information, we record reachable and unreachable targets as + // `Assume` statements in MIR. - // No unreachable branches were removed. - if new_targets.all_targets().len() == targets.all_targets().len() { - return None; - } + let discr_ty = discr.ty(body, tcx); + let discr_size = Size::from_bits(match discr_ty.kind() { + ty::Uint(uint) => uint.normalize(tcx.sess.target.pointer_width).bit_width().unwrap(), + ty::Int(int) => int.normalize(tcx.sess.target.pointer_width).bit_width().unwrap(), + ty::Char => 32, + ty::Bool => 1, + other => bug!("unhandled type: {:?}", other), + }); - TerminatorKind::SwitchInt { discr: discr.clone(), targets: new_targets } - } else { - // If the otherwise branch is reachable, we don't want to delete any unreachable branches. - return None; - } - } - _ => return None, + let mut add_assumption = |binop, value| { + let local = patch.new_temp(tcx.types.bool, source_info.span); + let value = Operand::Constant(Box::new(ConstOperand { + span: source_info.span, + user_ty: None, + const_: Const::from_scalar(tcx, Scalar::from_uint(value, discr_size), discr_ty), + })); + let cmp = Rvalue::BinaryOp(binop, Box::new((discr.to_copy(), value))); + patch.add_assign(location, local.into(), cmp); + + let assume = NonDivergingIntrinsic::Assume(Operand::Move(local.into())); + patch.add_statement(location, StatementKind::Intrinsic(Box::new(assume))); }; - Some(terminator) + + let reachable_iter = targets.iter().filter(|&(value, bb)| { + let is_unreachable = is_unreachable(bb); + if is_unreachable { + // We remove this target from the switch, so record the inequality using `Assume`. + add_assumption(BinOp::Ne, value); + false + } else { + true + } + }); + + let otherwise = targets.otherwise(); + let new_targets = SwitchTargets::new(reachable_iter, otherwise); + + let num_targets = new_targets.all_targets().len(); + let otherwise_unreachable = is_unreachable(otherwise); + let fully_unreachable = num_targets == 1 && otherwise_unreachable; + + let terminator = match (num_targets, otherwise_unreachable) { + // If all targets are unreachable, we can be unreachable as well. + (1, true) => TerminatorKind::Unreachable, + (1, false) => TerminatorKind::Goto { target: otherwise }, + (2, true) => { + // All targets are unreachable except one. Record the equality, and make it a goto. + let (value, target) = new_targets.iter().next().unwrap(); + add_assumption(BinOp::Eq, value); + TerminatorKind::Goto { target } + } + _ if num_targets == targets.all_targets().len() => { + // Nothing has changed. + return false; + } + _ => TerminatorKind::SwitchInt { discr: discr.clone(), targets: new_targets }, + }; + + patch.patch_terminator(bb, terminator); + fully_unreachable } diff --git a/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.Inline.panic-abort.diff b/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.Inline.panic-abort.diff index 018b6c1ee95..14a8b22657f 100644 --- a/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.Inline.panic-abort.diff +++ b/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.Inline.panic-abort.diff @@ -16,6 +16,7 @@ + scope 5 (inlined unreachable_unchecked) { + scope 6 { + scope 7 (inlined unreachable_unchecked::runtime) { ++ let _5: !; + } + } + } @@ -31,16 +32,23 @@ - _0 = Option::::unwrap_unchecked(move _2) -> [return: bb1, unwind unreachable]; + StorageLive(_3); + StorageLive(_4); ++ StorageLive(_5); + _4 = discriminant(_2); -+ switchInt(move _4) -> [1: bb2, otherwise: bb1]; ++ switchInt(move _4) -> [0: bb1, 1: bb3, otherwise: bb2]; } bb1: { -+ unreachable; ++ assume(const false); ++ _5 = core::panicking::panic_nounwind(const "unsafe precondition(s) violated: hint::unreachable_unchecked must never be reached") -> unwind unreachable; + } + + bb2: { ++ unreachable; ++ } ++ ++ bb3: { + _0 = move ((_2 as Some).0: T); ++ StorageDead(_5); + StorageDead(_4); + StorageDead(_3); StorageDead(_2); diff --git a/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.Inline.panic-unwind.diff b/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.Inline.panic-unwind.diff index 47845758a3f..a6901ca0892 100644 --- a/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.Inline.panic-unwind.diff +++ b/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.Inline.panic-unwind.diff @@ -16,6 +16,7 @@ + scope 5 (inlined unreachable_unchecked) { + scope 6 { + scope 7 (inlined unreachable_unchecked::runtime) { ++ let _5: !; + } + } + } @@ -31,20 +32,27 @@ - _0 = Option::::unwrap_unchecked(move _2) -> [return: bb1, unwind: bb2]; + StorageLive(_3); + StorageLive(_4); ++ StorageLive(_5); + _4 = discriminant(_2); -+ switchInt(move _4) -> [1: bb2, otherwise: bb1]; ++ switchInt(move _4) -> [0: bb1, 1: bb3, otherwise: bb2]; } bb1: { - StorageDead(_2); - return; -+ unreachable; ++ assume(const false); ++ _5 = core::panicking::panic_nounwind(const "unsafe precondition(s) violated: hint::unreachable_unchecked must never be reached") -> unwind unreachable; } - bb2 (cleanup): { - resume; + bb2: { ++ unreachable; ++ } ++ ++ bb3: { + _0 = move ((_2 as Some).0: T); ++ StorageDead(_5); + StorageDead(_4); + StorageDead(_3); + StorageDead(_2); diff --git a/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.PreCodegen.after.panic-abort.mir b/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.PreCodegen.after.panic-abort.mir index 392f085bd4d..37a2d28e0c4 100644 --- a/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.PreCodegen.after.panic-abort.mir +++ b/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.PreCodegen.after.panic-abort.mir @@ -6,7 +6,7 @@ fn unwrap_unchecked(_1: Option) -> T { scope 1 (inlined #[track_caller] Option::::unwrap_unchecked) { debug self => _1; let mut _2: isize; - let mut _3: &std::option::Option; + let mut _4: &std::option::Option; scope 2 { debug val => _0; } @@ -14,30 +14,36 @@ fn unwrap_unchecked(_1: Option) -> T { scope 5 (inlined unreachable_unchecked) { scope 6 { scope 7 (inlined unreachable_unchecked::runtime) { + let _3: !; } } } } scope 4 (inlined Option::::is_some) { - debug self => _3; + debug self => _4; } } bb0: { - StorageLive(_3); + StorageLive(_4); StorageLive(_2); _2 = discriminant(_1); - switchInt(move _2) -> [1: bb1, otherwise: bb2]; + switchInt(move _2) -> [0: bb1, 1: bb2, otherwise: bb3]; } bb1: { - _0 = move ((_1 as Some).0: T); - StorageDead(_2); - StorageDead(_3); - return; + assume(const false); + _3 = core::panicking::panic_nounwind(const "unsafe precondition(s) violated: hint::unreachable_unchecked must never be reached") -> unwind unreachable; } bb2: { + _0 = move ((_1 as Some).0: T); + StorageDead(_2); + StorageDead(_4); + return; + } + + bb3: { unreachable; } } diff --git a/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.PreCodegen.after.panic-unwind.mir b/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.PreCodegen.after.panic-unwind.mir index 392f085bd4d..37a2d28e0c4 100644 --- a/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.PreCodegen.after.panic-unwind.mir +++ b/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.PreCodegen.after.panic-unwind.mir @@ -6,7 +6,7 @@ fn unwrap_unchecked(_1: Option) -> T { scope 1 (inlined #[track_caller] Option::::unwrap_unchecked) { debug self => _1; let mut _2: isize; - let mut _3: &std::option::Option; + let mut _4: &std::option::Option; scope 2 { debug val => _0; } @@ -14,30 +14,36 @@ fn unwrap_unchecked(_1: Option) -> T { scope 5 (inlined unreachable_unchecked) { scope 6 { scope 7 (inlined unreachable_unchecked::runtime) { + let _3: !; } } } } scope 4 (inlined Option::::is_some) { - debug self => _3; + debug self => _4; } } bb0: { - StorageLive(_3); + StorageLive(_4); StorageLive(_2); _2 = discriminant(_1); - switchInt(move _2) -> [1: bb1, otherwise: bb2]; + switchInt(move _2) -> [0: bb1, 1: bb2, otherwise: bb3]; } bb1: { - _0 = move ((_1 as Some).0: T); - StorageDead(_2); - StorageDead(_3); - return; + assume(const false); + _3 = core::panicking::panic_nounwind(const "unsafe precondition(s) violated: hint::unreachable_unchecked must never be reached") -> unwind unreachable; } bb2: { + _0 = move ((_1 as Some).0: T); + StorageDead(_2); + StorageDead(_4); + return; + } + + bb3: { unreachable; } } diff --git a/tests/mir-opt/pre-codegen/duplicate_switch_targets.ub_if_b.PreCodegen.after.mir b/tests/mir-opt/pre-codegen/duplicate_switch_targets.ub_if_b.PreCodegen.after.mir index 65d71199aae..d5d5253a8b7 100644 --- a/tests/mir-opt/pre-codegen/duplicate_switch_targets.ub_if_b.PreCodegen.after.mir +++ b/tests/mir-opt/pre-codegen/duplicate_switch_targets.ub_if_b.PreCodegen.after.mir @@ -7,13 +7,14 @@ fn ub_if_b(_1: Thing) -> Thing { scope 1 (inlined unreachable_unchecked) { scope 2 { scope 3 (inlined unreachable_unchecked::runtime) { + let _3: !; } } } bb0: { _2 = discriminant(_1); - switchInt(move _2) -> [0: bb1, otherwise: bb2]; + switchInt(move _2) -> [0: bb1, 1: bb2, otherwise: bb3]; } bb1: { @@ -22,6 +23,11 @@ fn ub_if_b(_1: Thing) -> Thing { } bb2: { + assume(const false); + _3 = core::panicking::panic_nounwind(const "unsafe precondition(s) violated: hint::unreachable_unchecked must never be reached") -> unwind unreachable; + } + + bb3: { unreachable; } } diff --git a/tests/mir-opt/uninhabited_enum_branching.main.SimplifyCfg-after-uninhabited-enum-branching.after.mir b/tests/mir-opt/uninhabited_enum_branching.main.SimplifyCfg-after-uninhabited-enum-branching.after.mir index 474f43104bb..1ee44e48c8e 100644 --- a/tests/mir-opt/uninhabited_enum_branching.main.SimplifyCfg-after-uninhabited-enum-branching.after.mir +++ b/tests/mir-opt/uninhabited_enum_branching.main.SimplifyCfg-after-uninhabited-enum-branching.after.mir @@ -11,16 +11,15 @@ fn main() -> () { let mut _7: Test2; let mut _8: isize; let _9: &str; + let mut _10: bool; bb0: { StorageLive(_1); StorageLive(_2); _2 = Test1::C; _3 = discriminant(_2); - switchInt(move _3) -> [2: bb1, otherwise: bb2]; - } - - bb1: { + _10 = Eq(_3, const 2_isize); + assume(move _10); StorageLive(_5); _5 = const "C"; _1 = &(*_5); @@ -31,27 +30,27 @@ fn main() -> () { StorageLive(_7); _7 = Test2::D; _8 = discriminant(_7); - switchInt(move _8) -> [4: bb4, 5: bb3, otherwise: bb2]; + switchInt(move _8) -> [4: bb3, 5: bb2, otherwise: bb1]; } - bb2: { + bb1: { unreachable; } - bb3: { + bb2: { StorageLive(_9); _9 = const "E"; _6 = &(*_9); StorageDead(_9); - goto -> bb5; + goto -> bb4; + } + + bb3: { + _6 = const "D"; + goto -> bb4; } bb4: { - _6 = const "D"; - goto -> bb5; - } - - bb5: { StorageDead(_7); StorageDead(_6); _0 = const (); diff --git a/tests/mir-opt/unreachable.main.UnreachablePropagation.panic-abort.diff b/tests/mir-opt/unreachable.main.UnreachablePropagation.panic-abort.diff index aec22e0328c..6f4faefce14 100644 --- a/tests/mir-opt/unreachable.main.UnreachablePropagation.panic-abort.diff +++ b/tests/mir-opt/unreachable.main.UnreachablePropagation.panic-abort.diff @@ -8,6 +8,8 @@ let _5: (); let mut _6: bool; let mut _7: !; ++ let mut _8: bool; ++ let mut _9: bool; scope 1 { debug _x => _3; let _3: Empty; @@ -24,7 +26,10 @@ bb1: { _2 = discriminant(_1); - switchInt(move _2) -> [1: bb2, otherwise: bb6]; +- switchInt(move _2) -> [1: bb2, otherwise: bb6]; ++ _9 = Ne(_2, const 1_isize); ++ assume(move _9); ++ goto -> bb6; } bb2: { diff --git a/tests/mir-opt/unreachable.main.UnreachablePropagation.panic-unwind.diff b/tests/mir-opt/unreachable.main.UnreachablePropagation.panic-unwind.diff index b2e8ab95d94..5bacb42bed3 100644 --- a/tests/mir-opt/unreachable.main.UnreachablePropagation.panic-unwind.diff +++ b/tests/mir-opt/unreachable.main.UnreachablePropagation.panic-unwind.diff @@ -8,6 +8,8 @@ let _5: (); let mut _6: bool; let mut _7: !; ++ let mut _8: bool; ++ let mut _9: bool; scope 1 { debug _x => _3; let _3: Empty; @@ -24,7 +26,10 @@ bb1: { _2 = discriminant(_1); - switchInt(move _2) -> [1: bb2, otherwise: bb6]; +- switchInt(move _2) -> [1: bb2, otherwise: bb6]; ++ _9 = Ne(_2, const 1_isize); ++ assume(move _9); ++ goto -> bb6; } bb2: { diff --git a/tests/mir-opt/unreachable.rs b/tests/mir-opt/unreachable.rs index 9e8b0600b19..3d934ace261 100644 --- a/tests/mir-opt/unreachable.rs +++ b/tests/mir-opt/unreachable.rs @@ -13,15 +13,17 @@ fn main() { // CHECK: bb0: { // CHECK: {{_.*}} = empty() // CHECK: bb1: { - // CHECK: switchInt({{.*}}) -> [1: bb2, otherwise: bb6]; + // CHECK: [[ne:_.*]] = Ne({{.*}}, const 1_isize); + // CHECK-NEXT: assume(move [[ne]]); + // CHECK-NEXT: goto -> bb6; // CHECK: bb2: { - // CHECK: unreachable; + // CHECK-NEXT: unreachable; // CHECK: bb3: { - // CHECK: unreachable; + // CHECK-NEXT: unreachable; // CHECK: bb4: { - // CHECK: unreachable; + // CHECK-NEXT: unreachable; // CHECK: bb5: { - // CHECK: unreachable; + // CHECK-NEXT: unreachable; // CHECK: bb6: { // CHECK: return; if let Some(_x) = empty() { diff --git a/tests/mir-opt/unreachable_diverging.main.UnreachablePropagation.panic-abort.diff b/tests/mir-opt/unreachable_diverging.main.UnreachablePropagation.panic-abort.diff index 713757ce6e0..11d7924e736 100644 --- a/tests/mir-opt/unreachable_diverging.main.UnreachablePropagation.panic-abort.diff +++ b/tests/mir-opt/unreachable_diverging.main.UnreachablePropagation.panic-abort.diff @@ -9,6 +9,7 @@ let _5: (); let mut _6: bool; let mut _7: !; ++ let mut _8: bool; scope 1 { debug x => _1; scope 2 { @@ -35,7 +36,10 @@ StorageLive(_5); StorageLive(_6); _6 = _1; - switchInt(move _6) -> [0: bb4, otherwise: bb3]; +- switchInt(move _6) -> [0: bb4, otherwise: bb3]; ++ _8 = Ne(_6, const false); ++ assume(move _8); ++ goto -> bb3; } bb3: { diff --git a/tests/mir-opt/unreachable_diverging.main.UnreachablePropagation.panic-unwind.diff b/tests/mir-opt/unreachable_diverging.main.UnreachablePropagation.panic-unwind.diff index a0479fb9130..df6f5609fbf 100644 --- a/tests/mir-opt/unreachable_diverging.main.UnreachablePropagation.panic-unwind.diff +++ b/tests/mir-opt/unreachable_diverging.main.UnreachablePropagation.panic-unwind.diff @@ -9,6 +9,7 @@ let _5: (); let mut _6: bool; let mut _7: !; ++ let mut _8: bool; scope 1 { debug x => _1; scope 2 { @@ -35,7 +36,10 @@ StorageLive(_5); StorageLive(_6); _6 = _1; - switchInt(move _6) -> [0: bb4, otherwise: bb3]; +- switchInt(move _6) -> [0: bb4, otherwise: bb3]; ++ _8 = Ne(_6, const false); ++ assume(move _8); ++ goto -> bb3; } bb3: { diff --git a/tests/mir-opt/unreachable_diverging.rs b/tests/mir-opt/unreachable_diverging.rs index 95ea6cb00f9..b1df6f85262 100644 --- a/tests/mir-opt/unreachable_diverging.rs +++ b/tests/mir-opt/unreachable_diverging.rs @@ -19,7 +19,9 @@ fn main() { // CHECK: bb1: { // CHECK: switchInt({{.*}}) -> [1: bb2, otherwise: bb6]; // CHECK: bb2: { - // CHECK: switchInt({{.*}}) -> [0: bb4, otherwise: bb3]; + // CHECK: [[ne:_.*]] = Ne({{.*}}, const false); + // CHECK: assume(move [[ne]]); + // CHECK: goto -> bb3; // CHECK: bb3: { // CHECK: {{_.*}} = loop_forever() // CHECK: bb4: {