Auto merge of #15000 - lowr:fix/builtin-derive-bound-for-assoc, r=HKalbasi

fix: only generate trait bound for associated types in field types

Given the following definitions:

```rust
trait Trait {
    type A;
    type B;
    type C;
}

#[derive(Clone)]
struct S<T: Trait>
where
    T::A: Send,
{
    qualified: <T as Trait>::B,
    shorthand: T::C,
}
```

we currently expand the derive macro to:

```rust
impl<T> Clone for S<T>
where
    T: Trait + Clone,
    T::A: Clone,
    T::B: Clone,
    T::C: Clone,
{ /* ... */ }
```

This does not match how rustc expands it. Specifically, `Clone` bounds for `T::A` and `T::B` should not be generated.

The criteria for associated types to get bound seem to be 1) the associated type appears as part of field types AND 2) it's written in the shorthand form. I have no idea why rustc doesn't consider qualified associated types (there's even a comment that suggests they should be considered; see rust-lang/rust#50730), but it's important to follow rustc.
This commit is contained in:
bors 2023-06-07 13:00:24 +00:00
commit 085a3112ae
3 changed files with 139 additions and 71 deletions

View File

@ -114,6 +114,66 @@ fn clone(&self ) -> Self {
);
}
#[test]
fn test_clone_expand_with_associated_types() {
check(
r#"
//- minicore: derive, clone
trait Trait {
type InWc;
type InFieldQualified;
type InFieldShorthand;
type InGenericArg;
}
trait Marker {}
struct Vec<T>(T);
#[derive(Clone)]
struct Foo<T: Trait>
where
<T as Trait>::InWc: Marker,
{
qualified: <T as Trait>::InFieldQualified,
shorthand: T::InFieldShorthand,
generic: Vec<T::InGenericArg>,
}
"#,
expect![[r#"
trait Trait {
type InWc;
type InFieldQualified;
type InFieldShorthand;
type InGenericArg;
}
trait Marker {}
struct Vec<T>(T);
#[derive(Clone)]
struct Foo<T: Trait>
where
<T as Trait>::InWc: Marker,
{
qualified: <T as Trait>::InFieldQualified,
shorthand: T::InFieldShorthand,
generic: Vec<T::InGenericArg>,
}
impl <T: core::clone::Clone, > core::clone::Clone for Foo<T, > where T: Trait, T::InFieldShorthand: core::clone::Clone, T::InGenericArg: core::clone::Clone, {
fn clone(&self ) -> Self {
match self {
Foo {
qualified: qualified, shorthand: shorthand, generic: generic,
}
=>Foo {
qualified: qualified.clone(), shorthand: shorthand.clone(), generic: generic.clone(),
}
,
}
}
}"#]],
);
}
#[test]
fn test_clone_expand_with_const_generics() {
check(

View File

@ -4,17 +4,16 @@
use base_db::{CrateOrigin, LangCrateOrigin};
use itertools::izip;
use mbe::TokenMap;
use std::collections::HashSet;
use rustc_hash::FxHashSet;
use stdx::never;
use tracing::debug;
use crate::tt::{self, TokenId};
use syntax::{
ast::{
self, AstNode, FieldList, HasAttrs, HasGenericParams, HasModuleItem, HasName,
HasTypeBounds, PathType,
},
match_ast,
use crate::{
name::{AsName, Name},
tt::{self, TokenId},
};
use syntax::ast::{
self, AstNode, FieldList, HasAttrs, HasGenericParams, HasModuleItem, HasName, HasTypeBounds,
};
use crate::{db::ExpandDatabase, name, quote, ExpandError, ExpandResult, MacroCallId};
@ -201,33 +200,46 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
debug!("no module item parsed");
ExpandError::Other("no item found".into())
})?;
let node = item.syntax();
let (name, params, shape) = match_ast! {
match node {
ast::Struct(it) => (it.name(), it.generic_param_list(), AdtShape::Struct(VariantShape::from(it.field_list(), &token_map)?)),
ast::Enum(it) => {
let default_variant = it.variant_list().into_iter().flat_map(|x| x.variants()).position(|x| x.attrs().any(|x| x.simple_name() == Some("default".into())));
let adt = ast::Adt::cast(item.syntax().clone()).ok_or_else(|| {
debug!("expected adt, found: {:?}", item);
ExpandError::Other("expected struct, enum or union".into())
})?;
let (name, generic_param_list, shape) = match &adt {
ast::Adt::Struct(it) => (
it.name(),
it.generic_param_list(),
AdtShape::Struct(VariantShape::from(it.field_list(), &token_map)?),
),
ast::Adt::Enum(it) => {
let default_variant = it
.variant_list()
.into_iter()
.flat_map(|x| x.variants())
.position(|x| x.attrs().any(|x| x.simple_name() == Some("default".into())));
(
it.name(),
it.generic_param_list(),
AdtShape::Enum {
default_variant,
variants: it.variant_list()
variants: it
.variant_list()
.into_iter()
.flat_map(|x| x.variants())
.map(|x| Ok((name_to_token(&token_map,x.name())?, VariantShape::from(x.field_list(), &token_map)?))).collect::<Result<_, ExpandError>>()?
}
.map(|x| {
Ok((
name_to_token(&token_map, x.name())?,
VariantShape::from(x.field_list(), &token_map)?,
))
})
.collect::<Result<_, ExpandError>>()?,
},
)
},
ast::Union(it) => (it.name(), it.generic_param_list(), AdtShape::Union),
_ => {
debug!("unexpected node is {:?}", node);
return Err(ExpandError::Other("expected struct, enum or union".into()))
},
}
ast::Adt::Union(it) => (it.name(), it.generic_param_list(), AdtShape::Union),
};
let mut param_type_set: HashSet<String> = HashSet::new();
let param_types = params
let mut param_type_set: FxHashSet<Name> = FxHashSet::default();
let param_types = generic_param_list
.into_iter()
.flat_map(|param_list| param_list.type_or_const_params())
.map(|param| {
@ -235,7 +247,7 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
let this = param.name();
match this {
Some(x) => {
param_type_set.insert(x.to_string());
param_type_set.insert(x.as_name());
mbe::syntax_node_to_token_tree(x.syntax()).0
}
None => tt::Subtree::empty(),
@ -259,37 +271,33 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
(name, ty, bounds)
})
.collect();
let is_associated_type = |p: &PathType| {
if let Some(p) = p.path() {
if let Some(parent) = p.qualifier() {
if let Some(x) = parent.segment() {
if let Some(x) = x.path_type() {
if let Some(x) = x.path() {
if let Some(pname) = x.as_single_name_ref() {
if param_type_set.contains(&pname.to_string()) {
// <T as Trait>::Assoc
return true;
}
}
}
}
}
if let Some(pname) = parent.as_single_name_ref() {
if param_type_set.contains(&pname.to_string()) {
// T::Assoc
return true;
}
}
}
}
false
// For a generic parameter `T`, when shorthand associated type `T::Assoc` appears in field
// types (of any variant for enums), we generate trait bound for it. It sounds reasonable to
// also generate trait bound for qualified associated type `<T as Trait>::Assoc`, but rustc
// does not do that for some unknown reason.
//
// See the analogous function in rustc [find_type_parameters()] and rust-lang/rust#50730.
// [find_type_parameters()]: https://github.com/rust-lang/rust/blob/1.70.0/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs#L378
// It's cumbersome to deal with the distinct structures of ADTs, so let's just get untyped
// `SyntaxNode` that contains fields and look for descendant `ast::PathType`s. Of note is that
// we should not inspect `ast::PathType`s in parameter bounds and where clauses.
let field_list = match adt {
ast::Adt::Enum(it) => it.variant_list().map(|list| list.syntax().clone()),
ast::Adt::Struct(it) => it.field_list().map(|list| list.syntax().clone()),
ast::Adt::Union(it) => it.record_field_list().map(|list| list.syntax().clone()),
};
let associated_types = node
.descendants()
.filter_map(PathType::cast)
.filter(is_associated_type)
let associated_types = field_list
.into_iter()
.flat_map(|it| it.descendants())
.filter_map(ast::PathType::cast)
.filter_map(|p| {
let name = p.path()?.qualifier()?.as_single_name_ref()?.as_name();
param_type_set.contains(&name).then_some(p)
})
.map(|x| mbe::syntax_node_to_token_tree(x.syntax()).0)
.collect::<Vec<_>>();
.collect();
let name_token = name_to_token(&token_map, name)?;
Ok(BasicAdtInfo { name: name_token, shape, param_types, associated_types })
}
@ -334,18 +342,18 @@ fn name_to_token(token_map: &TokenMap, name: Option<ast::Name>) -> Result<tt::Id
/// }
/// ```
///
/// where B1, ..., BN are the bounds given by `bounds_paths`.'. Z is a phantom type, and
/// where B1, ..., BN are the bounds given by `bounds_paths`. Z is a phantom type, and
/// therefore does not get bound by the derived trait.
fn expand_simple_derive(
tt: &tt::Subtree,
trait_path: tt::Subtree,
trait_body: impl FnOnce(&BasicAdtInfo) -> tt::Subtree,
make_trait_body: impl FnOnce(&BasicAdtInfo) -> tt::Subtree,
) -> ExpandResult<tt::Subtree> {
let info = match parse_adt(tt) {
Ok(info) => info,
Err(e) => return ExpandResult::new(tt::Subtree::empty(), e),
};
let trait_body = trait_body(&info);
let trait_body = make_trait_body(&info);
let mut where_block = vec![];
let (params, args): (Vec<_>, Vec<_>) = info
.param_types

View File

@ -4335,8 +4335,9 @@ impl Tr for Copy {
#[derive(Clone)]
struct AssocGeneric<T: Tr>(T::Assoc);
#[derive(Clone)]
struct AssocGeneric2<T: Tr>(<T as Tr>::Assoc);
// Currently rustc does not accept this.
// #[derive(Clone)]
// struct AssocGeneric2<T: Tr>(<T as Tr>::Assoc);
#[derive(Clone)]
struct AssocGeneric3<T: Tr>(Generic<T::Assoc>);
@ -4361,9 +4362,8 @@ fn f() {
let x: &AssocGeneric<Copy> = &AssocGeneric(NotCopy);
let x = x.clone();
//^ &AssocGeneric<Copy>
let x: &AssocGeneric2<Copy> = &AssocGeneric2(NotCopy);
let x = x.clone();
//^ &AssocGeneric2<Copy>
// let x: &AssocGeneric2<Copy> = &AssocGeneric2(NotCopy);
// let x = x.clone();
let x: &AssocGeneric3<Copy> = &AssocGeneric3(Generic(NotCopy));
let x = x.clone();
//^ &AssocGeneric3<Copy>