From 88a86d4ff985265e8b1501dc4ea33e1572670666 Mon Sep 17 00:00:00 2001 From: Jonas Schievink Date: Tue, 29 Jun 2021 17:35:37 +0200 Subject: [PATCH] Fix deduction of `dyn Fn` closure parameter types --- crates/hir_ty/src/display.rs | 28 ++++----- crates/hir_ty/src/infer.rs | 1 + crates/hir_ty/src/infer/closure.rs | 93 ++++++++++++++++++++++++++++++ crates/hir_ty/src/infer/expr.rs | 6 +- crates/hir_ty/src/tests/traits.rs | 20 +++++++ crates/hir_ty/src/utils.rs | 12 +++- 6 files changed, 139 insertions(+), 21 deletions(-) create mode 100644 crates/hir_ty/src/infer/closure.rs diff --git a/crates/hir_ty/src/display.rs b/crates/hir_ty/src/display.rs index 44f843bf383..60328481950 100644 --- a/crates/hir_ty/src/display.rs +++ b/crates/hir_ty/src/display.rs @@ -2,10 +2,7 @@ //! HIR back into source code, and just displaying them for debugging/testing //! purposes. -use std::{ - array, - fmt::{self, Debug}, -}; +use std::fmt::{self, Debug}; use chalk_ir::BoundVar; use hir_def::{ @@ -23,12 +20,16 @@ use hir_expand::{hygiene::Hygiene, name::Name}; use crate::{ - const_from_placeholder_idx, db::HirDatabase, from_assoc_type_id, from_foreign_def_id, - from_placeholder_idx, lt_from_placeholder_idx, mapping::from_chalk, primitive, subst_prefix, - to_assoc_type_id, utils::generics, AdtId, AliasEq, AliasTy, CallableDefId, CallableSig, Const, - ConstValue, DomainGoal, GenericArg, ImplTraitId, Interner, Lifetime, LifetimeData, - LifetimeOutlives, Mutability, OpaqueTy, ProjectionTy, ProjectionTyExt, QuantifiedWhereClause, - Scalar, TraitRef, TraitRefExt, Ty, TyExt, TyKind, WhereClause, + const_from_placeholder_idx, + db::HirDatabase, + from_assoc_type_id, from_foreign_def_id, from_placeholder_idx, lt_from_placeholder_idx, + mapping::from_chalk, + primitive, subst_prefix, to_assoc_type_id, + utils::{self, generics}, + AdtId, AliasEq, AliasTy, CallableDefId, CallableSig, Const, ConstValue, DomainGoal, GenericArg, + ImplTraitId, Interner, Lifetime, LifetimeData, LifetimeOutlives, Mutability, OpaqueTy, + ProjectionTy, ProjectionTyExt, QuantifiedWhereClause, Scalar, TraitRef, TraitRefExt, Ty, TyExt, + TyKind, WhereClause, }; pub struct HirFormatter<'a> { @@ -706,12 +707,7 @@ fn hir_fmt(&self, f: &mut HirFormatter) -> Result<(), HirDisplayError> { fn fn_traits(db: &dyn DefDatabase, trait_: TraitId) -> impl Iterator { let krate = trait_.lookup(db).container.krate(); - let fn_traits = [ - db.lang_item(krate, "fn".into()), - db.lang_item(krate, "fn_mut".into()), - db.lang_item(krate, "fn_once".into()), - ]; - array::IntoIter::new(fn_traits).into_iter().flatten().flat_map(|it| it.as_trait()) + utils::fn_traits(db, krate) } pub fn write_bounds_like_dyn_trait_with_prefix( diff --git a/crates/hir_ty/src/infer.rs b/crates/hir_ty/src/infer.rs index 63f37c0ab07..1ca7105f245 100644 --- a/crates/hir_ty/src/infer.rs +++ b/crates/hir_ty/src/infer.rs @@ -52,6 +52,7 @@ mod expr; mod pat; mod coerce; +mod closure; /// The entry point of type inference. pub(crate) fn infer_query(db: &dyn HirDatabase, def: DefWithBodyId) -> Arc { diff --git a/crates/hir_ty/src/infer/closure.rs b/crates/hir_ty/src/infer/closure.rs new file mode 100644 index 00000000000..b4ce22b9eb0 --- /dev/null +++ b/crates/hir_ty/src/infer/closure.rs @@ -0,0 +1,93 @@ +//! Inference of closure parameter types based on the closure's expected type. + +use chalk_ir::{fold::Shift, AliasTy, FnSubst, WhereClause}; +use hir_def::HasModule; +use smallvec::SmallVec; + +use crate::{ + to_chalk_trait_id, utils, ChalkTraitId, DynTy, FnPointer, FnSig, Interner, Substitution, Ty, + TyKind, +}; + +use super::{Expectation, InferenceContext}; + +impl InferenceContext<'_> { + pub(super) fn deduce_closure_type_from_expectations( + &mut self, + closure_ty: &Ty, + sig_ty: &Ty, + expectation: &Expectation, + ) { + let expected_ty = match expectation.to_option(&mut self.table) { + Some(ty) => ty, + None => return, + }; + + // Deduction from where-clauses in scope, as well as fn-pointer coercion are handled here. + self.coerce(closure_ty, &expected_ty); + + // Deduction based on the expected `dyn Fn` is done separately. + if let TyKind::Dyn(dyn_ty) = expected_ty.kind(&Interner) { + if let Some(sig) = self.deduce_sig_from_dyn_ty(dyn_ty) { + let expected_sig_ty = TyKind::Function(sig).intern(&Interner); + + self.unify(sig_ty, &expected_sig_ty); + } + } + } + + fn deduce_sig_from_dyn_ty(&self, dyn_ty: &DynTy) -> Option { + // Search for predicates like `$self: FnX` and `<$self as FnOnce<...>>::Output == Ret` + + let fn_traits: SmallVec<[ChalkTraitId; 3]> = + utils::fn_traits(self.db.upcast(), self.owner.module(self.db.upcast()).krate()) + .map(|tid| to_chalk_trait_id(tid)) + .collect(); + + for bound in dyn_ty.bounds.map_ref(|b| b.iter(&Interner)) { + let bound = bound.map(|b| b.clone()).fuse_binders(&Interner); + match bound.skip_binders() { + WhereClause::AliasEq(eq) => match &eq.alias { + AliasTy::Projection(projection) => { + let assoc_data = self.db.associated_ty_data(projection.associated_ty_id); + if !fn_traits.contains(&assoc_data.trait_id) { + return None; + } + + // Skip `Self`, get the type argument. + let arg = projection.substitution.as_slice(&Interner).get(1)?; + match arg.ty(&Interner)?.kind(&Interner) { + TyKind::Tuple(_, subst) => { + let generic_args = subst.as_slice(&Interner); + let mut sig_tys = Vec::new(); + for arg in generic_args { + sig_tys.push(arg.ty(&Interner)?.clone()); + } + sig_tys.push(eq.ty.clone()); + + cov_mark::hit!(dyn_fn_param_informs_call_site_closure_signature); + return Some(FnPointer { + num_binders: 0, + sig: FnSig { + abi: (), + safety: chalk_ir::Safety::Safe, + variadic: false, + }, + substitution: FnSubst( + Substitution::from_iter(&Interner, sig_tys.clone()) + .shifted_in(&Interner), + ), + }); + } + _ => {} + } + } + AliasTy::Opaque(_) => {} + }, + _ => {} + } + } + + None + } +} diff --git a/crates/hir_ty/src/infer/expr.rs b/crates/hir_ty/src/infer/expr.rs index c3a5b979f71..2f109297b55 100644 --- a/crates/hir_ty/src/infer/expr.rs +++ b/crates/hir_ty/src/infer/expr.rs @@ -278,15 +278,13 @@ fn infer_expr_inner(&mut self, tgt_expr: ExprId, expected: &Expectation) -> Ty { .intern(&Interner); let closure_id = self.db.intern_closure((self.owner, tgt_expr)).into(); let closure_ty = - TyKind::Closure(closure_id, Substitution::from1(&Interner, sig_ty)) + TyKind::Closure(closure_id, Substitution::from1(&Interner, sig_ty.clone())) .intern(&Interner); // Eagerly try to relate the closure type with the expected // type, otherwise we often won't have enough information to // infer the body. - if let Some(t) = expected.only_has_type(&mut self.table) { - self.coerce(&closure_ty, &t); - } + self.deduce_closure_type_from_expectations(&closure_ty, &sig_ty, expected); // Now go through the argument patterns for (arg_pat, arg_ty) in args.iter().zip(sig_tys) { diff --git a/crates/hir_ty/src/tests/traits.rs b/crates/hir_ty/src/tests/traits.rs index a0ddad570d0..9db30d9f98b 100644 --- a/crates/hir_ty/src/tests/traits.rs +++ b/crates/hir_ty/src/tests/traits.rs @@ -2829,6 +2829,26 @@ fn foo() { ); } +#[test] +fn dyn_fn_param_informs_call_site_closure_signature() { + cov_mark::check!(dyn_fn_param_informs_call_site_closure_signature); + check_types( + r#" +//- minicore: fn, coerce_unsized +struct S; +impl S { + fn inherent(&self) -> u8 { 0 } +} +fn take_dyn_fn(f: &dyn Fn(S)) {} + +fn f() { + take_dyn_fn(&|x| { x.inherent(); }); + //^^^^^^^^^^^^ u8 +} + "#, + ); +} + #[test] fn infer_fn_trait_arg() { check_infer_with_mismatches( diff --git a/crates/hir_ty/src/utils.rs b/crates/hir_ty/src/utils.rs index 2f490fb9202..076b2c8cba3 100644 --- a/crates/hir_ty/src/utils.rs +++ b/crates/hir_ty/src/utils.rs @@ -1,8 +1,9 @@ //! Helper functions for working with def, which don't need to be a separate //! query, but can't be computed directly from `*Data` (ie, which need a `db`). -use std::iter; +use std::{array, iter}; +use base_db::CrateId; use chalk_ir::{fold::Shift, BoundVar, DebruijnIndex}; use hir_def::{ db::DefDatabase, @@ -23,6 +24,15 @@ WhereClause, }; +pub(crate) fn fn_traits(db: &dyn DefDatabase, krate: CrateId) -> impl Iterator { + let fn_traits = [ + db.lang_item(krate, "fn".into()), + db.lang_item(krate, "fn_mut".into()), + db.lang_item(krate, "fn_once".into()), + ]; + array::IntoIter::new(fn_traits).into_iter().flatten().flat_map(|it| it.as_trait()) +} + fn direct_super_traits(db: &dyn DefDatabase, trait_: TraitId) -> Vec { let resolver = trait_.resolver(db); // returning the iterator directly doesn't easily work because of