diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs index 30ff07ee6c3..da6244acb31 100644 --- a/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs +++ b/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs @@ -1234,6 +1234,7 @@ fn report_selection_error( _ => None, }; + let found_node = found_did.and_then(|did| self.tcx.hir().get_if_local(did)); let found_span = found_did.and_then(|did| self.tcx.hir().span_if_local(did)); if self.reported_closure_mismatch.borrow().contains(&(span, found_span)) { @@ -1287,6 +1288,7 @@ fn report_selection_error( found_trait_ref, expected_trait_ref, obligation.cause.code(), + found_node, ) } else { let (closure_span, closure_arg_span, found) = found_did diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs index 40c81025471..2923ad352a7 100644 --- a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs +++ b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs @@ -258,6 +258,7 @@ fn report_closure_arg_mismatch( found: ty::PolyTraitRef<'tcx>, expected: ty::PolyTraitRef<'tcx>, cause: &ObligationCauseCode<'tcx>, + found_node: Option<Node<'_>>, ) -> DiagnosticBuilder<'tcx, ErrorGuaranteed>; fn note_conflicting_closure_bounds( @@ -1695,6 +1696,7 @@ fn report_closure_arg_mismatch( found: ty::PolyTraitRef<'tcx>, expected: ty::PolyTraitRef<'tcx>, cause: &ObligationCauseCode<'tcx>, + found_node: Option<Node<'_>>, ) -> DiagnosticBuilder<'tcx, ErrorGuaranteed> { pub(crate) fn build_fn_sig_ty<'tcx>( infcx: &InferCtxt<'tcx>, @@ -1756,6 +1758,75 @@ pub(crate) fn build_fn_sig_ty<'tcx>( self.note_conflicting_closure_bounds(cause, &mut err); + let found_args = match found.kind() { + ty::FnPtr(f) => f.inputs().skip_binder().iter(), + kind => { + span_bug!(span, "found was converted to a FnPtr above but is now {:?}", kind) + } + }; + let expected_args = match expected.kind() { + ty::FnPtr(f) => f.inputs().skip_binder().iter(), + kind => { + span_bug!(span, "expected was converted to a FnPtr above but is now {:?}", kind) + } + }; + + if let Some(found_node) = found_node { + let fn_decl = match found_node { + Node::Expr(expr) => match &expr.kind { + hir::ExprKind::Closure(hir::Closure { fn_decl, .. }) => fn_decl, + kind => { + span_bug!(found_span, "expression must be a closure but is {:?}", kind) + } + }, + Node::Item(item) => match &item.kind { + hir::ItemKind::Fn(signature, _generics, _body) => signature.decl, + kind => { + span_bug!(found_span, "item must be a function but is {:?}", kind) + } + }, + node => { + span_bug!(found_span, "node must be a expr or item but is {:?}", node) + } + }; + + let arg_spans = fn_decl.inputs.iter().map(|ty| ty.span); + + fn get_deref_type_and_refs(mut ty: Ty<'_>) -> (Ty<'_>, usize) { + let mut refs = 0; + + while let ty::Ref(_, new_ty, _) = ty.kind() { + ty = *new_ty; + refs += 1; + } + + (ty, refs) + } + + for ((found_arg, expected_arg), arg_span) in + found_args.zip(expected_args).zip(arg_spans) + { + let (found_ty, found_refs) = get_deref_type_and_refs(*found_arg); + let (expected_ty, expected_refs) = get_deref_type_and_refs(*expected_arg); + + if found_ty == expected_ty { + let hint = if found_refs < expected_refs { + "hint: consider borrowing here:" + } else if found_refs == expected_refs { + continue; + } else { + "hint: consider removing the borrow:" + }; + err.span_suggestion_verbose( + arg_span, + hint, + expected_arg.to_string(), + Applicability::MaybeIncorrect, + ); + } + } + } + err }