diff --git a/compiler/rustc_middle/src/mir/terminator.rs b/compiler/rustc_middle/src/mir/terminator.rs index f116347cc2b..58a27c1f9ef 100644 --- a/compiler/rustc_middle/src/mir/terminator.rs +++ b/compiler/rustc_middle/src/mir/terminator.rs @@ -85,6 +85,12 @@ pub fn add_target(&mut self, value: u128, bb: BasicBlock) { self.values.push(value); self.targets.insert(self.targets.len() - 1, bb); } + + /// Returns true if all targets (including the fallback target) are distinct. + #[inline] + pub fn is_distinct(&self) -> bool { + self.targets.iter().collect::>().len() == self.targets.len() + } } pub struct SwitchTargetsIter<'a> { diff --git a/compiler/rustc_mir_transform/src/match_branches.rs b/compiler/rustc_mir_transform/src/match_branches.rs index be1158683ac..e766c1ae0f6 100644 --- a/compiler/rustc_mir_transform/src/match_branches.rs +++ b/compiler/rustc_mir_transform/src/match_branches.rs @@ -1,6 +1,6 @@ use rustc_index::IndexVec; use rustc_middle::mir::*; -use rustc_middle::ty::{ParamEnv, Ty, TyCtxt}; +use rustc_middle::ty::{ParamEnv, ScalarInt, Ty, TyCtxt}; use std::iter; use super::simplify::simplify_cfg; @@ -38,6 +38,11 @@ fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { should_cleanup = true; continue; } + if SimplifyToExp::default().simplify(tcx, &mut body.local_decls, bbs, bb_idx, param_env) + { + should_cleanup = true; + continue; + } } if should_cleanup { @@ -47,8 +52,10 @@ fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { } trait SimplifyMatch<'tcx> { + /// Simplifies a match statement, returning true if the simplification succeeds, false otherwise. + /// Generic code is written here, and we generally don't need a custom implementation. fn simplify( - &self, + &mut self, tcx: TyCtxt<'tcx>, local_decls: &mut IndexVec>, bbs: &mut IndexVec>, @@ -72,9 +79,7 @@ fn simplify( let source_info = bbs[switch_bb_idx].terminator().source_info; let discr_local = local_decls.push(LocalDecl::new(discr_ty, source_info.span)); - // We already checked that first and second are different blocks, - // and bb_idx has a different terminator from both of them. - let new_stmts = self.new_stmts(tcx, targets, param_env, bbs, discr_local.clone(), discr_ty); + let new_stmts = self.new_stmts(tcx, targets, param_env, bbs, discr_local, discr_ty); let (_, first) = targets.iter().next().unwrap(); let (from, first) = bbs.pick2_mut(switch_bb_idx, first); from.statements @@ -90,8 +95,11 @@ fn simplify( true } + /// Check that the BBs to be simplified satisfies all distinct and + /// that the terminator are the same. + /// There are also conditions for different ways of simplification. fn can_simplify( - &self, + &mut self, tcx: TyCtxt<'tcx>, targets: &SwitchTargets, param_env: ParamEnv<'tcx>, @@ -144,7 +152,7 @@ fn new_stmts( /// ``` impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf { fn can_simplify( - &self, + &mut self, tcx: TyCtxt<'tcx>, targets: &SwitchTargets, param_env: ParamEnv<'tcx>, @@ -250,3 +258,211 @@ fn new_stmts( new_stmts.collect() } } + +#[derive(Default)] +struct SimplifyToExp { + transfrom_types: Vec, +} + +#[derive(Clone, Copy)] +enum CompareType<'tcx, 'a> { + Same(&'a StatementKind<'tcx>), + Eq(&'a Place<'tcx>, Ty<'tcx>, ScalarInt), + Discr(&'a Place<'tcx>, Ty<'tcx>), +} + +enum TransfromType { + Same, + Eq, + Discr, +} + +impl From> for TransfromType { + fn from(compare_type: CompareType<'_, '_>) -> Self { + match compare_type { + CompareType::Same(_) => TransfromType::Same, + CompareType::Eq(_, _, _) => TransfromType::Eq, + CompareType::Discr(_, _) => TransfromType::Discr, + } + } +} + +/// If we find that the value of match is the same as the assignment, +/// merge a target block statements into the source block, +/// using cast to transform different integer types. +/// +/// For example: +/// +/// ```ignore (MIR) +/// bb0: { +/// switchInt(_1) -> [1: bb2, 2: bb3, 3: bb4, otherwise: bb1]; +/// } +/// +/// bb1: { +/// unreachable; +/// } +/// +/// bb2: { +/// _0 = const 1_i16; +/// goto -> bb5; +/// } +/// +/// bb3: { +/// _0 = const 2_i16; +/// goto -> bb5; +/// } +/// +/// bb4: { +/// _0 = const 3_i16; +/// goto -> bb5; +/// } +/// ``` +/// +/// into: +/// +/// ```ignore (MIR) +/// bb0: { +/// _0 = _3 as i16 (IntToInt); +/// goto -> bb5; +/// } +/// ``` +impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { + fn can_simplify( + &mut self, + tcx: TyCtxt<'tcx>, + targets: &SwitchTargets, + param_env: ParamEnv<'tcx>, + bbs: &IndexVec>, + ) -> bool { + if targets.iter().len() < 2 || targets.iter().len() > 64 { + return false; + } + // We require that the possible target blocks all be distinct. + if !targets.is_distinct() { + return false; + } + if !bbs[targets.otherwise()].is_empty_unreachable() { + return false; + } + let mut target_iter = targets.iter(); + let (first_val, first_target) = target_iter.next().unwrap(); + let first_terminator_kind = &bbs[first_target].terminator().kind; + // Check that destinations are identical, and if not, then don't optimize this block + if !targets + .iter() + .all(|(_, other_target)| first_terminator_kind == &bbs[other_target].terminator().kind) + { + return false; + } + + let first_stmts = &bbs[first_target].statements; + let (second_val, second_target) = target_iter.next().unwrap(); + let second_stmts = &bbs[second_target].statements; + if first_stmts.len() != second_stmts.len() { + return false; + } + + let mut compare_types = Vec::new(); + for (f, s) in iter::zip(first_stmts, second_stmts) { + let compare_type = match (&f.kind, &s.kind) { + // If two statements are exactly the same, we can optimize. + (f_s, s_s) if f_s == s_s => CompareType::Same(f_s), + + // If two statements are assignments with the match values to the same place, we can optimize. + ( + StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))), + StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), + ) if lhs_f == lhs_s + && f_c.const_.ty() == s_c.const_.ty() + && f_c.const_.ty().is_integral() => + { + match ( + f_c.const_.try_eval_scalar_int(tcx, param_env), + s_c.const_.try_eval_scalar_int(tcx, param_env), + ) { + (Some(f), Some(s)) if f == s => CompareType::Eq(lhs_f, f_c.const_.ty(), f), + (Some(f), Some(s)) + if Some(f) == ScalarInt::try_from_uint(first_val, f.size()) + && Some(s) == ScalarInt::try_from_uint(second_val, s.size()) => + { + CompareType::Discr(lhs_f, f_c.const_.ty()) + } + _ => return false, + } + } + + // Otherwise we cannot optimize. Try another block. + _ => return false, + }; + compare_types.push(compare_type); + } + + // All remaining BBs need to fulfill the same pattern as the two BBs from the previous step. + for (other_val, other_target) in target_iter { + let other_stmts = &bbs[other_target].statements; + if compare_types.len() != other_stmts.len() { + return false; + } + for (f, s) in iter::zip(&compare_types, other_stmts) { + match (*f, &s.kind) { + (CompareType::Same(f_s), s_s) if f_s == s_s => {} + ( + CompareType::Eq(lhs_f, f_ty, val), + StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), + ) if lhs_f == lhs_s + && s_c.const_.ty() == f_ty + && s_c.const_.try_eval_scalar_int(tcx, param_env) == Some(val) => {} + ( + CompareType::Discr(lhs_f, f_ty), + StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), + ) if lhs_f == lhs_s && s_c.const_.ty() == f_ty => { + let Some(f) = s_c.const_.try_eval_scalar_int(tcx, param_env) else { + return false; + }; + if Some(f) != ScalarInt::try_from_uint(other_val, f.size()) { + return false; + } + } + _ => return false, + } + } + } + self.transfrom_types = compare_types.into_iter().map(|c| c.into()).collect(); + true + } + + fn new_stmts( + &self, + _tcx: TyCtxt<'tcx>, + targets: &SwitchTargets, + _param_env: ParamEnv<'tcx>, + bbs: &IndexVec>, + discr_local: Local, + discr_ty: Ty<'tcx>, + ) -> Vec> { + let (_, first) = targets.iter().next().unwrap(); + let first = &bbs[first]; + + let new_stmts = + iter::zip(&self.transfrom_types, &first.statements).map(|(t, s)| match (t, &s.kind) { + (TransfromType::Same, _) | (TransfromType::Eq, _) => (*s).clone(), + ( + TransfromType::Discr, + StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))), + ) => { + let operand = Operand::Copy(Place::from(discr_local)); + let r_val = if f_c.const_.ty() == discr_ty { + Rvalue::Use(operand) + } else { + Rvalue::Cast(CastKind::IntToInt, operand, f_c.const_.ty()) + }; + Statement { + source_info: s.source_info, + kind: StatementKind::Assign(Box::new((*lhs, r_val))), + } + } + _ => unreachable!(), + }); + new_stmts.collect() + } +} diff --git a/tests/codegen/match-optimized.rs b/tests/codegen/match-optimized.rs index 09907edf8f2..5cecafb9f29 100644 --- a/tests/codegen/match-optimized.rs +++ b/tests/codegen/match-optimized.rs @@ -26,12 +26,12 @@ pub fn exhaustive_match(e: E) -> u8 { // CHECK-NEXT: store i8 1, ptr %_0, align 1 // CHECK-NEXT: br label %[[EXIT]] // CHECK: [[C]]: -// CHECK-NEXT: store i8 2, ptr %_0, align 1 +// CHECK-NEXT: store i8 3, ptr %_0, align 1 // CHECK-NEXT: br label %[[EXIT]] match e { E::A => 0, E::B => 1, - E::C => 2, + E::C => 3, } } diff --git a/tests/mir-opt/matches_reduce_branches.match_i128_u128.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_i128_u128.MatchBranchSimplification.diff index 1f20349fdec..31ce51dc6de 100644 --- a/tests/mir-opt/matches_reduce_branches.match_i128_u128.MatchBranchSimplification.diff +++ b/tests/mir-opt/matches_reduce_branches.match_i128_u128.MatchBranchSimplification.diff @@ -5,37 +5,42 @@ debug i => _1; let mut _0: u128; let mut _2: i128; ++ let mut _3: i128; bb0: { _2 = discriminant(_1); - switchInt(move _2) -> [1: bb3, 2: bb4, 3: bb5, 340282366920938463463374607431768211455: bb2, otherwise: bb1]; - } - - bb1: { - unreachable; - } - - bb2: { - _0 = const core::num::::MAX; - goto -> bb6; - } - - bb3: { - _0 = const 1_u128; - goto -> bb6; - } - - bb4: { - _0 = const 2_u128; - goto -> bb6; - } - - bb5: { - _0 = const 3_u128; - goto -> bb6; - } - - bb6: { +- switchInt(move _2) -> [1: bb3, 2: bb4, 3: bb5, 340282366920938463463374607431768211455: bb2, otherwise: bb1]; +- } +- +- bb1: { +- unreachable; +- } +- +- bb2: { +- _0 = const core::num::::MAX; +- goto -> bb6; +- } +- +- bb3: { +- _0 = const 1_u128; +- goto -> bb6; +- } +- +- bb4: { +- _0 = const 2_u128; +- goto -> bb6; +- } +- +- bb5: { +- _0 = const 3_u128; +- goto -> bb6; +- } +- +- bb6: { ++ StorageLive(_3); ++ _3 = move _2; ++ _0 = _3 as u128 (IntToInt); ++ StorageDead(_3); return; } } diff --git a/tests/mir-opt/matches_reduce_branches.match_u8_i16.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_u8_i16.MatchBranchSimplification.diff index 72ad60956ab..9ee01a87a91 100644 --- a/tests/mir-opt/matches_reduce_branches.match_u8_i16.MatchBranchSimplification.diff +++ b/tests/mir-opt/matches_reduce_branches.match_u8_i16.MatchBranchSimplification.diff @@ -5,27 +5,32 @@ debug i => _1; let mut _0: i16; let mut _2: u8; ++ let mut _3: u8; bb0: { _2 = discriminant(_1); - switchInt(move _2) -> [1: bb3, 2: bb2, otherwise: bb1]; - } - - bb1: { - unreachable; - } - - bb2: { - _0 = const 2_i16; - goto -> bb4; - } - - bb3: { - _0 = const 1_i16; - goto -> bb4; - } - - bb4: { +- switchInt(move _2) -> [1: bb3, 2: bb2, otherwise: bb1]; +- } +- +- bb1: { +- unreachable; +- } +- +- bb2: { +- _0 = const 2_i16; +- goto -> bb4; +- } +- +- bb3: { +- _0 = const 1_i16; +- goto -> bb4; +- } +- +- bb4: { ++ StorageLive(_3); ++ _3 = move _2; ++ _0 = _3 as i16 (IntToInt); ++ StorageDead(_3); return; } } diff --git a/tests/mir-opt/matches_reduce_branches.match_u8_u16.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_u8_u16.MatchBranchSimplification.diff index 043fdb197a3..aa9fcc60a3e 100644 --- a/tests/mir-opt/matches_reduce_branches.match_u8_u16.MatchBranchSimplification.diff +++ b/tests/mir-opt/matches_reduce_branches.match_u8_u16.MatchBranchSimplification.diff @@ -5,32 +5,37 @@ debug i => _1; let mut _0: u16; let mut _2: u8; ++ let mut _3: u8; bb0: { _2 = discriminant(_1); - switchInt(move _2) -> [1: bb3, 2: bb4, 5: bb2, otherwise: bb1]; - } - - bb1: { - unreachable; - } - - bb2: { - _0 = const 5_u16; - goto -> bb5; - } - - bb3: { - _0 = const 1_u16; - goto -> bb5; - } - - bb4: { - _0 = const 2_u16; - goto -> bb5; - } - - bb5: { +- switchInt(move _2) -> [1: bb3, 2: bb4, 5: bb2, otherwise: bb1]; +- } +- +- bb1: { +- unreachable; +- } +- +- bb2: { +- _0 = const 5_u16; +- goto -> bb5; +- } +- +- bb3: { +- _0 = const 1_u16; +- goto -> bb5; +- } +- +- bb4: { +- _0 = const 2_u16; +- goto -> bb5; +- } +- +- bb5: { ++ StorageLive(_3); ++ _3 = move _2; ++ _0 = _3 as u16 (IntToInt); ++ StorageDead(_3); return; } } diff --git a/tests/mir-opt/matches_reduce_branches.rs b/tests/mir-opt/matches_reduce_branches.rs index 2e7b7d4e600..d51dd7c5873 100644 --- a/tests/mir-opt/matches_reduce_branches.rs +++ b/tests/mir-opt/matches_reduce_branches.rs @@ -75,7 +75,9 @@ enum EnumAu8 { // EMIT_MIR matches_reduce_branches.match_u8_i16.MatchBranchSimplification.diff fn match_u8_i16(i: EnumAu8) -> i16 { // CHECK-LABEL: fn match_u8_i16( - // CHECK: switchInt + // CHECK-NOT: switchInt + // CHECK: _0 = _3 as i16 (IntToInt); + // CHECH: return match i { EnumAu8::A => 1, EnumAu8::B => 2, @@ -144,7 +146,9 @@ enum EnumBu8 { // EMIT_MIR matches_reduce_branches.match_u8_u16.MatchBranchSimplification.diff fn match_u8_u16(i: EnumBu8) -> u16 { // CHECK-LABEL: fn match_u8_u16( - // CHECK: switchInt + // CHECK-NOT: switchInt + // CHECK: _0 = _3 as u16 (IntToInt); + // CHECH: return match i { EnumBu8::A => 1, EnumBu8::B => 2, @@ -248,7 +252,9 @@ enum EnumAi128 { // EMIT_MIR matches_reduce_branches.match_i128_u128.MatchBranchSimplification.diff fn match_i128_u128(i: EnumAi128) -> u128 { // CHECK-LABEL: fn match_i128_u128( - // CHECK: switchInt + // CHECK-NOT: switchInt + // CHECK: _0 = _3 as u128 (IntToInt); + // CHECH: return match i { EnumAi128::A => 1, EnumAi128::B => 2, diff --git a/tests/mir-opt/matches_u8.exhaustive_match.MatchBranchSimplification.diff b/tests/mir-opt/matches_u8.exhaustive_match.MatchBranchSimplification.diff index 157f9c98353..11a18f58e3a 100644 --- a/tests/mir-opt/matches_u8.exhaustive_match.MatchBranchSimplification.diff +++ b/tests/mir-opt/matches_u8.exhaustive_match.MatchBranchSimplification.diff @@ -5,27 +5,32 @@ debug e => _1; let mut _0: u8; let mut _2: isize; ++ let mut _3: isize; bb0: { _2 = discriminant(_1); - switchInt(move _2) -> [0: bb3, 1: bb2, otherwise: bb1]; - } - - bb1: { - unreachable; - } - - bb2: { - _0 = const 1_u8; - goto -> bb4; - } - - bb3: { - _0 = const 0_u8; - goto -> bb4; - } - - bb4: { +- switchInt(move _2) -> [0: bb3, 1: bb2, otherwise: bb1]; +- } +- +- bb1: { +- unreachable; +- } +- +- bb2: { +- _0 = const 1_u8; +- goto -> bb4; +- } +- +- bb3: { +- _0 = const 0_u8; +- goto -> bb4; +- } +- +- bb4: { ++ StorageLive(_3); ++ _3 = move _2; ++ _0 = _3 as u8 (IntToInt); ++ StorageDead(_3); return; } } diff --git a/tests/mir-opt/matches_u8.exhaustive_match_i8.MatchBranchSimplification.diff b/tests/mir-opt/matches_u8.exhaustive_match_i8.MatchBranchSimplification.diff index 19083771fd9..809badc41ba 100644 --- a/tests/mir-opt/matches_u8.exhaustive_match_i8.MatchBranchSimplification.diff +++ b/tests/mir-opt/matches_u8.exhaustive_match_i8.MatchBranchSimplification.diff @@ -5,27 +5,32 @@ debug e => _1; let mut _0: i8; let mut _2: isize; ++ let mut _3: isize; bb0: { _2 = discriminant(_1); - switchInt(move _2) -> [0: bb3, 1: bb2, otherwise: bb1]; - } - - bb1: { - unreachable; - } - - bb2: { - _0 = const 1_i8; - goto -> bb4; - } - - bb3: { - _0 = const 0_i8; - goto -> bb4; - } - - bb4: { +- switchInt(move _2) -> [0: bb3, 1: bb2, otherwise: bb1]; +- } +- +- bb1: { +- unreachable; +- } +- +- bb2: { +- _0 = const 1_i8; +- goto -> bb4; +- } +- +- bb3: { +- _0 = const 0_i8; +- goto -> bb4; +- } +- +- bb4: { ++ StorageLive(_3); ++ _3 = move _2; ++ _0 = _3 as i8 (IntToInt); ++ StorageDead(_3); return; } }