diff --git a/crates/hir-def/src/macro_expansion_tests/builtin_derive_macro.rs b/crates/hir-def/src/macro_expansion_tests/builtin_derive_macro.rs index 80474bc154d..1a5ab19e1c2 100644 --- a/crates/hir-def/src/macro_expansion_tests/builtin_derive_macro.rs +++ b/crates/hir-def/src/macro_expansion_tests/builtin_derive_macro.rs @@ -114,6 +114,66 @@ fn clone(&self ) -> Self { ); } +#[test] +fn test_clone_expand_with_associated_types() { + check( + r#" +//- minicore: derive, clone +trait Trait { + type InWc; + type InFieldQualified; + type InFieldShorthand; + type InGenericArg; +} +trait Marker {} +struct Vec(T); + +#[derive(Clone)] +struct Foo +where + ::InWc: Marker, +{ + qualified: ::InFieldQualified, + shorthand: T::InFieldShorthand, + generic: Vec, +} +"#, + expect![[r#" +trait Trait { + type InWc; + type InFieldQualified; + type InFieldShorthand; + type InGenericArg; +} +trait Marker {} +struct Vec(T); + +#[derive(Clone)] +struct Foo +where + ::InWc: Marker, +{ + qualified: ::InFieldQualified, + shorthand: T::InFieldShorthand, + generic: Vec, +} + +impl core::clone::Clone for Foo where T: Trait, T::InFieldShorthand: core::clone::Clone, T::InGenericArg: core::clone::Clone, { + fn clone(&self ) -> Self { + match self { + Foo { + qualified: qualified, shorthand: shorthand, generic: generic, + } + =>Foo { + qualified: qualified.clone(), shorthand: shorthand.clone(), generic: generic.clone(), + } + , + } + } +}"#]], + ); +} + #[test] fn test_clone_expand_with_const_generics() { check( diff --git a/crates/hir-expand/src/builtin_derive_macro.rs b/crates/hir-expand/src/builtin_derive_macro.rs index 54706943ac4..da34f50660a 100644 --- a/crates/hir-expand/src/builtin_derive_macro.rs +++ b/crates/hir-expand/src/builtin_derive_macro.rs @@ -4,17 +4,16 @@ use base_db::{CrateOrigin, LangCrateOrigin}; use itertools::izip; use mbe::TokenMap; -use std::collections::HashSet; +use rustc_hash::FxHashSet; use stdx::never; use tracing::debug; -use crate::tt::{self, TokenId}; -use syntax::{ - ast::{ - self, AstNode, FieldList, HasAttrs, HasGenericParams, HasModuleItem, HasName, - HasTypeBounds, PathType, - }, - match_ast, +use crate::{ + name::{AsName, Name}, + tt::{self, TokenId}, +}; +use syntax::ast::{ + self, AstNode, FieldList, HasAttrs, HasGenericParams, HasModuleItem, HasName, HasTypeBounds, }; use crate::{db::ExpandDatabase, name, quote, ExpandError, ExpandResult, MacroCallId}; @@ -201,33 +200,46 @@ fn parse_adt(tt: &tt::Subtree) -> Result { debug!("no module item parsed"); ExpandError::Other("no item found".into()) })?; - let node = item.syntax(); - let (name, params, shape) = match_ast! { - match node { - ast::Struct(it) => (it.name(), it.generic_param_list(), AdtShape::Struct(VariantShape::from(it.field_list(), &token_map)?)), - ast::Enum(it) => { - let default_variant = it.variant_list().into_iter().flat_map(|x| x.variants()).position(|x| x.attrs().any(|x| x.simple_name() == Some("default".into()))); - ( - it.name(), - it.generic_param_list(), - AdtShape::Enum { - default_variant, - variants: it.variant_list() - .into_iter() - .flat_map(|x| x.variants()) - .map(|x| Ok((name_to_token(&token_map,x.name())?, VariantShape::from(x.field_list(), &token_map)?))).collect::>()? - } - ) - }, - ast::Union(it) => (it.name(), it.generic_param_list(), AdtShape::Union), - _ => { - debug!("unexpected node is {:?}", node); - return Err(ExpandError::Other("expected struct, enum or union".into())) - }, + let adt = ast::Adt::cast(item.syntax().clone()).ok_or_else(|| { + debug!("expected adt, found: {:?}", item); + ExpandError::Other("expected struct, enum or union".into()) + })?; + let (name, generic_param_list, shape) = match &adt { + ast::Adt::Struct(it) => ( + it.name(), + it.generic_param_list(), + AdtShape::Struct(VariantShape::from(it.field_list(), &token_map)?), + ), + ast::Adt::Enum(it) => { + let default_variant = it + .variant_list() + .into_iter() + .flat_map(|x| x.variants()) + .position(|x| x.attrs().any(|x| x.simple_name() == Some("default".into()))); + ( + it.name(), + it.generic_param_list(), + AdtShape::Enum { + default_variant, + variants: it + .variant_list() + .into_iter() + .flat_map(|x| x.variants()) + .map(|x| { + Ok(( + name_to_token(&token_map, x.name())?, + VariantShape::from(x.field_list(), &token_map)?, + )) + }) + .collect::>()?, + }, + ) } + ast::Adt::Union(it) => (it.name(), it.generic_param_list(), AdtShape::Union), }; - let mut param_type_set: HashSet = HashSet::new(); - let param_types = params + + let mut param_type_set: FxHashSet = FxHashSet::default(); + let param_types = generic_param_list .into_iter() .flat_map(|param_list| param_list.type_or_const_params()) .map(|param| { @@ -235,7 +247,7 @@ fn parse_adt(tt: &tt::Subtree) -> Result { let this = param.name(); match this { Some(x) => { - param_type_set.insert(x.to_string()); + param_type_set.insert(x.as_name()); mbe::syntax_node_to_token_tree(x.syntax()).0 } None => tt::Subtree::empty(), @@ -259,37 +271,33 @@ fn parse_adt(tt: &tt::Subtree) -> Result { (name, ty, bounds) }) .collect(); - let is_associated_type = |p: &PathType| { - if let Some(p) = p.path() { - if let Some(parent) = p.qualifier() { - if let Some(x) = parent.segment() { - if let Some(x) = x.path_type() { - if let Some(x) = x.path() { - if let Some(pname) = x.as_single_name_ref() { - if param_type_set.contains(&pname.to_string()) { - // ::Assoc - return true; - } - } - } - } - } - if let Some(pname) = parent.as_single_name_ref() { - if param_type_set.contains(&pname.to_string()) { - // T::Assoc - return true; - } - } - } - } - false + + // For a generic parameter `T`, when shorthand associated type `T::Assoc` appears in field + // types (of any variant for enums), we generate trait bound for it. It sounds reasonable to + // also generate trait bound for qualified associated type `::Assoc`, but rustc + // does not do that for some unknown reason. + // + // See the analogous function in rustc [find_type_parameters()] and rust-lang/rust#50730. + // [find_type_parameters()]: https://github.com/rust-lang/rust/blob/1.70.0/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs#L378 + + // It's cumbersome to deal with the distinct structures of ADTs, so let's just get untyped + // `SyntaxNode` that contains fields and look for descendant `ast::PathType`s. Of note is that + // we should not inspect `ast::PathType`s in parameter bounds and where clauses. + let field_list = match adt { + ast::Adt::Enum(it) => it.variant_list().map(|list| list.syntax().clone()), + ast::Adt::Struct(it) => it.field_list().map(|list| list.syntax().clone()), + ast::Adt::Union(it) => it.record_field_list().map(|list| list.syntax().clone()), }; - let associated_types = node - .descendants() - .filter_map(PathType::cast) - .filter(is_associated_type) + let associated_types = field_list + .into_iter() + .flat_map(|it| it.descendants()) + .filter_map(ast::PathType::cast) + .filter_map(|p| { + let name = p.path()?.qualifier()?.as_single_name_ref()?.as_name(); + param_type_set.contains(&name).then_some(p) + }) .map(|x| mbe::syntax_node_to_token_tree(x.syntax()).0) - .collect::>(); + .collect(); let name_token = name_to_token(&token_map, name)?; Ok(BasicAdtInfo { name: name_token, shape, param_types, associated_types }) } @@ -334,18 +342,18 @@ fn name_to_token(token_map: &TokenMap, name: Option) -> Result tt::Subtree, + make_trait_body: impl FnOnce(&BasicAdtInfo) -> tt::Subtree, ) -> ExpandResult { let info = match parse_adt(tt) { Ok(info) => info, Err(e) => return ExpandResult::new(tt::Subtree::empty(), e), }; - let trait_body = trait_body(&info); + let trait_body = make_trait_body(&info); let mut where_block = vec![]; let (params, args): (Vec<_>, Vec<_>) = info .param_types diff --git a/crates/hir-ty/src/tests/traits.rs b/crates/hir-ty/src/tests/traits.rs index 829a6ab189e..97ae732a904 100644 --- a/crates/hir-ty/src/tests/traits.rs +++ b/crates/hir-ty/src/tests/traits.rs @@ -4335,8 +4335,9 @@ impl Tr for Copy { #[derive(Clone)] struct AssocGeneric(T::Assoc); - #[derive(Clone)] - struct AssocGeneric2(::Assoc); + // Currently rustc does not accept this. + // #[derive(Clone)] + // struct AssocGeneric2(::Assoc); #[derive(Clone)] struct AssocGeneric3(Generic); @@ -4361,9 +4362,8 @@ fn f() { let x: &AssocGeneric = &AssocGeneric(NotCopy); let x = x.clone(); //^ &AssocGeneric - let x: &AssocGeneric2 = &AssocGeneric2(NotCopy); - let x = x.clone(); - //^ &AssocGeneric2 + // let x: &AssocGeneric2 = &AssocGeneric2(NotCopy); + // let x = x.clone(); let x: &AssocGeneric3 = &AssocGeneric3(Generic(NotCopy)); let x = x.clone(); //^ &AssocGeneric3