From 3764af6119ee32061de6ff25effb2bf72dced487 Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Wed, 10 Apr 2024 18:42:48 -0400 Subject: [PATCH] Use suggest_impl_trait in return type suggestion --- compiler/rustc_hir_analysis/src/collect.rs | 98 +++++++++++-------- .../src/fn_ctxt/suggestions.rs | 12 +-- .../async-closure-gate.afn.stderr | 4 +- .../async-closure-gate.nofeat.stderr | 4 +- tests/ui/suggestions/return-closures.stderr | 2 +- .../typeck/return_type_containing_closure.rs | 2 +- .../return_type_containing_closure.stderr | 6 +- 7 files changed, 68 insertions(+), 60 deletions(-) diff --git a/compiler/rustc_hir_analysis/src/collect.rs b/compiler/rustc_hir_analysis/src/collect.rs index a705d3bc107..c1bf65367aa 100644 --- a/compiler/rustc_hir_analysis/src/collect.rs +++ b/compiler/rustc_hir_analysis/src/collect.rs @@ -30,7 +30,7 @@ use rustc_middle::ty::util::{Discr, IntTypeExt}; use rustc_middle::ty::{self, AdtKind, Const, IsSuggestable, ToPredicate, Ty, TyCtxt}; use rustc_span::symbol::{kw, sym, Ident, Symbol}; -use rustc_span::Span; +use rustc_span::{Span, DUMMY_SP}; use rustc_target::abi::FieldIdx; use rustc_target::spec::abi; use rustc_trait_selection::infer::InferCtxtExt; @@ -1383,7 +1383,9 @@ fn infer_return_ty_for_fn_sig<'tcx>( Applicability::MachineApplicable, ); should_recover = true; - } else if let Some(sugg) = suggest_impl_trait(tcx, ret_ty, ty.span, def_id) { + } else if let Some(sugg) = + suggest_impl_trait(&tcx.infer_ctxt().build(), tcx.param_env(def_id), ret_ty) + { diag.span_suggestion( ty.span, "replace with an appropriate return type", @@ -1426,11 +1428,10 @@ fn infer_return_ty_for_fn_sig<'tcx>( } } -fn suggest_impl_trait<'tcx>( - tcx: TyCtxt<'tcx>, +pub fn suggest_impl_trait<'tcx>( + infcx: &InferCtxt<'tcx>, + param_env: ty::ParamEnv<'tcx>, ret_ty: Ty<'tcx>, - span: Span, - def_id: LocalDefId, ) -> Option { let format_as_assoc: fn(_, _, _, _, _) -> _ = |tcx: TyCtxt<'tcx>, @@ -1464,24 +1465,28 @@ fn suggest_impl_trait<'tcx>( for (trait_def_id, assoc_item_def_id, formatter) in [ ( - tcx.get_diagnostic_item(sym::Iterator), - tcx.get_diagnostic_item(sym::IteratorItem), + infcx.tcx.get_diagnostic_item(sym::Iterator), + infcx.tcx.get_diagnostic_item(sym::IteratorItem), format_as_assoc, ), ( - tcx.lang_items().future_trait(), - tcx.get_diagnostic_item(sym::FutureOutput), + infcx.tcx.lang_items().future_trait(), + infcx.tcx.get_diagnostic_item(sym::FutureOutput), format_as_assoc, ), - (tcx.lang_items().fn_trait(), tcx.lang_items().fn_once_output(), format_as_parenthesized), ( - tcx.lang_items().fn_mut_trait(), - tcx.lang_items().fn_once_output(), + infcx.tcx.lang_items().fn_trait(), + infcx.tcx.lang_items().fn_once_output(), format_as_parenthesized, ), ( - tcx.lang_items().fn_once_trait(), - tcx.lang_items().fn_once_output(), + infcx.tcx.lang_items().fn_mut_trait(), + infcx.tcx.lang_items().fn_once_output(), + format_as_parenthesized, + ), + ( + infcx.tcx.lang_items().fn_once_trait(), + infcx.tcx.lang_items().fn_once_output(), format_as_parenthesized, ), ] { @@ -1491,36 +1496,45 @@ fn suggest_impl_trait<'tcx>( let Some(assoc_item_def_id) = assoc_item_def_id else { continue; }; - if tcx.def_kind(assoc_item_def_id) != DefKind::AssocTy { + if infcx.tcx.def_kind(assoc_item_def_id) != DefKind::AssocTy { continue; } - let param_env = tcx.param_env(def_id); - let infcx = tcx.infer_ctxt().build(); - let args = ty::GenericArgs::for_item(tcx, trait_def_id, |param, _| { - if param.index == 0 { ret_ty.into() } else { infcx.var_for_def(span, param) } + let sugg = infcx.probe(|_| { + let args = ty::GenericArgs::for_item(infcx.tcx, trait_def_id, |param, _| { + if param.index == 0 { ret_ty.into() } else { infcx.var_for_def(DUMMY_SP, param) } + }); + if !infcx + .type_implements_trait(trait_def_id, args, param_env) + .must_apply_modulo_regions() + { + return None; + } + let ocx = ObligationCtxt::new(&infcx); + let item_ty = ocx.normalize( + &ObligationCause::dummy(), + param_env, + Ty::new_projection(infcx.tcx, assoc_item_def_id, args), + ); + // FIXME(compiler-errors): We may benefit from resolving regions here. + if ocx.select_where_possible().is_empty() + && let item_ty = infcx.resolve_vars_if_possible(item_ty) + && let Some(item_ty) = item_ty.make_suggestable(infcx.tcx, false, None) + && let Some(sugg) = formatter( + infcx.tcx, + infcx.resolve_vars_if_possible(args), + trait_def_id, + assoc_item_def_id, + item_ty, + ) + { + return Some(sugg); + } + + None }); - if !infcx.type_implements_trait(trait_def_id, args, param_env).must_apply_modulo_regions() { - continue; - } - let ocx = ObligationCtxt::new(&infcx); - let item_ty = ocx.normalize( - &ObligationCause::misc(span, def_id), - param_env, - Ty::new_projection(tcx, assoc_item_def_id, args), - ); - // FIXME(compiler-errors): We may benefit from resolving regions here. - if ocx.select_where_possible().is_empty() - && let item_ty = infcx.resolve_vars_if_possible(item_ty) - && let Some(item_ty) = item_ty.make_suggestable(tcx, false, None) - && let Some(sugg) = formatter( - tcx, - infcx.resolve_vars_if_possible(args), - trait_def_id, - assoc_item_def_id, - item_ty, - ) - { - return Some(sugg); + + if sugg.is_some() { + return sugg; } } None diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/suggestions.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/suggestions.rs index 442bfd75746..1e1136ef467 100644 --- a/compiler/rustc_hir_typeck/src/fn_ctxt/suggestions.rs +++ b/compiler/rustc_hir_typeck/src/fn_ctxt/suggestions.rs @@ -20,6 +20,7 @@ CoroutineDesugaring, CoroutineKind, CoroutineSource, Expr, ExprKind, GenericBound, HirId, Node, Path, QPath, Stmt, StmtKind, TyKind, WherePredicate, }; +use rustc_hir_analysis::collect::suggest_impl_trait; use rustc_hir_analysis::hir_ty_lowering::HirTyLowerer; use rustc_infer::traits::{self}; use rustc_middle::lint::in_external_macro; @@ -814,17 +815,10 @@ pub(in super::super) fn suggest_missing_return_type( errors::AddReturnTypeSuggestion::Add { span, found: found.to_string() }, ); return true; - } else if let ty::Closure(_, args) = found.kind() - // FIXME(compiler-errors): Get better at printing binders... - && let closure = args.as_closure() - && closure.sig().is_suggestable(self.tcx, false) - { + } else if let Some(sugg) = suggest_impl_trait(self, self.param_env, found) { err.subdiagnostic( self.dcx(), - errors::AddReturnTypeSuggestion::Add { - span, - found: closure.print_as_impl_trait().to_string(), - }, + errors::AddReturnTypeSuggestion::Add { span, found: sugg }, ); return true; } else { diff --git a/tests/ui/async-await/track-caller/async-closure-gate.afn.stderr b/tests/ui/async-await/track-caller/async-closure-gate.afn.stderr index 92f38d5a796..640d946421a 100644 --- a/tests/ui/async-await/track-caller/async-closure-gate.afn.stderr +++ b/tests/ui/async-await/track-caller/async-closure-gate.afn.stderr @@ -62,7 +62,7 @@ error[E0308]: mismatched types --> $DIR/async-closure-gate.rs:27:5 | LL | fn foo3() { - | - help: a return type might be missing here: `-> _` + | - help: try adding a return type: `-> impl Future` LL | / async { LL | | LL | | let _ = #[track_caller] || { @@ -78,7 +78,7 @@ error[E0308]: mismatched types --> $DIR/async-closure-gate.rs:44:5 | LL | fn foo5() { - | - help: a return type might be missing here: `-> _` + | - help: try adding a return type: `-> impl Future` LL | / async { LL | | LL | | let _ = || { diff --git a/tests/ui/async-await/track-caller/async-closure-gate.nofeat.stderr b/tests/ui/async-await/track-caller/async-closure-gate.nofeat.stderr index 92f38d5a796..640d946421a 100644 --- a/tests/ui/async-await/track-caller/async-closure-gate.nofeat.stderr +++ b/tests/ui/async-await/track-caller/async-closure-gate.nofeat.stderr @@ -62,7 +62,7 @@ error[E0308]: mismatched types --> $DIR/async-closure-gate.rs:27:5 | LL | fn foo3() { - | - help: a return type might be missing here: `-> _` + | - help: try adding a return type: `-> impl Future` LL | / async { LL | | LL | | let _ = #[track_caller] || { @@ -78,7 +78,7 @@ error[E0308]: mismatched types --> $DIR/async-closure-gate.rs:44:5 | LL | fn foo5() { - | - help: a return type might be missing here: `-> _` + | - help: try adding a return type: `-> impl Future` LL | / async { LL | | LL | | let _ = || { diff --git a/tests/ui/suggestions/return-closures.stderr b/tests/ui/suggestions/return-closures.stderr index 97c13200ac3..ef1f50b8a6c 100644 --- a/tests/ui/suggestions/return-closures.stderr +++ b/tests/ui/suggestions/return-closures.stderr @@ -2,7 +2,7 @@ error[E0308]: mismatched types --> $DIR/return-closures.rs:3:5 | LL | fn foo() { - | - help: try adding a return type: `-> impl for<'a> Fn(&'a i32) -> i32` + | - help: try adding a return type: `-> impl FnOnce(&i32) -> i32` LL | LL | |x: &i32| 1i32 | ^^^^^^^^^^^^^^ expected `()`, found closure diff --git a/tests/ui/typeck/return_type_containing_closure.rs b/tests/ui/typeck/return_type_containing_closure.rs index 8b826daeede..b81cac0a58a 100644 --- a/tests/ui/typeck/return_type_containing_closure.rs +++ b/tests/ui/typeck/return_type_containing_closure.rs @@ -1,5 +1,5 @@ #[allow(unused)] -fn foo() { //~ HELP a return type might be missing here +fn foo() { //~ HELP try adding a return type vec!['a'].iter().map(|c| c) //~^ ERROR mismatched types [E0308] //~| NOTE expected `()`, found `Map, ...>` diff --git a/tests/ui/typeck/return_type_containing_closure.stderr b/tests/ui/typeck/return_type_containing_closure.stderr index ea9c74be362..3f14650a82c 100644 --- a/tests/ui/typeck/return_type_containing_closure.stderr +++ b/tests/ui/typeck/return_type_containing_closure.stderr @@ -10,10 +10,10 @@ help: consider using a semicolon here | LL | vec!['a'].iter().map(|c| c); | + -help: a return type might be missing here +help: try adding a return type | -LL | fn foo() -> _ { - | ++++ +LL | fn foo() -> impl Iterator { + | ++++++++++++++++++++++++++++++ error: aborting due to 1 previous error