From 3760d919d86b1f4ad464ee1370ba6389e161a7a2 Mon Sep 17 00:00:00 2001 From: Nadrieril Date: Sun, 29 Oct 2023 05:37:59 +0100 Subject: [PATCH] Cleanup check_match code paths --- .../src/thir/pattern/check_match.rs | 214 +++++++++--------- 1 file changed, 104 insertions(+), 110 deletions(-) diff --git a/compiler/rustc_mir_build/src/thir/pattern/check_match.rs b/compiler/rustc_mir_build/src/thir/pattern/check_match.rs index 694e4c07bd5..eadd7a310b7 100644 --- a/compiler/rustc_mir_build/src/thir/pattern/check_match.rs +++ b/compiler/rustc_mir_build/src/thir/pattern/check_match.rs @@ -42,7 +42,7 @@ pub(crate) fn check_match(tcx: TyCtxt<'_>, def_id: LocalDefId) -> Result<(), Err for param in thir.params.iter() { if let Some(box ref pattern) = param.pat { - visitor.check_irrefutable(pattern, "function argument", None); + visitor.check_binding_is_irrefutable(pattern, "function argument", None); } } visitor.error @@ -66,16 +66,17 @@ use RefutableFlag::*; #[derive(Clone, Copy, Debug, PartialEq, Eq)] enum LetSource { None, + PlainLet, IfLet, IfLetGuard, LetElse, WhileLet, } -struct MatchVisitor<'a, 'p, 'tcx> { +struct MatchVisitor<'thir, 'p, 'tcx> { tcx: TyCtxt<'tcx>, param_env: ty::ParamEnv<'tcx>, - thir: &'a Thir<'tcx>, + thir: &'thir Thir<'tcx>, lint_level: HirId, let_source: LetSource, pattern_arena: &'p TypedArena>, @@ -85,8 +86,10 @@ struct MatchVisitor<'a, 'p, 'tcx> { error: Result<(), ErrorGuaranteed>, } -impl<'a, 'tcx> Visitor<'a, 'tcx> for MatchVisitor<'a, '_, 'tcx> { - fn thir(&self) -> &'a Thir<'tcx> { +// Visitor for a thir body. This calls `check_match`, `check_let` and `check_let_chain` as +// appropriate. +impl<'thir, 'tcx> Visitor<'thir, 'tcx> for MatchVisitor<'thir, '_, 'tcx> { + fn thir(&self) -> &'thir Thir<'tcx> { self.thir } @@ -101,7 +104,7 @@ impl<'a, 'tcx> Visitor<'a, 'tcx> for MatchVisitor<'a, '_, 'tcx> { } Some(Guard::IfLet(ref pat, expr)) => { this.with_let_source(LetSource::IfLetGuard, |this| { - this.check_let(pat, expr, LetSource::IfLetGuard, pat.span); + this.check_let(pat, Some(expr), pat.span); this.visit_pat(pat); this.visit_expr(&this.thir[expr]); }); @@ -148,45 +151,43 @@ impl<'a, 'tcx> Visitor<'a, 'tcx> for MatchVisitor<'a, '_, 'tcx> { }; self.check_match(scrutinee, arms, source, ex.span); } - ExprKind::Let { box ref pat, expr } => { - self.check_let(pat, expr, self.let_source, ex.span); + ExprKind::Let { box ref pat, expr } if !matches!(self.let_source, LetSource::None) => { + self.check_let(pat, Some(expr), ex.span); } - ExprKind::LogicalOp { op: LogicalOp::And, .. } => { - self.check_let_chain(self.let_source, ex); + ExprKind::LogicalOp { op: LogicalOp::And, .. } + if !matches!(self.let_source, LetSource::None) => + { + self.check_let_chain(ex); } _ => {} }; + // If we got e.g. `let pat1 = x1 && let pat2 = x2` above, we will now traverse the two + // `let`s. In order not to check them twice we set `LetSource::None`. self.with_let_source(LetSource::None, |this| visit::walk_expr(this, ex)); } fn visit_stmt(&mut self, stmt: &Stmt<'tcx>) { - let old_lint_level = self.lint_level; match stmt.kind { StmtKind::Let { box ref pattern, initializer, else_block, lint_level, span, .. } => { - if let LintLevel::Explicit(lint_level) = lint_level { - self.lint_level = lint_level; - } - - if let Some(initializer) = initializer - && else_block.is_some() - { - self.check_let(pattern, initializer, LetSource::LetElse, span); - } - - if else_block.is_none() { - self.check_irrefutable(pattern, "local binding", Some(span)); - } + self.with_lint_level(lint_level, |this| { + let let_source = + if else_block.is_some() { LetSource::LetElse } else { LetSource::PlainLet }; + this.with_let_source(let_source, |this| { + this.check_let(pattern, initializer, span) + }); + visit::walk_stmt(this, stmt); + }); + } + StmtKind::Expr { .. } => { + visit::walk_stmt(self, stmt); } - _ => {} } - visit::walk_stmt(self, stmt); - self.lint_level = old_lint_level; } } -impl<'p, 'tcx> MatchVisitor<'_, 'p, 'tcx> { +impl<'thir, 'p, 'tcx> MatchVisitor<'thir, 'p, 'tcx> { #[instrument(level = "trace", skip(self, f))] fn with_let_source(&mut self, let_source: LetSource, f: impl FnOnce(&mut Self)) { let old_let_source = self.let_source; @@ -228,24 +229,37 @@ impl<'p, 'tcx> MatchVisitor<'_, 'p, 'tcx> { } } - fn new_cx(&self, hir_id: HirId, refutable: bool) -> MatchCheckCtxt<'p, 'tcx> { + fn new_cx(&self, refutability: RefutableFlag) -> MatchCheckCtxt<'p, 'tcx> { + let refutable = match refutability { + Irrefutable => false, + Refutable => true, + }; MatchCheckCtxt { tcx: self.tcx, param_env: self.param_env, - module: self.tcx.parent_module(hir_id).to_def_id(), + module: self.tcx.parent_module(self.lint_level).to_def_id(), pattern_arena: &self.pattern_arena, refutable, } } #[instrument(level = "trace", skip(self))] - fn check_let(&mut self, pat: &Pat<'tcx>, scrutinee: ExprId, source: LetSource, span: Span) { - if let LetSource::None = source { - return; + fn check_let(&mut self, pat: &Pat<'tcx>, scrutinee: Option, span: Span) { + assert!(self.let_source != LetSource::None); + if let LetSource::PlainLet = self.let_source { + self.check_binding_is_irrefutable(pat, "local binding", Some(span)) + } else { + let Ok(irrefutable) = self.is_let_irrefutable(pat) else { return }; + if irrefutable { + report_irrefutable_let_patterns( + self.tcx, + self.lint_level, + self.let_source, + 1, + span, + ); + } } - let mut cx = self.new_cx(self.lint_level, true); - let Ok(tpat) = self.lower_pattern(&cx, pat) else { return }; - self.check_let_reachability(&mut cx, self.lint_level, source, tpat, span); } fn check_match( @@ -255,15 +269,13 @@ impl<'p, 'tcx> MatchVisitor<'_, 'p, 'tcx> { source: hir::MatchSource, expr_span: Span, ) { - let mut cx = self.new_cx(self.lint_level, true); + let cx = self.new_cx(Refutable); let mut tarms = Vec::with_capacity(arms.len()); for &arm in arms { let arm = &self.thir.arms[arm]; let got_error = self.with_lint_level(arm.lint_level, |this| { - let Ok(pat) = this.lower_pattern(&mut cx, &arm.pattern) else { - return true; - }; + let Ok(pat) = this.lower_pattern(&cx, &arm.pattern) else { return true }; let arm = MatchArm { pat, hir_id: this.lint_level, has_guard: arm.guard.is_some() }; tarms.push(arm); false @@ -299,34 +311,18 @@ impl<'p, 'tcx> MatchVisitor<'_, 'p, 'tcx> { debug_assert_eq!(pat.span.desugaring_kind(), Some(DesugaringKind::ForLoop)); let PatKind::Variant { ref subpatterns, .. } = pat.kind else { bug!() }; let [pat_field] = &subpatterns[..] else { bug!() }; - self.check_irrefutable(&pat_field.pattern, "`for` loop binding", None); + self.check_binding_is_irrefutable(&pat_field.pattern, "`for` loop binding", None); } else { - self.error = Err(non_exhaustive_match( + self.error = Err(report_non_exhaustive_match( &cx, self.thir, scrut_ty, scrut.span, witnesses, arms, expr_span, )); } } } - fn check_let_reachability( - &mut self, - cx: &mut MatchCheckCtxt<'p, 'tcx>, - pat_id: HirId, - source: LetSource, - pat: &'p DeconstructedPat<'p, 'tcx>, - span: Span, - ) { - if is_let_irrefutable(cx, pat_id, pat) { - irrefutable_let_patterns(cx.tcx, pat_id, source, 1, span); - } - } - #[instrument(level = "trace", skip(self))] - fn check_let_chain(&mut self, let_source: LetSource, expr: &Expr<'tcx>) { - if let LetSource::None = let_source { - return; - } - + fn check_let_chain(&mut self, expr: &Expr<'tcx>) { + assert!(self.let_source != LetSource::None); let top_expr_span = expr.span; // Lint level enclosing `next_expr`. @@ -336,7 +332,7 @@ impl<'p, 'tcx> MatchVisitor<'_, 'p, 'tcx> { // and record chain members that aren't let exprs. let mut chain_refutabilities = Vec::new(); - let mut got_lowering_error = false; + let mut got_error = false; let mut next_expr = Some(expr); while let Some(mut expr) = next_expr { while let ExprKind::Scope { value, lint_level, .. } = expr.kind { @@ -364,14 +360,15 @@ impl<'p, 'tcx> MatchVisitor<'_, 'p, 'tcx> { } let value = match expr.kind { ExprKind::Let { box ref pat, expr: _ } => { - let mut ncx = self.new_cx(expr_lint_level, true); - if let Ok(tpat) = self.lower_pattern(&mut ncx, pat) { - let refutable = !is_let_irrefutable(&mut ncx, expr_lint_level, tpat); - Some((expr.span, refutable)) - } else { - got_lowering_error = true; - None - } + self.with_lint_level(LintLevel::Explicit(expr_lint_level), |this| { + match this.is_let_irrefutable(pat) { + Ok(irrefutable) => Some((expr.span, !irrefutable)), + Err(_) => { + got_error = true; + None + } + } + }) } _ => None, }; @@ -380,17 +377,17 @@ impl<'p, 'tcx> MatchVisitor<'_, 'p, 'tcx> { debug!(?chain_refutabilities); chain_refutabilities.reverse(); - if got_lowering_error { + if got_error { return; } - // Third, emit the actual warnings. + // Emit the actual warnings. if chain_refutabilities.iter().all(|r| matches!(*r, Some((_, false)))) { // The entire chain is made up of irrefutable `let` statements - irrefutable_let_patterns( + report_irrefutable_let_patterns( self.tcx, self.lint_level, - let_source, + self.let_source, chain_refutabilities.len(), top_expr_span, ); @@ -409,7 +406,7 @@ impl<'p, 'tcx> MatchVisitor<'_, 'p, 'tcx> { // so can't always be moved out. // FIXME: Add checking whether the bindings are actually used in the prefix, // and lint if they are not. - if !matches!(let_source, LetSource::WhileLet | LetSource::IfLetGuard) { + if !matches!(self.let_source, LetSource::WhileLet | LetSource::IfLetGuard) { // Emit the lint let prefix = &chain_refutabilities[..until]; let span_start = prefix[0].unwrap().0; @@ -444,18 +441,33 @@ impl<'p, 'tcx> MatchVisitor<'_, 'p, 'tcx> { } } + fn analyze_binding( + &mut self, + pat: &Pat<'tcx>, + refutability: RefutableFlag, + ) -> Result<(MatchCheckCtxt<'p, 'tcx>, UsefulnessReport<'p, 'tcx>), ErrorGuaranteed> { + let cx = self.new_cx(refutability); + let pat = self.lower_pattern(&cx, pat)?; + let arms = [MatchArm { pat, hir_id: self.lint_level, has_guard: false }]; + let report = compute_match_usefulness(&cx, &arms, self.lint_level, pat.ty(), pat.span()); + Ok((cx, report)) + } + + fn is_let_irrefutable(&mut self, pat: &Pat<'tcx>) -> Result { + let (cx, report) = self.analyze_binding(pat, Refutable)?; + // Report if the pattern is unreachable, which can only occur when the type is + // uninhabited. This also reports unreachable sub-patterns. + report_arm_reachability(&cx, &report); + // If the list of witnesses is empty, the match is exhaustive, + // i.e. the `if let` pattern is irrefutable. + Ok(report.non_exhaustiveness_witnesses.is_empty()) + } + #[instrument(level = "trace", skip(self))] - fn check_irrefutable(&mut self, pat: &Pat<'tcx>, origin: &str, sp: Option) { - let mut cx = self.new_cx(self.lint_level, false); + fn check_binding_is_irrefutable(&mut self, pat: &Pat<'tcx>, origin: &str, sp: Option) { + let pattern_ty = pat.ty; - let Ok(pattern) = self.lower_pattern(&mut cx, pat) else { return }; - let pattern_ty = pattern.ty(); - let arm = MatchArm { pat: pattern, hir_id: self.lint_level, has_guard: false }; - let report = - compute_match_usefulness(&cx, &[arm], self.lint_level, pattern_ty, pattern.span()); - - // Note: we ignore whether the pattern is unreachable (i.e. whether the type is empty). We - // only care about exhaustiveness here. + let Ok((cx, report)) = self.analyze_binding(pat, Irrefutable) else { return }; let witnesses = report.non_exhaustiveness_witnesses; if witnesses.is_empty() { // The pattern is irrefutable. @@ -509,16 +521,16 @@ impl<'p, 'tcx> MatchVisitor<'_, 'p, 'tcx> { // Emit an extra note if the first uncovered witness would be uninhabited // if we disregard visibility. - let witness_1_is_privately_uninhabited = if cx.tcx.features().exhaustive_patterns + let witness_1_is_privately_uninhabited = if self.tcx.features().exhaustive_patterns && let Some(witness_1) = witnesses.get(0) && let ty::Adt(adt, args) = witness_1.ty().kind() && adt.is_enum() && let Constructor::Variant(variant_index) = witness_1.ctor() { let variant = adt.variant(*variant_index); - let inhabited = variant.inhabited_predicate(cx.tcx, *adt).instantiate(cx.tcx, args); - assert!(inhabited.apply(cx.tcx, cx.param_env, cx.module)); - !inhabited.apply_ignore_module(cx.tcx, cx.param_env) + let inhabited = variant.inhabited_predicate(self.tcx, *adt).instantiate(self.tcx, args); + assert!(inhabited.apply(self.tcx, cx.param_env, cx.module)); + !inhabited.apply_ignore_module(self.tcx, cx.param_env) } else { false }; @@ -673,7 +685,7 @@ fn check_for_bindings_named_same_as_variants( BindingsWithVariantName { // If this is an irrefutable pattern, and there's > 1 variant, // then we can't actually match on this. Applying the below - // suggestion would produce code that breaks on `check_irrefutable`. + // suggestion would produce code that breaks on `check_binding_is_irrefutable`. suggestion: if rf == Refutable || variant_count == 1 { Some(pat.span) } else { @@ -686,7 +698,7 @@ fn check_for_bindings_named_same_as_variants( } } -fn irrefutable_let_patterns( +fn report_irrefutable_let_patterns( tcx: TyCtxt<'_>, id: HirId, source: LetSource, @@ -700,7 +712,7 @@ fn irrefutable_let_patterns( } match source { - LetSource::None => bug!(), + LetSource::None | LetSource::PlainLet => bug!(), LetSource::IfLet => emit_diag!(IrrefutableLetPatternsIfLet), LetSource::IfLetGuard => emit_diag!(IrrefutableLetPatternsIfLetGuard), LetSource::LetElse => emit_diag!(IrrefutableLetPatternsLetElse), @@ -708,24 +720,6 @@ fn irrefutable_let_patterns( } } -fn is_let_irrefutable<'p, 'tcx>( - cx: &mut MatchCheckCtxt<'p, 'tcx>, - pat_id: HirId, - pat: &'p DeconstructedPat<'p, 'tcx>, -) -> bool { - let arms = [MatchArm { pat, hir_id: pat_id, has_guard: false }]; - let report = compute_match_usefulness(&cx, &arms, pat_id, pat.ty(), pat.span()); - - // Report if the pattern is unreachable, which can only occur when the type is uninhabited. - // This also reports unreachable sub-patterns though, so we can't just replace it with an - // `is_uninhabited` check. - report_arm_reachability(&cx, &report); - - // If the list of witnesses is empty, the match is exhaustive, - // i.e. the `if let` pattern is irrefutable. - report.non_exhaustiveness_witnesses.is_empty() -} - /// Report unreachable arms, if any. fn report_arm_reachability<'p, 'tcx>( cx: &MatchCheckCtxt<'p, 'tcx>, @@ -776,7 +770,7 @@ fn pat_is_catchall(pat: &DeconstructedPat<'_, '_>) -> bool { } /// Report that a match is not exhaustive. -fn non_exhaustive_match<'p, 'tcx>( +fn report_non_exhaustive_match<'p, 'tcx>( cx: &MatchCheckCtxt<'p, 'tcx>, thir: &Thir<'tcx>, scrut_ty: Ty<'tcx>,