From 3d3b321c60f6ce1ac59edf0706c083aa7fbd1e83 Mon Sep 17 00:00:00 2001 From: Nadrieril Date: Thu, 29 Feb 2024 00:52:03 +0100 Subject: [PATCH] Use an enum instead of manually tracking indices for `target_blocks` --- .../rustc_mir_build/src/build/matches/mod.rs | 34 +++-- .../rustc_mir_build/src/build/matches/test.rs | 117 ++++++++++-------- .../building/issue_49232.main.built.after.mir | 8 +- ...fg-initial.after-ElaborateDrops.after.diff | 15 +-- ...fg-initial.after-ElaborateDrops.after.diff | 15 +-- 5 files changed, 112 insertions(+), 77 deletions(-) diff --git a/compiler/rustc_mir_build/src/build/matches/mod.rs b/compiler/rustc_mir_build/src/build/matches/mod.rs index daa0349789e..aea52fc497f 100644 --- a/compiler/rustc_mir_build/src/build/matches/mod.rs +++ b/compiler/rustc_mir_build/src/build/matches/mod.rs @@ -1160,6 +1160,19 @@ pub(crate) struct Test<'tcx> { kind: TestKind<'tcx>, } +/// The branch to be taken after a test. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +enum TestBranch<'tcx> { + /// Success branch, used for tests with two possible outcomes. + Success, + /// Branch corresponding to this constant. + Constant(Const<'tcx>, u128), + /// Branch corresponding to this variant. + Variant(VariantIdx), + /// Failure branch for tests with two possible outcomes, and "otherwise" branch for other tests. + Failure, +} + /// `ArmHasGuard` is a wrapper around a boolean flag. It indicates whether /// a match arm has a guard expression attached to it. #[derive(Copy, Clone, Debug)] @@ -1636,11 +1649,14 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { match_place: &PlaceBuilder<'tcx>, test: &Test<'tcx>, mut candidates: &'b mut [&'c mut Candidate<'pat, 'tcx>], - ) -> (&'b mut [&'c mut Candidate<'pat, 'tcx>], Vec>>) { + ) -> ( + &'b mut [&'c mut Candidate<'pat, 'tcx>], + FxIndexMap, Vec<&'b mut Candidate<'pat, 'tcx>>>, + ) { // For each of the N possible outcomes, create a (initially empty) vector of candidates. // Those are the candidates that apply if the test has that particular outcome. - let mut target_candidates: Vec>> = vec![]; - target_candidates.resize_with(test.targets(), Default::default); + let mut target_candidates: FxIndexMap<_, Vec<&mut Candidate<'pat, 'tcx>>> = + test.targets().into_iter().map(|branch| (branch, Vec::new())).collect(); let total_candidate_count = candidates.len(); @@ -1648,11 +1664,11 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { // point we may encounter a candidate where the test is not relevant; at that point, we stop // sorting. while let Some(candidate) = candidates.first_mut() { - let Some(idx) = self.sort_candidate(&match_place, &test, candidate) else { + let Some(branch) = self.sort_candidate(&match_place, &test, candidate) else { break; }; let (candidate, rest) = candidates.split_first_mut().unwrap(); - target_candidates[idx].push(candidate); + target_candidates[&branch].push(candidate); candidates = rest; } @@ -1797,9 +1813,9 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { // apply. Collect a list of blocks where control flow will // branch if one of the `target_candidate` sets is not // exhaustive. - let target_blocks: Vec<_> = target_candidates + let target_blocks: FxIndexMap<_, _> = target_candidates .into_iter() - .map(|mut candidates| { + .map(|(branch, mut candidates)| { if !candidates.is_empty() { let candidate_start = self.cfg.start_new_block(); self.match_candidates( @@ -1809,9 +1825,9 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { remainder_start, &mut *candidates, ); - candidate_start + (branch, candidate_start) } else { - remainder_start + (branch, remainder_start) } }) .collect(); diff --git a/compiler/rustc_mir_build/src/build/matches/test.rs b/compiler/rustc_mir_build/src/build/matches/test.rs index d811141f50f..d003ae8d803 100644 --- a/compiler/rustc_mir_build/src/build/matches/test.rs +++ b/compiler/rustc_mir_build/src/build/matches/test.rs @@ -6,7 +6,7 @@ // the candidates based on the result. use crate::build::expr::as_place::PlaceBuilder; -use crate::build::matches::{Candidate, MatchPair, Test, TestCase, TestKind}; +use crate::build::matches::{Candidate, MatchPair, Test, TestBranch, TestCase, TestKind}; use crate::build::Builder; use rustc_data_structures::fx::FxIndexMap; use rustc_hir::{LangItem, RangeEnd}; @@ -129,11 +129,12 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { block: BasicBlock, place_builder: &PlaceBuilder<'tcx>, test: &Test<'tcx>, - target_blocks: Vec, + target_blocks: FxIndexMap, BasicBlock>, ) { let place = place_builder.to_place(self); let place_ty = place.ty(&self.local_decls, self.tcx); - debug!(?place, ?place_ty,); + debug!(?place, ?place_ty); + let target_block = |branch| target_blocks[&branch]; let source_info = self.source_info(test.span); match test.kind { @@ -141,20 +142,20 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { // Variants is a BitVec of indexes into adt_def.variants. let num_enum_variants = adt_def.variants().len(); debug_assert_eq!(target_blocks.len(), num_enum_variants + 1); - let otherwise_block = *target_blocks.last().unwrap(); + let otherwise_block = target_block(TestBranch::Failure); let tcx = self.tcx; let switch_targets = SwitchTargets::new( adt_def.discriminants(tcx).filter_map(|(idx, discr)| { if variants.contains(idx) { debug_assert_ne!( - target_blocks[idx.index()], + target_block(TestBranch::Variant(idx)), otherwise_block, "no candidates for tested discriminant: {discr:?}", ); - Some((discr.val, target_blocks[idx.index()])) + Some((discr.val, target_block(TestBranch::Variant(idx)))) } else { debug_assert_eq!( - target_blocks[idx.index()], + target_block(TestBranch::Variant(idx)), otherwise_block, "found candidates for untested discriminant: {discr:?}", ); @@ -185,9 +186,11 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { TestKind::SwitchInt { ref options } => { // The switch may be inexhaustive so we have a catch-all block debug_assert_eq!(options.len() + 1, target_blocks.len()); - let otherwise_block = *target_blocks.last().unwrap(); + let otherwise_block = target_block(TestBranch::Failure); let switch_targets = SwitchTargets::new( - options.values().copied().zip(target_blocks), + options + .iter() + .map(|(&val, &bits)| (bits, target_block(TestBranch::Constant(val, bits)))), otherwise_block, ); let terminator = TerminatorKind::SwitchInt { @@ -198,18 +201,19 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { } TestKind::If => { - let [false_bb, true_bb] = *target_blocks else { - bug!("`TestKind::If` should have two targets") - }; - let terminator = TerminatorKind::if_(Operand::Copy(place), true_bb, false_bb); + debug_assert_eq!(target_blocks.len(), 2); + let success_block = target_block(TestBranch::Success); + let fail_block = target_block(TestBranch::Failure); + let terminator = + TerminatorKind::if_(Operand::Copy(place), success_block, fail_block); self.cfg.terminate(block, self.source_info(match_start_span), terminator); } TestKind::Eq { value, ty } => { let tcx = self.tcx; - let [success_block, fail_block] = *target_blocks else { - bug!("`TestKind::Eq` should have two target blocks") - }; + debug_assert_eq!(target_blocks.len(), 2); + let success_block = target_block(TestBranch::Success); + let fail_block = target_block(TestBranch::Failure); if let ty::Adt(def, _) = ty.kind() && Some(def.did()) == tcx.lang_items().string() { @@ -286,9 +290,9 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { } TestKind::Range(ref range) => { - let [success, fail] = *target_blocks else { - bug!("`TestKind::Range` should have two target blocks"); - }; + debug_assert_eq!(target_blocks.len(), 2); + let success = target_block(TestBranch::Success); + let fail = target_block(TestBranch::Failure); // Test `val` by computing `lo <= val && val <= hi`, using primitive comparisons. let val = Operand::Copy(place); @@ -333,15 +337,15 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { // expected = let expected = self.push_usize(block, source_info, len); - let [true_bb, false_bb] = *target_blocks else { - bug!("`TestKind::Len` should have two target blocks"); - }; + debug_assert_eq!(target_blocks.len(), 2); + let success_block = target_block(TestBranch::Success); + let fail_block = target_block(TestBranch::Failure); // result = actual == expected OR result = actual < expected // branch based on result self.compare( block, - true_bb, - false_bb, + success_block, + fail_block, source_info, op, Operand::Move(actual), @@ -526,10 +530,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { /// Given that we are performing `test` against `test_place`, this job /// sorts out what the status of `candidate` will be after the test. See - /// `test_candidates` for the usage of this function. The returned index is - /// the index that this candidate should be placed in the - /// `target_candidates` vec. The candidate may be modified to update its - /// `match_pairs`. + /// `test_candidates` for the usage of this function. The candidate may + /// be modified to update its `match_pairs`. /// /// So, for example, if this candidate is `x @ Some(P0)` and the `Test` is /// a variant test, then we would modify the candidate to be `(x as @@ -556,7 +558,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { test_place: &PlaceBuilder<'tcx>, test: &Test<'tcx>, candidate: &mut Candidate<'pat, 'tcx>, - ) -> Option { + ) -> Option> { // Find the match_pair for this place (if any). At present, // afaik, there can be at most one. (In the future, if we // adopted a more general `@` operator, there might be more @@ -576,7 +578,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { ) => { assert_eq!(adt_def, tested_adt_def); fully_matched = true; - Some(variant_index.as_usize()) + Some(TestBranch::Variant(variant_index)) } // If we are performing a switch over integers, then this informs integer @@ -584,12 +586,12 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { // // FIXME(#29623) we could use PatKind::Range to rule // things out here, in some cases. - (TestKind::SwitchInt { options }, TestCase::Constant { value }) + (TestKind::SwitchInt { options }, &TestCase::Constant { value }) if is_switch_ty(match_pair.pattern.ty) => { fully_matched = true; - let index = options.get_index_of(value).unwrap(); - Some(index) + let bits = options.get(&value).unwrap(); + Some(TestBranch::Constant(value, *bits)) } (TestKind::SwitchInt { options }, TestCase::Range(range)) => { fully_matched = false; @@ -599,7 +601,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { not_contained.then(|| { // No switch values are contained in the pattern range, // so the pattern can be matched only if this test fails. - options.len() + TestBranch::Failure }) } @@ -608,7 +610,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { let value = value.try_eval_bool(self.tcx, self.param_env).unwrap_or_else(|| { span_bug!(test.span, "expected boolean value but got {value:?}") }); - Some(value as usize) + Some(if value { TestBranch::Success } else { TestBranch::Failure }) } ( @@ -620,14 +622,14 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { // on true, min_len = len = $actual_length, // on false, len != $actual_length fully_matched = true; - Some(0) + Some(TestBranch::Success) } (Ordering::Less, _) => { // test_len < pat_len. If $actual_len = test_len, // then $actual_len < pat_len and we don't have // enough elements. fully_matched = false; - Some(1) + Some(TestBranch::Failure) } (Ordering::Equal | Ordering::Greater, true) => { // This can match both if $actual_len = test_len >= pat_len, @@ -639,7 +641,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { // test_len != pat_len, so if $actual_len = test_len, then // $actual_len != pat_len. fully_matched = false; - Some(1) + Some(TestBranch::Failure) } } } @@ -653,20 +655,20 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { // $actual_len >= test_len = pat_len, // so we can match. fully_matched = true; - Some(0) + Some(TestBranch::Success) } (Ordering::Less, _) | (Ordering::Equal, false) => { // test_len <= pat_len. If $actual_len < test_len, // then it is also < pat_len, so the test passing is // necessary (but insufficient). fully_matched = false; - Some(0) + Some(TestBranch::Success) } (Ordering::Greater, false) => { // test_len > pat_len. If $actual_len >= test_len > pat_len, // then we know we won't have a match. fully_matched = false; - Some(1) + Some(TestBranch::Failure) } (Ordering::Greater, true) => { // test_len < pat_len, and is therefore less @@ -680,12 +682,16 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { (TestKind::Range(test), &TestCase::Range(pat)) => { if test.as_ref() == pat { fully_matched = true; - Some(0) + Some(TestBranch::Success) } else { fully_matched = false; // If the testing range does not overlap with pattern range, // the pattern can be matched only if this test fails. - if !test.overlaps(pat, self.tcx, self.param_env)? { Some(1) } else { None } + if !test.overlaps(pat, self.tcx, self.param_env)? { + Some(TestBranch::Failure) + } else { + None + } } } (TestKind::Range(range), &TestCase::Constant { value }) => { @@ -693,7 +699,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { if !range.contains(value, self.tcx, self.param_env)? { // `value` is not contained in the testing range, // so `value` can be matched only if this test fails. - Some(1) + Some(TestBranch::Failure) } else { None } @@ -704,7 +710,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { if test_val == case_val => { fully_matched = true; - Some(0) + Some(TestBranch::Success) } ( @@ -747,18 +753,29 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { } } -impl Test<'_> { - pub(super) fn targets(&self) -> usize { +impl<'tcx> Test<'tcx> { + pub(super) fn targets(&self) -> Vec> { match self.kind { - TestKind::Eq { .. } | TestKind::Range(_) | TestKind::Len { .. } | TestKind::If => 2, + TestKind::Eq { .. } | TestKind::Range(_) | TestKind::Len { .. } | TestKind::If => { + vec![TestBranch::Success, TestBranch::Failure] + } TestKind::Switch { adt_def, .. } => { // While the switch that we generate doesn't test for all // variants, we have a target for each variant and the // otherwise case, and we make sure that all of the cases not // specified have the same block. - adt_def.variants().len() + 1 + adt_def + .variants() + .indices() + .map(|idx| TestBranch::Variant(idx)) + .chain([TestBranch::Failure]) + .collect() } - TestKind::SwitchInt { ref options } => options.len() + 1, + TestKind::SwitchInt { ref options } => options + .iter() + .map(|(val, bits)| TestBranch::Constant(*val, *bits)) + .chain([TestBranch::Failure]) + .collect(), } } } diff --git a/tests/mir-opt/building/issue_49232.main.built.after.mir b/tests/mir-opt/building/issue_49232.main.built.after.mir index d09a1748a8b..166e28ce51d 100644 --- a/tests/mir-opt/building/issue_49232.main.built.after.mir +++ b/tests/mir-opt/building/issue_49232.main.built.after.mir @@ -25,7 +25,7 @@ fn main() -> () { StorageLive(_3); _3 = const true; PlaceMention(_3); - switchInt(_3) -> [0: bb4, otherwise: bb6]; + switchInt(_3) -> [0: bb6, otherwise: bb4]; } bb3: { @@ -34,7 +34,8 @@ fn main() -> () { } bb4: { - falseEdge -> [real: bb8, imaginary: bb6]; + _0 = const (); + goto -> bb13; } bb5: { @@ -42,8 +43,7 @@ fn main() -> () { } bb6: { - _0 = const (); - goto -> bb13; + falseEdge -> [real: bb8, imaginary: bb4]; } bb7: { diff --git a/tests/mir-opt/match_arm_scopes.complicated_match.panic-abort.SimplifyCfg-initial.after-ElaborateDrops.after.diff b/tests/mir-opt/match_arm_scopes.complicated_match.panic-abort.SimplifyCfg-initial.after-ElaborateDrops.after.diff index 619fda339a6..307f7105dd2 100644 --- a/tests/mir-opt/match_arm_scopes.complicated_match.panic-abort.SimplifyCfg-initial.after-ElaborateDrops.after.diff +++ b/tests/mir-opt/match_arm_scopes.complicated_match.panic-abort.SimplifyCfg-initial.after-ElaborateDrops.after.diff @@ -42,11 +42,15 @@ } bb2: { -- switchInt((_2.0: bool)) -> [0: bb3, otherwise: bb4]; +- switchInt((_2.0: bool)) -> [0: bb4, otherwise: bb3]; + switchInt((_2.0: bool)) -> [0: bb3, otherwise: bb17]; } bb3: { +- falseEdge -> [real: bb20, imaginary: bb4]; +- } +- +- bb4: { StorageLive(_15); _15 = (_2.1: bool); StorageLive(_16); @@ -55,12 +59,8 @@ + goto -> bb16; } - bb4: { -- falseEdge -> [real: bb20, imaginary: bb3]; -- } -- - bb5: { -- falseEdge -> [real: bb13, imaginary: bb4]; +- falseEdge -> [real: bb13, imaginary: bb3]; - } - - bb6: { @@ -68,6 +68,7 @@ - } - - bb7: { ++ bb4: { _0 = const 1_i32; - drop(_7) -> [return: bb18, unwind: bb25]; + drop(_7) -> [return: bb15, unwind: bb22]; @@ -183,7 +184,7 @@ StorageDead(_12); StorageDead(_8); StorageDead(_6); -- falseEdge -> [real: bb2, imaginary: bb4]; +- falseEdge -> [real: bb2, imaginary: bb3]; + goto -> bb2; } diff --git a/tests/mir-opt/match_arm_scopes.complicated_match.panic-unwind.SimplifyCfg-initial.after-ElaborateDrops.after.diff b/tests/mir-opt/match_arm_scopes.complicated_match.panic-unwind.SimplifyCfg-initial.after-ElaborateDrops.after.diff index 619fda339a6..307f7105dd2 100644 --- a/tests/mir-opt/match_arm_scopes.complicated_match.panic-unwind.SimplifyCfg-initial.after-ElaborateDrops.after.diff +++ b/tests/mir-opt/match_arm_scopes.complicated_match.panic-unwind.SimplifyCfg-initial.after-ElaborateDrops.after.diff @@ -42,11 +42,15 @@ } bb2: { -- switchInt((_2.0: bool)) -> [0: bb3, otherwise: bb4]; +- switchInt((_2.0: bool)) -> [0: bb4, otherwise: bb3]; + switchInt((_2.0: bool)) -> [0: bb3, otherwise: bb17]; } bb3: { +- falseEdge -> [real: bb20, imaginary: bb4]; +- } +- +- bb4: { StorageLive(_15); _15 = (_2.1: bool); StorageLive(_16); @@ -55,12 +59,8 @@ + goto -> bb16; } - bb4: { -- falseEdge -> [real: bb20, imaginary: bb3]; -- } -- - bb5: { -- falseEdge -> [real: bb13, imaginary: bb4]; +- falseEdge -> [real: bb13, imaginary: bb3]; - } - - bb6: { @@ -68,6 +68,7 @@ - } - - bb7: { ++ bb4: { _0 = const 1_i32; - drop(_7) -> [return: bb18, unwind: bb25]; + drop(_7) -> [return: bb15, unwind: bb22]; @@ -183,7 +184,7 @@ StorageDead(_12); StorageDead(_8); StorageDead(_6); -- falseEdge -> [real: bb2, imaginary: bb4]; +- falseEdge -> [real: bb2, imaginary: bb3]; + goto -> bb2; }