Auto merge of #15000 - lowr:fix/builtin-derive-bound-for-assoc, r=HKalbasi
fix: only generate trait bound for associated types in field types Given the following definitions: ```rust trait Trait { type A; type B; type C; } #[derive(Clone)] struct S<T: Trait> where T::A: Send, { qualified: <T as Trait>::B, shorthand: T::C, } ``` we currently expand the derive macro to: ```rust impl<T> Clone for S<T> where T: Trait + Clone, T::A: Clone, T::B: Clone, T::C: Clone, { /* ... */ } ``` This does not match how rustc expands it. Specifically, `Clone` bounds for `T::A` and `T::B` should not be generated. The criteria for associated types to get bound seem to be 1) the associated type appears as part of field types AND 2) it's written in the shorthand form. I have no idea why rustc doesn't consider qualified associated types (there's even a comment that suggests they should be considered; see rust-lang/rust#50730), but it's important to follow rustc.
This commit is contained in:
commit
085a3112ae
@ -114,6 +114,66 @@ fn clone(&self ) -> Self {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clone_expand_with_associated_types() {
|
||||
check(
|
||||
r#"
|
||||
//- minicore: derive, clone
|
||||
trait Trait {
|
||||
type InWc;
|
||||
type InFieldQualified;
|
||||
type InFieldShorthand;
|
||||
type InGenericArg;
|
||||
}
|
||||
trait Marker {}
|
||||
struct Vec<T>(T);
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Foo<T: Trait>
|
||||
where
|
||||
<T as Trait>::InWc: Marker,
|
||||
{
|
||||
qualified: <T as Trait>::InFieldQualified,
|
||||
shorthand: T::InFieldShorthand,
|
||||
generic: Vec<T::InGenericArg>,
|
||||
}
|
||||
"#,
|
||||
expect![[r#"
|
||||
trait Trait {
|
||||
type InWc;
|
||||
type InFieldQualified;
|
||||
type InFieldShorthand;
|
||||
type InGenericArg;
|
||||
}
|
||||
trait Marker {}
|
||||
struct Vec<T>(T);
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Foo<T: Trait>
|
||||
where
|
||||
<T as Trait>::InWc: Marker,
|
||||
{
|
||||
qualified: <T as Trait>::InFieldQualified,
|
||||
shorthand: T::InFieldShorthand,
|
||||
generic: Vec<T::InGenericArg>,
|
||||
}
|
||||
|
||||
impl <T: core::clone::Clone, > core::clone::Clone for Foo<T, > where T: Trait, T::InFieldShorthand: core::clone::Clone, T::InGenericArg: core::clone::Clone, {
|
||||
fn clone(&self ) -> Self {
|
||||
match self {
|
||||
Foo {
|
||||
qualified: qualified, shorthand: shorthand, generic: generic,
|
||||
}
|
||||
=>Foo {
|
||||
qualified: qualified.clone(), shorthand: shorthand.clone(), generic: generic.clone(),
|
||||
}
|
||||
,
|
||||
}
|
||||
}
|
||||
}"#]],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clone_expand_with_const_generics() {
|
||||
check(
|
||||
|
@ -4,17 +4,16 @@
|
||||
use base_db::{CrateOrigin, LangCrateOrigin};
|
||||
use itertools::izip;
|
||||
use mbe::TokenMap;
|
||||
use std::collections::HashSet;
|
||||
use rustc_hash::FxHashSet;
|
||||
use stdx::never;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::tt::{self, TokenId};
|
||||
use syntax::{
|
||||
ast::{
|
||||
self, AstNode, FieldList, HasAttrs, HasGenericParams, HasModuleItem, HasName,
|
||||
HasTypeBounds, PathType,
|
||||
},
|
||||
match_ast,
|
||||
use crate::{
|
||||
name::{AsName, Name},
|
||||
tt::{self, TokenId},
|
||||
};
|
||||
use syntax::ast::{
|
||||
self, AstNode, FieldList, HasAttrs, HasGenericParams, HasModuleItem, HasName, HasTypeBounds,
|
||||
};
|
||||
|
||||
use crate::{db::ExpandDatabase, name, quote, ExpandError, ExpandResult, MacroCallId};
|
||||
@ -201,33 +200,46 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
|
||||
debug!("no module item parsed");
|
||||
ExpandError::Other("no item found".into())
|
||||
})?;
|
||||
let node = item.syntax();
|
||||
let (name, params, shape) = match_ast! {
|
||||
match node {
|
||||
ast::Struct(it) => (it.name(), it.generic_param_list(), AdtShape::Struct(VariantShape::from(it.field_list(), &token_map)?)),
|
||||
ast::Enum(it) => {
|
||||
let default_variant = it.variant_list().into_iter().flat_map(|x| x.variants()).position(|x| x.attrs().any(|x| x.simple_name() == Some("default".into())));
|
||||
let adt = ast::Adt::cast(item.syntax().clone()).ok_or_else(|| {
|
||||
debug!("expected adt, found: {:?}", item);
|
||||
ExpandError::Other("expected struct, enum or union".into())
|
||||
})?;
|
||||
let (name, generic_param_list, shape) = match &adt {
|
||||
ast::Adt::Struct(it) => (
|
||||
it.name(),
|
||||
it.generic_param_list(),
|
||||
AdtShape::Struct(VariantShape::from(it.field_list(), &token_map)?),
|
||||
),
|
||||
ast::Adt::Enum(it) => {
|
||||
let default_variant = it
|
||||
.variant_list()
|
||||
.into_iter()
|
||||
.flat_map(|x| x.variants())
|
||||
.position(|x| x.attrs().any(|x| x.simple_name() == Some("default".into())));
|
||||
(
|
||||
it.name(),
|
||||
it.generic_param_list(),
|
||||
AdtShape::Enum {
|
||||
default_variant,
|
||||
variants: it.variant_list()
|
||||
variants: it
|
||||
.variant_list()
|
||||
.into_iter()
|
||||
.flat_map(|x| x.variants())
|
||||
.map(|x| Ok((name_to_token(&token_map,x.name())?, VariantShape::from(x.field_list(), &token_map)?))).collect::<Result<_, ExpandError>>()?
|
||||
}
|
||||
.map(|x| {
|
||||
Ok((
|
||||
name_to_token(&token_map, x.name())?,
|
||||
VariantShape::from(x.field_list(), &token_map)?,
|
||||
))
|
||||
})
|
||||
.collect::<Result<_, ExpandError>>()?,
|
||||
},
|
||||
)
|
||||
},
|
||||
ast::Union(it) => (it.name(), it.generic_param_list(), AdtShape::Union),
|
||||
_ => {
|
||||
debug!("unexpected node is {:?}", node);
|
||||
return Err(ExpandError::Other("expected struct, enum or union".into()))
|
||||
},
|
||||
}
|
||||
ast::Adt::Union(it) => (it.name(), it.generic_param_list(), AdtShape::Union),
|
||||
};
|
||||
let mut param_type_set: HashSet<String> = HashSet::new();
|
||||
let param_types = params
|
||||
|
||||
let mut param_type_set: FxHashSet<Name> = FxHashSet::default();
|
||||
let param_types = generic_param_list
|
||||
.into_iter()
|
||||
.flat_map(|param_list| param_list.type_or_const_params())
|
||||
.map(|param| {
|
||||
@ -235,7 +247,7 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
|
||||
let this = param.name();
|
||||
match this {
|
||||
Some(x) => {
|
||||
param_type_set.insert(x.to_string());
|
||||
param_type_set.insert(x.as_name());
|
||||
mbe::syntax_node_to_token_tree(x.syntax()).0
|
||||
}
|
||||
None => tt::Subtree::empty(),
|
||||
@ -259,37 +271,33 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
|
||||
(name, ty, bounds)
|
||||
})
|
||||
.collect();
|
||||
let is_associated_type = |p: &PathType| {
|
||||
if let Some(p) = p.path() {
|
||||
if let Some(parent) = p.qualifier() {
|
||||
if let Some(x) = parent.segment() {
|
||||
if let Some(x) = x.path_type() {
|
||||
if let Some(x) = x.path() {
|
||||
if let Some(pname) = x.as_single_name_ref() {
|
||||
if param_type_set.contains(&pname.to_string()) {
|
||||
// <T as Trait>::Assoc
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(pname) = parent.as_single_name_ref() {
|
||||
if param_type_set.contains(&pname.to_string()) {
|
||||
// T::Assoc
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
|
||||
// For a generic parameter `T`, when shorthand associated type `T::Assoc` appears in field
|
||||
// types (of any variant for enums), we generate trait bound for it. It sounds reasonable to
|
||||
// also generate trait bound for qualified associated type `<T as Trait>::Assoc`, but rustc
|
||||
// does not do that for some unknown reason.
|
||||
//
|
||||
// See the analogous function in rustc [find_type_parameters()] and rust-lang/rust#50730.
|
||||
// [find_type_parameters()]: https://github.com/rust-lang/rust/blob/1.70.0/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs#L378
|
||||
|
||||
// It's cumbersome to deal with the distinct structures of ADTs, so let's just get untyped
|
||||
// `SyntaxNode` that contains fields and look for descendant `ast::PathType`s. Of note is that
|
||||
// we should not inspect `ast::PathType`s in parameter bounds and where clauses.
|
||||
let field_list = match adt {
|
||||
ast::Adt::Enum(it) => it.variant_list().map(|list| list.syntax().clone()),
|
||||
ast::Adt::Struct(it) => it.field_list().map(|list| list.syntax().clone()),
|
||||
ast::Adt::Union(it) => it.record_field_list().map(|list| list.syntax().clone()),
|
||||
};
|
||||
let associated_types = node
|
||||
.descendants()
|
||||
.filter_map(PathType::cast)
|
||||
.filter(is_associated_type)
|
||||
let associated_types = field_list
|
||||
.into_iter()
|
||||
.flat_map(|it| it.descendants())
|
||||
.filter_map(ast::PathType::cast)
|
||||
.filter_map(|p| {
|
||||
let name = p.path()?.qualifier()?.as_single_name_ref()?.as_name();
|
||||
param_type_set.contains(&name).then_some(p)
|
||||
})
|
||||
.map(|x| mbe::syntax_node_to_token_tree(x.syntax()).0)
|
||||
.collect::<Vec<_>>();
|
||||
.collect();
|
||||
let name_token = name_to_token(&token_map, name)?;
|
||||
Ok(BasicAdtInfo { name: name_token, shape, param_types, associated_types })
|
||||
}
|
||||
@ -334,18 +342,18 @@ fn name_to_token(token_map: &TokenMap, name: Option<ast::Name>) -> Result<tt::Id
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// where B1, ..., BN are the bounds given by `bounds_paths`.'. Z is a phantom type, and
|
||||
/// where B1, ..., BN are the bounds given by `bounds_paths`. Z is a phantom type, and
|
||||
/// therefore does not get bound by the derived trait.
|
||||
fn expand_simple_derive(
|
||||
tt: &tt::Subtree,
|
||||
trait_path: tt::Subtree,
|
||||
trait_body: impl FnOnce(&BasicAdtInfo) -> tt::Subtree,
|
||||
make_trait_body: impl FnOnce(&BasicAdtInfo) -> tt::Subtree,
|
||||
) -> ExpandResult<tt::Subtree> {
|
||||
let info = match parse_adt(tt) {
|
||||
Ok(info) => info,
|
||||
Err(e) => return ExpandResult::new(tt::Subtree::empty(), e),
|
||||
};
|
||||
let trait_body = trait_body(&info);
|
||||
let trait_body = make_trait_body(&info);
|
||||
let mut where_block = vec![];
|
||||
let (params, args): (Vec<_>, Vec<_>) = info
|
||||
.param_types
|
||||
|
@ -4335,8 +4335,9 @@ impl Tr for Copy {
|
||||
#[derive(Clone)]
|
||||
struct AssocGeneric<T: Tr>(T::Assoc);
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AssocGeneric2<T: Tr>(<T as Tr>::Assoc);
|
||||
// Currently rustc does not accept this.
|
||||
// #[derive(Clone)]
|
||||
// struct AssocGeneric2<T: Tr>(<T as Tr>::Assoc);
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AssocGeneric3<T: Tr>(Generic<T::Assoc>);
|
||||
@ -4361,9 +4362,8 @@ fn f() {
|
||||
let x: &AssocGeneric<Copy> = &AssocGeneric(NotCopy);
|
||||
let x = x.clone();
|
||||
//^ &AssocGeneric<Copy>
|
||||
let x: &AssocGeneric2<Copy> = &AssocGeneric2(NotCopy);
|
||||
let x = x.clone();
|
||||
//^ &AssocGeneric2<Copy>
|
||||
// let x: &AssocGeneric2<Copy> = &AssocGeneric2(NotCopy);
|
||||
// let x = x.clone();
|
||||
let x: &AssocGeneric3<Copy> = &AssocGeneric3(Generic(NotCopy));
|
||||
let x = x.clone();
|
||||
//^ &AssocGeneric3<Copy>
|
||||
|
Loading…
Reference in New Issue
Block a user