diff --git a/crates/hir_def/src/generics.rs b/crates/hir_def/src/generics.rs index 41134d23b3d..bb8fca009a1 100644 --- a/crates/hir_def/src/generics.rs +++ b/crates/hir_def/src/generics.rs @@ -62,6 +62,7 @@ pub struct GenericParams { pub enum WherePredicate { TypeBound { target: WherePredicateTypeTarget, bound: TypeBound }, Lifetime { target: LifetimeRef, bound: LifetimeRef }, + ForLifetime { lifetimes: Box<[Name]>, target: WherePredicateTypeTarget, bound: TypeBound }, } #[derive(Clone, PartialEq, Eq, Debug)] @@ -69,7 +70,6 @@ pub enum WherePredicateTypeTarget { TypeRef(TypeRef), /// For desugared where predicates that can directly refer to a type param. TypeParam(LocalTypeParamId), - // FIXME: ForLifetime(Vec, TypeRef) } #[derive(Default)] @@ -234,7 +234,7 @@ pub(crate) fn fill_bounds( for bound in node.type_bound_list().iter().flat_map(|type_bound_list| type_bound_list.bounds()) { - self.add_where_predicate_from_bound(lower_ctx, bound, target.clone()); + self.add_where_predicate_from_bound(lower_ctx, bound, None, target.clone()); } } @@ -279,8 +279,25 @@ fn fill_where_predicates(&mut self, lower_ctx: &LowerCtx, where_clause: ast::Whe } else { continue; }; + + let lifetimes: Option> = pred.generic_param_list().map(|param_list| { + // Higher-Ranked Trait Bounds + param_list + .lifetime_params() + .map(|lifetime_param| { + lifetime_param + .lifetime() + .map_or_else(Name::missing, |lt| Name::new_lifetime(<)) + }) + .collect() + }); for bound in pred.type_bound_list().iter().flat_map(|l| l.bounds()) { - self.add_where_predicate_from_bound(lower_ctx, bound, target.clone()); + self.add_where_predicate_from_bound( + lower_ctx, + bound, + lifetimes.as_ref(), + target.clone(), + ); } } } @@ -289,6 +306,7 @@ fn add_where_predicate_from_bound( &mut self, lower_ctx: &LowerCtx, bound: ast::TypeBound, + hrtb_lifetimes: Option<&Box<[Name]>>, target: Either, ) { if bound.question_mark_token().is_some() { @@ -297,9 +315,16 @@ fn add_where_predicate_from_bound( } let bound = TypeBound::from_ast(lower_ctx, bound); let predicate = match (target, bound) { - (Either::Left(type_ref), bound) => WherePredicate::TypeBound { - target: WherePredicateTypeTarget::TypeRef(type_ref), - bound, + (Either::Left(type_ref), bound) => match hrtb_lifetimes { + Some(hrtb_lifetimes) => WherePredicate::ForLifetime { + lifetimes: hrtb_lifetimes.clone(), + target: WherePredicateTypeTarget::TypeRef(type_ref), + bound, + }, + None => WherePredicate::TypeBound { + target: WherePredicateTypeTarget::TypeRef(type_ref), + bound, + }, }, (Either::Right(lifetime), TypeBound::Lifetime(bound)) => { WherePredicate::Lifetime { target: lifetime, bound } diff --git a/crates/hir_ty/src/lower.rs b/crates/hir_ty/src/lower.rs index 8392cb77065..8da56cd11c8 100644 --- a/crates/hir_ty/src/lower.rs +++ b/crates/hir_ty/src/lower.rs @@ -675,7 +675,8 @@ pub(crate) fn from_where_predicate<'a>( where_predicate: &'a WherePredicate, ) -> impl Iterator + 'a { match where_predicate { - WherePredicate::TypeBound { target, bound } => { + WherePredicate::ForLifetime { target, bound, .. } + | WherePredicate::TypeBound { target, bound } => { let self_ty = match target { WherePredicateTypeTarget::TypeRef(type_ref) => Ty::from_hir(ctx, type_ref), WherePredicateTypeTarget::TypeParam(param_id) => { @@ -888,14 +889,13 @@ pub(crate) fn generic_predicates_for_param_query( .where_predicates_in_scope() // we have to filter out all other predicates *first*, before attempting to lower them .filter(|pred| match pred { - WherePredicate::TypeBound { - target: WherePredicateTypeTarget::TypeRef(type_ref), - .. - } => Ty::from_hir_only_param(&ctx, type_ref) == Some(param_id), - WherePredicate::TypeBound { - target: WherePredicateTypeTarget::TypeParam(local_id), - .. - } => *local_id == param_id.local_id, + WherePredicate::ForLifetime { target, .. } + | WherePredicate::TypeBound { target, .. } => match target { + WherePredicateTypeTarget::TypeRef(type_ref) => { + Ty::from_hir_only_param(&ctx, type_ref) == Some(param_id) + } + WherePredicateTypeTarget::TypeParam(local_id) => *local_id == param_id.local_id, + }, WherePredicate::Lifetime { .. } => false, }) .flat_map(|pred| { diff --git a/crates/hir_ty/src/utils.rs b/crates/hir_ty/src/utils.rs index af880c0658c..65b79df0d51 100644 --- a/crates/hir_ty/src/utils.rs +++ b/crates/hir_ty/src/utils.rs @@ -5,7 +5,9 @@ use hir_def::{ adt::VariantData, db::DefDatabase, - generics::{GenericParams, TypeParamData, TypeParamProvenance, WherePredicateTypeTarget}, + generics::{ + GenericParams, TypeParamData, TypeParamProvenance, WherePredicate, WherePredicateTypeTarget, + }, path::Path, resolver::{HasResolver, TypeNs}, type_ref::TypeRef, @@ -27,7 +29,8 @@ fn direct_super_traits(db: &dyn DefDatabase, trait_: TraitId) -> Vec { .where_predicates .iter() .filter_map(|pred| match pred { - hir_def::generics::WherePredicate::TypeBound { target, bound } => match target { + WherePredicate::ForLifetime { target, bound, .. } + | WherePredicate::TypeBound { target, bound } => match target { WherePredicateTypeTarget::TypeRef(TypeRef::Path(p)) if p == &Path::from(name![Self]) => { @@ -38,7 +41,7 @@ fn direct_super_traits(db: &dyn DefDatabase, trait_: TraitId) -> Vec { } _ => None, }, - hir_def::generics::WherePredicate::Lifetime { .. } => None, + WherePredicate::Lifetime { .. } => None, }) .filter_map(|path| match resolver.resolve_path_in_type_ns_fully(db, path.mod_path()) { Some(TypeNs::TraitId(t)) => Some(t), diff --git a/crates/ide/src/goto_definition.rs b/crates/ide/src/goto_definition.rs index 173509b08bd..d75ae447bb0 100644 --- a/crates/ide/src/goto_definition.rs +++ b/crates/ide/src/goto_definition.rs @@ -1077,4 +1077,32 @@ fn foo<'foobar>(_: &'foobar<|> ()) {} }"#, ) } + + #[test] + #[ignore] // requires the HIR to somehow track these hrtb lifetimes + fn goto_lifetime_hrtb() { + check( + r#"trait Foo {} +fn foo() where for<'a> T: Foo<&'a<|> (u8, u16)>, {} + //^^ +"#, + ); + check( + r#"trait Foo {} +fn foo() where for<'a<|>> T: Foo<&'a (u8, u16)>, {} + //^^ +"#, + ); + } + + #[test] + #[ignore] // requires ForTypes to be implemented + fn goto_lifetime_hrtb_for_type() { + check( + r#"trait Foo {} +fn foo() where T: for<'a> Foo<&'a<|> (u8, u16)>, {} + //^^ +"#, + ); + } }