diff --git a/src/librustc_mir/transform/match_branches.rs b/src/librustc_mir/transform/match_branches.rs new file mode 100644 index 00000000000..74da6d5e629 --- /dev/null +++ b/src/librustc_mir/transform/match_branches.rs @@ -0,0 +1,93 @@ +use crate::transform::{MirPass, MirSource}; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; + +pub struct MatchBranchSimplification; + +// What's the intent of this pass? +// If one block is found that switches between blocks which both go to the same place +// AND both of these blocks set a similar const in their -> +// condense into 1 block based on discriminant AND goto the destination afterwards + +impl<'tcx> MirPass<'tcx> for MatchBranchSimplification { + fn run_pass(&self, tcx: TyCtxt<'tcx>, src: MirSource<'tcx>, body: &mut Body<'tcx>) { + let param_env = tcx.param_env(src.def_id()); + let bbs = body.basic_blocks_mut(); + 'outer: for bb_idx in bbs.indices() { + let (discr, val, switch_ty, first, second) = match bbs[bb_idx].terminator().kind { + TerminatorKind::SwitchInt { + discr: Operand::Move(ref place), + switch_ty, + ref targets, + ref values, + .. + } if targets.len() == 2 && values.len() == 1 => { + (place, values[0], switch_ty, targets[0], targets[1]) + } + // Only optimize switch int statements + _ => continue, + }; + + // Check that destinations are identical, and if not, then don't optimize this block + if &bbs[first].terminator().kind != &bbs[second].terminator().kind { + continue; + } + + // Check that blocks are assignments of consts to the same place or same statement, + // and match up 1-1, if not don't optimize this block. + let first_stmts = &bbs[first].statements; + let scnd_stmts = &bbs[second].statements; + if first_stmts.len() != scnd_stmts.len() { + continue; + } + for (f, s) in first_stmts.iter().zip(scnd_stmts.iter()) { + match (&f.kind, &s.kind) { + // If two statements are exactly the same just ignore them. + (f_s, s_s) if f_s == s_s => (), + + ( + StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))), + StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), + ) if lhs_f == lhs_s => { + if let Some(f_c) = f_c.literal.try_eval_bool(tcx, param_env) { + // This should also be a bool because it's writing to the same place + let s_c = s_c.literal.try_eval_bool(tcx, param_env).unwrap(); + if f_c != s_c { + // have to check this here because f_c & s_c might have + // different spans. + continue; + } + } + continue 'outer; + } + // If there are not exclusively assignments, then ignore this + _ => continue 'outer, + } + } + // Take owenership of items now that we know we can optimize. + let discr = discr.clone(); + let (from, first) = bbs.pick2_mut(bb_idx, first); + + let new_stmts = first.statements.iter().cloned().map(|mut s| { + if let StatementKind::Assign(box (_, ref mut rhs)) = s.kind { + if let Rvalue::Use(Operand::Constant(c)) = rhs { + let size = tcx.layout_of(param_env.and(switch_ty)).unwrap().size; + let const_cmp = Operand::const_from_scalar( + tcx, + switch_ty, + crate::interpret::Scalar::from_uint(val, size), + rustc_span::DUMMY_SP, + ); + if let Some(c) = c.literal.try_eval_bool(tcx, param_env) { + let op = if c { BinOp::Eq } else { BinOp::Ne }; + *rhs = Rvalue::BinaryOp(op, Operand::Move(discr), const_cmp); + } + } + } + s + }); + from.statements.extend(new_stmts); + from.terminator_mut().kind = first.terminator().kind.clone(); + } + } +} diff --git a/src/librustc_mir/transform/mod.rs b/src/librustc_mir/transform/mod.rs index 3803ee78fd4..4f26f3bb459 100644 --- a/src/librustc_mir/transform/mod.rs +++ b/src/librustc_mir/transform/mod.rs @@ -29,6 +29,7 @@ pub mod generator; pub mod inline; pub mod instcombine; pub mod instrument_coverage; +pub mod match_branches; pub mod no_landing_pads; pub mod nrvo; pub mod promote_consts; @@ -440,6 +441,7 @@ fn run_optimization_passes<'tcx>( // with async primitives. &generator::StateTransform, &instcombine::InstCombine, + &match_branches::MatchBranchSimplification, &const_prop::ConstProp, &simplify_branches::SimplifyBranches::new("after-const-prop"), &simplify_try::SimplifyArmIdentity, diff --git a/src/test/mir-opt/matches_reduce_branches.foo.MatchBranchSimplification.diff.32bit b/src/test/mir-opt/matches_reduce_branches.foo.MatchBranchSimplification.diff.32bit new file mode 100644 index 00000000000..df94c897e92 --- /dev/null +++ b/src/test/mir-opt/matches_reduce_branches.foo.MatchBranchSimplification.diff.32bit @@ -0,0 +1,66 @@ +- // MIR for `foo` before MatchBranchSimplification ++ // MIR for `foo` after MatchBranchSimplification + + fn foo(_1: std::option::Option<()>) -> () { + debug bar => _1; // in scope 0 at $DIR/matches_reduce_branches.rs:4:8: 4:11 + let mut _0: (); // return place in scope 0 at $DIR/matches_reduce_branches.rs:4:25: 4:25 + let mut _2: bool; // in scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL + let mut _3: isize; // in scope 0 at $DIR/matches_reduce_branches.rs:5:22: 5:26 + + bb0: { + StorageLive(_2); // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL + _3 = discriminant(_1); // scope 0 at $DIR/matches_reduce_branches.rs:5:22: 5:26 +- switchInt(move _3) -> [0_isize: bb2, otherwise: bb1]; // scope 0 at $DIR/matches_reduce_branches.rs:5:22: 5:26 ++ _2 = Eq(move _3, const 0_isize); // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL ++ // ty::Const ++ // + ty: isize ++ // + val: Value(Scalar(0x00000000)) ++ // mir::Constant ++ // + span: $DIR/matches_reduce_branches.rs:1:1: 1:1 ++ // + literal: Const { ty: isize, val: Value(Scalar(0x00000000)) } ++ goto -> bb3; // scope 0 at $DIR/matches_reduce_branches.rs:5:22: 5:26 + } + + bb1: { + _2 = const false; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL + // ty::Const + // + ty: bool + // + val: Value(Scalar(0x00)) + // mir::Constant + // + span: $SRC_DIR/core/src/macros/mod.rs:LL:COL + // + literal: Const { ty: bool, val: Value(Scalar(0x00)) } + goto -> bb3; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL + } + + bb2: { + _2 = const true; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL + // ty::Const + // + ty: bool + // + val: Value(Scalar(0x01)) + // mir::Constant + // + span: $SRC_DIR/core/src/macros/mod.rs:LL:COL + // + literal: Const { ty: bool, val: Value(Scalar(0x01)) } + goto -> bb3; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL + } + + bb3: { + switchInt(_2) -> [false: bb4, otherwise: bb5]; // scope 0 at $DIR/matches_reduce_branches.rs:5:5: 7:6 + } + + bb4: { + _0 = const (); // scope 0 at $DIR/matches_reduce_branches.rs:5:5: 7:6 + // ty::Const + // + ty: () + // + val: Value(Scalar()) + // mir::Constant + // + span: $DIR/matches_reduce_branches.rs:5:5: 7:6 + // + literal: Const { ty: (), val: Value(Scalar()) } + goto -> bb5; // scope 0 at $DIR/matches_reduce_branches.rs:5:5: 7:6 + } + + bb5: { + StorageDead(_2); // scope 0 at $DIR/matches_reduce_branches.rs:8:1: 8:2 + return; // scope 0 at $DIR/matches_reduce_branches.rs:8:2: 8:2 + } + } + diff --git a/src/test/mir-opt/matches_reduce_branches.foo.MatchBranchSimplification.diff.64bit b/src/test/mir-opt/matches_reduce_branches.foo.MatchBranchSimplification.diff.64bit new file mode 100644 index 00000000000..06849b4a5d9 --- /dev/null +++ b/src/test/mir-opt/matches_reduce_branches.foo.MatchBranchSimplification.diff.64bit @@ -0,0 +1,66 @@ +- // MIR for `foo` before MatchBranchSimplification ++ // MIR for `foo` after MatchBranchSimplification + + fn foo(_1: std::option::Option<()>) -> () { + debug bar => _1; // in scope 0 at $DIR/matches_reduce_branches.rs:4:8: 4:11 + let mut _0: (); // return place in scope 0 at $DIR/matches_reduce_branches.rs:4:25: 4:25 + let mut _2: bool; // in scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL + let mut _3: isize; // in scope 0 at $DIR/matches_reduce_branches.rs:5:22: 5:26 + + bb0: { + StorageLive(_2); // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL + _3 = discriminant(_1); // scope 0 at $DIR/matches_reduce_branches.rs:5:22: 5:26 +- switchInt(move _3) -> [0_isize: bb2, otherwise: bb1]; // scope 0 at $DIR/matches_reduce_branches.rs:5:22: 5:26 ++ _2 = Eq(move _3, const 0_isize); // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL ++ // ty::Const ++ // + ty: isize ++ // + val: Value(Scalar(0x0000000000000000)) ++ // mir::Constant ++ // + span: $DIR/matches_reduce_branches.rs:1:1: 1:1 ++ // + literal: Const { ty: isize, val: Value(Scalar(0x0000000000000000)) } ++ goto -> bb3; // scope 0 at $DIR/matches_reduce_branches.rs:5:22: 5:26 + } + + bb1: { + _2 = const false; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL + // ty::Const + // + ty: bool + // + val: Value(Scalar(0x00)) + // mir::Constant + // + span: $SRC_DIR/core/src/macros/mod.rs:LL:COL + // + literal: Const { ty: bool, val: Value(Scalar(0x00)) } + goto -> bb3; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL + } + + bb2: { + _2 = const true; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL + // ty::Const + // + ty: bool + // + val: Value(Scalar(0x01)) + // mir::Constant + // + span: $SRC_DIR/core/src/macros/mod.rs:LL:COL + // + literal: Const { ty: bool, val: Value(Scalar(0x01)) } + goto -> bb3; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL + } + + bb3: { + switchInt(_2) -> [false: bb4, otherwise: bb5]; // scope 0 at $DIR/matches_reduce_branches.rs:5:5: 7:6 + } + + bb4: { + _0 = const (); // scope 0 at $DIR/matches_reduce_branches.rs:5:5: 7:6 + // ty::Const + // + ty: () + // + val: Value(Scalar()) + // mir::Constant + // + span: $DIR/matches_reduce_branches.rs:5:5: 7:6 + // + literal: Const { ty: (), val: Value(Scalar()) } + goto -> bb5; // scope 0 at $DIR/matches_reduce_branches.rs:5:5: 7:6 + } + + bb5: { + StorageDead(_2); // scope 0 at $DIR/matches_reduce_branches.rs:8:1: 8:2 + return; // scope 0 at $DIR/matches_reduce_branches.rs:8:2: 8:2 + } + } + diff --git a/src/test/mir-opt/matches_reduce_branches.rs b/src/test/mir-opt/matches_reduce_branches.rs new file mode 100644 index 00000000000..91b6bfc836b --- /dev/null +++ b/src/test/mir-opt/matches_reduce_branches.rs @@ -0,0 +1,13 @@ +// EMIT_MIR_FOR_EACH_BIT_WIDTH +// EMIT_MIR matches_reduce_branches.foo.MatchBranchSimplification.diff + +fn foo(bar: Option<()>) { + if matches!(bar, None) { + () + } +} + +fn main() { + let _ = foo(None); + let _ = foo(Some(())); +}