diff --git a/compiler/rustc_hir_typeck/src/closure.rs b/compiler/rustc_hir_typeck/src/closure.rs index ac7ed3e26f9..53cf121dfeb 100644 --- a/compiler/rustc_hir_typeck/src/closure.rs +++ b/compiler/rustc_hir_typeck/src/closure.rs @@ -424,9 +424,10 @@ fn visit_ty(&mut self, t: Ty<'tcx>) -> Self::Result { if let Some(trait_def_id) = trait_def_id { let found_kind = match closure_kind { hir::ClosureKind::Closure => self.tcx.fn_trait_kind_from_def_id(trait_def_id), - hir::ClosureKind::CoroutineClosure(hir::CoroutineDesugaring::Async) => { - self.tcx.async_fn_trait_kind_from_def_id(trait_def_id) - } + hir::ClosureKind::CoroutineClosure(hir::CoroutineDesugaring::Async) => self + .tcx + .async_fn_trait_kind_from_def_id(trait_def_id) + .or_else(|| self.tcx.fn_trait_kind_from_def_id(trait_def_id)), _ => None, }; @@ -470,14 +471,37 @@ fn deduce_sig_from_projection( // for closures and async closures, respectively. match closure_kind { hir::ClosureKind::Closure - if self.tcx.fn_trait_kind_from_def_id(trait_def_id).is_some() => {} + if self.tcx.fn_trait_kind_from_def_id(trait_def_id).is_some() => + { + self.extract_sig_from_projection(cause_span, projection) + } hir::ClosureKind::CoroutineClosure(hir::CoroutineDesugaring::Async) - if self.tcx.async_fn_trait_kind_from_def_id(trait_def_id).is_some() => {} - _ => return None, + if self.tcx.async_fn_trait_kind_from_def_id(trait_def_id).is_some() => + { + self.extract_sig_from_projection(cause_span, projection) + } + // It's possible we've passed the closure to a (somewhat out-of-fashion) + // `F: FnOnce() -> Fut, Fut: Future` style bound. Let's still + // guide inference here, since it's beneficial for the user. + hir::ClosureKind::CoroutineClosure(hir::CoroutineDesugaring::Async) + if self.tcx.fn_trait_kind_from_def_id(trait_def_id).is_some() => + { + self.extract_sig_from_projection_and_future_bound(cause_span, projection) + } + _ => None, } + } + + /// Given an `FnOnce::Output` or `AsyncFn::Output` projection, extract the args + /// and return type to infer a [`ty::PolyFnSig`] for the closure. + fn extract_sig_from_projection( + &self, + cause_span: Option, + projection: ty::PolyProjectionPredicate<'tcx>, + ) -> Option> { + let projection = self.resolve_vars_if_possible(projection); let arg_param_ty = projection.skip_binder().projection_term.args.type_at(1); - let arg_param_ty = self.resolve_vars_if_possible(arg_param_ty); debug!(?arg_param_ty); let ty::Tuple(input_tys) = *arg_param_ty.kind() else { @@ -486,7 +510,6 @@ fn deduce_sig_from_projection( // Since this is a return parameter type it is safe to unwrap. let ret_param_ty = projection.skip_binder().term.expect_type(); - let ret_param_ty = self.resolve_vars_if_possible(ret_param_ty); debug!(?ret_param_ty); let sig = projection.rebind(self.tcx.mk_fn_sig( @@ -500,6 +523,69 @@ fn deduce_sig_from_projection( Some(ExpectedSig { cause_span, sig }) } + /// When an async closure is passed to a function that has a "two-part" `Fn` + /// and `Future` trait bound, like: + /// + /// ```rust + /// use std::future::Future; + /// + /// fn not_exactly_an_async_closure(_f: F) + /// where + /// F: FnOnce(String, u32) -> Fut, + /// Fut: Future, + /// {} + /// ``` + /// + /// The we want to be able to extract the signature to guide inference in the async + /// closure. We will have two projection predicates registered in this case. First, + /// we identify the `FnOnce` bound, and if the output type is + /// an inference variable `?Fut`, we check if that is bounded by a `Future` + /// projection. + fn extract_sig_from_projection_and_future_bound( + &self, + cause_span: Option, + projection: ty::PolyProjectionPredicate<'tcx>, + ) -> Option> { + let projection = self.resolve_vars_if_possible(projection); + + let arg_param_ty = projection.skip_binder().projection_term.args.type_at(1); + debug!(?arg_param_ty); + + let ty::Tuple(input_tys) = *arg_param_ty.kind() else { + return None; + }; + + // If the return type is a type variable, look for bounds on it. + // We could theoretically support other kinds of return types here, + // but none of them would be useful, since async closures return + // concrete anonymous future types, and their futures are not coerced + // into any other type within the body of the async closure. + let ty::Infer(ty::TyVar(return_vid)) = *projection.skip_binder().term.expect_type().kind() + else { + return None; + }; + + // FIXME: We may want to elaborate here, though I assume this will be exceedingly rare. + for bound in self.obligations_for_self_ty(return_vid) { + if let Some(ret_projection) = bound.predicate.as_projection_clause() + && let Some(ret_projection) = ret_projection.no_bound_vars() + && self.tcx.is_lang_item(ret_projection.def_id(), LangItem::FutureOutput) + { + let sig = projection.rebind(self.tcx.mk_fn_sig( + input_tys, + ret_projection.term.expect_type(), + false, + hir::Safety::Safe, + Abi::Rust, + )); + + return Some(ExpectedSig { cause_span, sig }); + } + } + + None + } + fn sig_of_closure( &self, expr_def_id: LocalDefId, diff --git a/tests/ui/async-await/async-closures/signature-inference-from-two-part-bound.rs b/tests/ui/async-await/async-closures/signature-inference-from-two-part-bound.rs new file mode 100644 index 00000000000..0e2d1ef1208 --- /dev/null +++ b/tests/ui/async-await/async-closures/signature-inference-from-two-part-bound.rs @@ -0,0 +1,27 @@ +//@ edition: 2021 +//@ check-pass +//@ revisions: current next +//@ ignore-compare-mode-next-solver (explicit revisions) +//@[next] compile-flags: -Znext-solver + +#![feature(async_closure)] + +use std::future::Future; +use std::any::Any; + +struct Struct; +impl Struct { + fn method(&self) {} +} + +fn fake_async_closure(_: F) +where + F: Fn(Struct) -> Fut, + Fut: Future, +{} + +fn main() { + fake_async_closure(async |s| { + s.method(); + }) +}