From d25abdc0c52cc08cdd290be325f1be04f3cea548 Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Wed, 13 Jul 2022 05:39:01 +0000 Subject: [PATCH] Point out custom Fn-family trait impl --- compiler/rustc_middle/src/ty/closure.rs | 8 + .../rustc_typeck/src/check/fn_ctxt/checks.rs | 143 +++++++++++------- .../overloaded-calls-bad.stderr | 16 ++ 3 files changed, 114 insertions(+), 53 deletions(-) diff --git a/compiler/rustc_middle/src/ty/closure.rs b/compiler/rustc_middle/src/ty/closure.rs index f5ce43f3afb..8ead0512274 100644 --- a/compiler/rustc_middle/src/ty/closure.rs +++ b/compiler/rustc_middle/src/ty/closure.rs @@ -128,6 +128,14 @@ impl<'tcx> ClosureKind { None } } + + pub fn to_def_id(&self, tcx: TyCtxt<'_>) -> DefId { + match self { + ClosureKind::Fn => tcx.lang_items().fn_once_trait().unwrap(), + ClosureKind::FnMut => tcx.lang_items().fn_mut_trait().unwrap(), + ClosureKind::FnOnce => tcx.lang_items().fn_trait().unwrap(), + } + } } /// A composite describing a `Place` that is captured by a closure. diff --git a/compiler/rustc_typeck/src/check/fn_ctxt/checks.rs b/compiler/rustc_typeck/src/check/fn_ctxt/checks.rs index 3dd63d74c3f..ec045d3e70c 100644 --- a/compiler/rustc_typeck/src/check/fn_ctxt/checks.rs +++ b/compiler/rustc_typeck/src/check/fn_ctxt/checks.rs @@ -21,6 +21,7 @@ use rustc_hir::def_id::DefId; use rustc_hir::{ExprKind, Node, QPath}; use rustc_index::vec::IndexVec; use rustc_infer::infer::error_reporting::{FailureCode, ObligationCauseExt}; +use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind}; use rustc_infer::infer::InferOk; use rustc_infer::infer::TypeTrace; use rustc_middle::ty::adjustment::AllowTwoPhase; @@ -29,7 +30,9 @@ use rustc_middle::ty::{self, DefIdTree, IsSuggestable, Ty}; use rustc_session::Session; use rustc_span::symbol::Ident; use rustc_span::{self, Span}; -use rustc_trait_selection::traits::{self, ObligationCauseCode, StatementAsExpression}; +use rustc_trait_selection::traits::{ + self, ObligationCauseCode, SelectionContext, StatementAsExpression, +}; use std::iter; use std::slice; @@ -393,41 +396,6 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { } if !call_appears_satisfied { - // Next, let's construct the error - let (error_span, full_call_span, ctor_of) = match &call_expr.kind { - hir::ExprKind::Call( - hir::Expr { - span, - kind: - hir::ExprKind::Path(hir::QPath::Resolved( - _, - hir::Path { res: Res::Def(DefKind::Ctor(of, _), _), .. }, - )), - .. - }, - _, - ) => (call_span, *span, Some(of)), - hir::ExprKind::Call(hir::Expr { span, .. }, _) => (call_span, *span, None), - hir::ExprKind::MethodCall(path_segment, _, span) => { - let ident_span = path_segment.ident.span; - let ident_span = if let Some(args) = path_segment.args { - ident_span.with_hi(args.span_ext.hi()) - } else { - ident_span - }; - ( - *span, ident_span, None, // methods are never ctors - ) - } - k => span_bug!(call_span, "checking argument types on a non-call: `{:?}`", k), - }; - let args_span = error_span.trim_start(full_call_span).unwrap_or(error_span); - let call_name = match ctor_of { - Some(CtorOf::Struct) => "struct", - Some(CtorOf::Variant) => "enum variant", - None => "function", - }; - let compatibility_diagonal = IndexVec::from_raw(compatibility_diagonal); let provided_args = IndexVec::from_iter(provided_args.iter().take(if c_variadic { minimum_input_count @@ -451,13 +419,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { compatibility_diagonal, formal_and_expected_inputs, provided_args, - full_call_span, - error_span, - args_span, - call_name, c_variadic, err_code, fn_def_id, + call_span, call_expr, ); } @@ -468,15 +433,47 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { compatibility_diagonal: IndexVec>, formal_and_expected_inputs: IndexVec, Ty<'tcx>)>, provided_args: IndexVec>, - full_call_span: Span, - error_span: Span, - args_span: Span, - call_name: &str, c_variadic: bool, err_code: &str, fn_def_id: Option, + call_span: Span, call_expr: &hir::Expr<'tcx>, ) { + // Next, let's construct the error + let (error_span, full_call_span, ctor_of) = match &call_expr.kind { + hir::ExprKind::Call( + hir::Expr { + span, + kind: + hir::ExprKind::Path(hir::QPath::Resolved( + _, + hir::Path { res: Res::Def(DefKind::Ctor(of, _), _), .. }, + )), + .. + }, + _, + ) => (call_span, *span, Some(of)), + hir::ExprKind::Call(hir::Expr { span, .. }, _) => (call_span, *span, None), + hir::ExprKind::MethodCall(path_segment, _, span) => { + let ident_span = path_segment.ident.span; + let ident_span = if let Some(args) = path_segment.args { + ident_span.with_hi(args.span_ext.hi()) + } else { + ident_span + }; + ( + *span, ident_span, None, // methods are never ctors + ) + } + k => span_bug!(call_span, "checking argument types on a non-call: `{:?}`", k), + }; + let args_span = error_span.trim_start(full_call_span).unwrap_or(error_span); + let call_name = match ctor_of { + Some(CtorOf::Struct) => "struct", + Some(CtorOf::Variant) => "enum variant", + None => "function", + }; + // Don't print if it has error types or is just plain `_` fn has_error_or_infer<'tcx>(tys: impl IntoIterator>) -> bool { tys.into_iter().any(|ty| ty.references_error() || ty.is_ty_var()) @@ -1818,17 +1815,22 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { fn label_fn_like( &self, err: &mut rustc_errors::DiagnosticBuilder<'tcx, rustc_errors::ErrorGuaranteed>, - def_id: Option, + callable_def_id: Option, callee_ty: Option>, ) { - let Some(mut def_id) = def_id else { + let Some(mut def_id) = callable_def_id else { return; }; if let Some(assoc_item) = self.tcx.opt_associated_item(def_id) - && let trait_def_id = assoc_item.trait_item_def_id.unwrap_or_else(|| self.tcx.parent(def_id)) + // Possibly points at either impl or trait item, so try to get it + // to point to trait item, then get the parent. + // This parent might be an impl in the case of an inherent function, + // but the next check will fail. + && let maybe_trait_item_def_id = assoc_item.trait_item_def_id.unwrap_or(def_id) + && let maybe_trait_def_id = self.tcx.parent(maybe_trait_item_def_id) // Just an easy way to check "trait_def_id == Fn/FnMut/FnOnce" - && ty::ClosureKind::from_def_id(self.tcx, trait_def_id).is_some() + && let Some(call_kind) = ty::ClosureKind::from_def_id(self.tcx, maybe_trait_def_id) && let Some(callee_ty) = callee_ty { let callee_ty = callee_ty.peel_refs(); @@ -1853,7 +1855,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { std::iter::zip(instantiated.predicates, instantiated.spans) { if let ty::PredicateKind::Trait(pred) = predicate.kind().skip_binder() - && pred.self_ty() == callee_ty + && pred.self_ty().peel_refs() == callee_ty && ty::ClosureKind::from_def_id(self.tcx, pred.def_id()).is_some() { err.span_note(span, "callable defined here"); @@ -1862,11 +1864,46 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { } } } - ty::Opaque(new_def_id, _) | ty::Closure(new_def_id, _) | ty::FnDef(new_def_id, _) => { + ty::Opaque(new_def_id, _) + | ty::Closure(new_def_id, _) + | ty::FnDef(new_def_id, _) => { def_id = new_def_id; } _ => { - return; + // Look for a user-provided impl of a `Fn` trait, and point to it. + let new_def_id = self.probe(|_| { + let trait_ref = ty::TraitRef::new( + call_kind.to_def_id(self.tcx), + self.tcx.mk_substs([ + ty::GenericArg::from(callee_ty), + self.next_ty_var(TypeVariableOrigin { + kind: TypeVariableOriginKind::MiscVariable, + span: rustc_span::DUMMY_SP, + }) + .into(), + ].into_iter()), + ); + let obligation = traits::Obligation::new( + traits::ObligationCause::dummy(), + self.param_env, + ty::Binder::dummy(ty::TraitPredicate { + trait_ref, + constness: ty::BoundConstness::NotConst, + polarity: ty::ImplPolarity::Positive, + }), + ); + match SelectionContext::new(&self).select(&obligation) { + Ok(Some(traits::ImplSource::UserDefined(impl_source))) => { + Some(impl_source.impl_def_id) + } + _ => None + } + }); + if let Some(new_def_id) = new_def_id { + def_id = new_def_id; + } else { + return; + } } } } @@ -1888,8 +1925,8 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { let def_kind = self.tcx.def_kind(def_id); err.span_note(spans, &format!("{} defined here", def_kind.descr(def_id))); - } else if let def_kind @ (DefKind::Closure | DefKind::OpaqueTy) = self.tcx.def_kind(def_id) - { + } else { + let def_kind = self.tcx.def_kind(def_id); err.span_note( self.tcx.def_span(def_id), &format!("{} defined here", def_kind.descr(def_id)), diff --git a/src/test/ui/mismatched_types/overloaded-calls-bad.stderr b/src/test/ui/mismatched_types/overloaded-calls-bad.stderr index 5ed15468fd6..475ea9dfaf1 100644 --- a/src/test/ui/mismatched_types/overloaded-calls-bad.stderr +++ b/src/test/ui/mismatched_types/overloaded-calls-bad.stderr @@ -5,6 +5,12 @@ LL | let ans = s("what"); | - ^^^^^^ expected `isize`, found `&str` | | | arguments to this function are incorrect + | +note: implementation defined here + --> $DIR/overloaded-calls-bad.rs:10:1 + | +LL | impl FnMut<(isize,)> for S { + | ^^^^^^^^^^^^^^^^^^^^^^^^^^ error[E0057]: this function takes 1 argument but 0 arguments were supplied --> $DIR/overloaded-calls-bad.rs:29:15 @@ -12,6 +18,11 @@ error[E0057]: this function takes 1 argument but 0 arguments were supplied LL | let ans = s(); | ^-- an argument of type `isize` is missing | +note: implementation defined here + --> $DIR/overloaded-calls-bad.rs:10:1 + | +LL | impl FnMut<(isize,)> for S { + | ^^^^^^^^^^^^^^^^^^^^^^^^^^ help: provide the argument | LL | let ans = s(/* isize */); @@ -25,6 +36,11 @@ LL | let ans = s("burma", "shave"); | | | expected `isize`, found `&str` | +note: implementation defined here + --> $DIR/overloaded-calls-bad.rs:10:1 + | +LL | impl FnMut<(isize,)> for S { + | ^^^^^^^^^^^^^^^^^^^^^^^^^^ help: remove the extra argument | LL | let ans = s(/* isize */);