fix: only generate trait bound for associated types in field types

This commit is contained in:
Ryo Yoshida 2023-06-07 18:37:09 +09:00
parent 1c25885bd2
commit 4f0c6fac17
No known key found for this signature in database
GPG Key ID: E25698A930586171
3 changed files with 139 additions and 71 deletions

View File

@ -114,6 +114,66 @@ impl <A: core::clone::Clone, B: core::clone::Clone, > core::clone::Clone for Com
);
}
#[test]
fn test_clone_expand_with_associated_types() {
check(
r#"
//- minicore: derive, clone
trait Trait {
type InWc;
type InFieldQualified;
type InFieldShorthand;
type InGenericArg;
}
trait Marker {}
struct Vec<T>(T);
#[derive(Clone)]
struct Foo<T: Trait>
where
<T as Trait>::InWc: Marker,
{
qualified: <T as Trait>::InFieldQualified,
shorthand: T::InFieldShorthand,
generic: Vec<T::InGenericArg>,
}
"#,
expect![[r#"
trait Trait {
type InWc;
type InFieldQualified;
type InFieldShorthand;
type InGenericArg;
}
trait Marker {}
struct Vec<T>(T);
#[derive(Clone)]
struct Foo<T: Trait>
where
<T as Trait>::InWc: Marker,
{
qualified: <T as Trait>::InFieldQualified,
shorthand: T::InFieldShorthand,
generic: Vec<T::InGenericArg>,
}
impl <T: core::clone::Clone, > core::clone::Clone for Foo<T, > where T: Trait, T::InFieldShorthand: core::clone::Clone, T::InGenericArg: core::clone::Clone, {
fn clone(&self ) -> Self {
match self {
Foo {
qualified: qualified, shorthand: shorthand, generic: generic,
}
=>Foo {
qualified: qualified.clone(), shorthand: shorthand.clone(), generic: generic.clone(),
}
,
}
}
}"#]],
);
}
#[test]
fn test_clone_expand_with_const_generics() {
check(

View File

@ -4,17 +4,16 @@ use ::tt::Ident;
use base_db::{CrateOrigin, LangCrateOrigin};
use itertools::izip;
use mbe::TokenMap;
use std::collections::HashSet;
use rustc_hash::FxHashSet;
use stdx::never;
use tracing::debug;
use crate::tt::{self, TokenId};
use syntax::{
ast::{
self, AstNode, FieldList, HasAttrs, HasGenericParams, HasModuleItem, HasName,
HasTypeBounds, PathType,
},
match_ast,
use crate::{
name::{AsName, Name},
tt::{self, TokenId},
};
use syntax::ast::{
self, AstNode, FieldList, HasAttrs, HasGenericParams, HasModuleItem, HasName, HasTypeBounds,
};
use crate::{db::ExpandDatabase, name, quote, ExpandError, ExpandResult, MacroCallId};
@ -201,33 +200,46 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
debug!("no module item parsed");
ExpandError::Other("no item found".into())
})?;
let node = item.syntax();
let (name, params, shape) = match_ast! {
match node {
ast::Struct(it) => (it.name(), it.generic_param_list(), AdtShape::Struct(VariantShape::from(it.field_list(), &token_map)?)),
ast::Enum(it) => {
let default_variant = it.variant_list().into_iter().flat_map(|x| x.variants()).position(|x| x.attrs().any(|x| x.simple_name() == Some("default".into())));
(
it.name(),
it.generic_param_list(),
AdtShape::Enum {
default_variant,
variants: it.variant_list()
.into_iter()
.flat_map(|x| x.variants())
.map(|x| Ok((name_to_token(&token_map,x.name())?, VariantShape::from(x.field_list(), &token_map)?))).collect::<Result<_, ExpandError>>()?
}
)
},
ast::Union(it) => (it.name(), it.generic_param_list(), AdtShape::Union),
_ => {
debug!("unexpected node is {:?}", node);
return Err(ExpandError::Other("expected struct, enum or union".into()))
},
let adt = ast::Adt::cast(item.syntax().clone()).ok_or_else(|| {
debug!("expected adt, found: {:?}", item);
ExpandError::Other("expected struct, enum or union".into())
})?;
let (name, generic_param_list, shape) = match &adt {
ast::Adt::Struct(it) => (
it.name(),
it.generic_param_list(),
AdtShape::Struct(VariantShape::from(it.field_list(), &token_map)?),
),
ast::Adt::Enum(it) => {
let default_variant = it
.variant_list()
.into_iter()
.flat_map(|x| x.variants())
.position(|x| x.attrs().any(|x| x.simple_name() == Some("default".into())));
(
it.name(),
it.generic_param_list(),
AdtShape::Enum {
default_variant,
variants: it
.variant_list()
.into_iter()
.flat_map(|x| x.variants())
.map(|x| {
Ok((
name_to_token(&token_map, x.name())?,
VariantShape::from(x.field_list(), &token_map)?,
))
})
.collect::<Result<_, ExpandError>>()?,
},
)
}
ast::Adt::Union(it) => (it.name(), it.generic_param_list(), AdtShape::Union),
};
let mut param_type_set: HashSet<String> = HashSet::new();
let param_types = params
let mut param_type_set: FxHashSet<Name> = FxHashSet::default();
let param_types = generic_param_list
.into_iter()
.flat_map(|param_list| param_list.type_or_const_params())
.map(|param| {
@ -235,7 +247,7 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
let this = param.name();
match this {
Some(x) => {
param_type_set.insert(x.to_string());
param_type_set.insert(x.as_name());
mbe::syntax_node_to_token_tree(x.syntax()).0
}
None => tt::Subtree::empty(),
@ -259,37 +271,33 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
(name, ty, bounds)
})
.collect();
let is_associated_type = |p: &PathType| {
if let Some(p) = p.path() {
if let Some(parent) = p.qualifier() {
if let Some(x) = parent.segment() {
if let Some(x) = x.path_type() {
if let Some(x) = x.path() {
if let Some(pname) = x.as_single_name_ref() {
if param_type_set.contains(&pname.to_string()) {
// <T as Trait>::Assoc
return true;
}
}
}
}
}
if let Some(pname) = parent.as_single_name_ref() {
if param_type_set.contains(&pname.to_string()) {
// T::Assoc
return true;
}
}
}
}
false
// For a generic parameter `T`, when shorthand associated type `T::Assoc` appears in field
// types (of any variant for enums), we generate trait bound for it. It sounds reasonable to
// also generate trait bound for qualified associated type `<T as Trait>::Assoc`, but rustc
// does not do that for some unknown reason.
//
// See the analogous function in rustc [find_type_parameters()] and rust-lang/rust#50730.
// [find_type_parameters()]: https://github.com/rust-lang/rust/blob/1.70.0/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs#L378
// It's cumbersome to deal with the distinct structures of ADTs, so let's just get untyped
// `SyntaxNode` that contains fields and look for descendant `ast::PathType`s. Of note is that
// we should not inspect `ast::PathType`s in parameter bounds and where clauses.
let field_list = match adt {
ast::Adt::Enum(it) => it.variant_list().map(|list| list.syntax().clone()),
ast::Adt::Struct(it) => it.field_list().map(|list| list.syntax().clone()),
ast::Adt::Union(it) => it.record_field_list().map(|list| list.syntax().clone()),
};
let associated_types = node
.descendants()
.filter_map(PathType::cast)
.filter(is_associated_type)
let associated_types = field_list
.into_iter()
.flat_map(|it| it.descendants())
.filter_map(ast::PathType::cast)
.filter_map(|p| {
let name = p.path()?.qualifier()?.as_single_name_ref()?.as_name();
param_type_set.contains(&name).then_some(p)
})
.map(|x| mbe::syntax_node_to_token_tree(x.syntax()).0)
.collect::<Vec<_>>();
.collect();
let name_token = name_to_token(&token_map, name)?;
Ok(BasicAdtInfo { name: name_token, shape, param_types, associated_types })
}
@ -334,18 +342,18 @@ fn name_to_token(token_map: &TokenMap, name: Option<ast::Name>) -> Result<tt::Id
/// }
/// ```
///
/// where B1, ..., BN are the bounds given by `bounds_paths`.'. Z is a phantom type, and
/// where B1, ..., BN are the bounds given by `bounds_paths`. Z is a phantom type, and
/// therefore does not get bound by the derived trait.
fn expand_simple_derive(
tt: &tt::Subtree,
trait_path: tt::Subtree,
trait_body: impl FnOnce(&BasicAdtInfo) -> tt::Subtree,
make_trait_body: impl FnOnce(&BasicAdtInfo) -> tt::Subtree,
) -> ExpandResult<tt::Subtree> {
let info = match parse_adt(tt) {
Ok(info) => info,
Err(e) => return ExpandResult::new(tt::Subtree::empty(), e),
};
let trait_body = trait_body(&info);
let trait_body = make_trait_body(&info);
let mut where_block = vec![];
let (params, args): (Vec<_>, Vec<_>) = info
.param_types

View File

@ -4335,8 +4335,9 @@ fn derive_macro_bounds() {
#[derive(Clone)]
struct AssocGeneric<T: Tr>(T::Assoc);
#[derive(Clone)]
struct AssocGeneric2<T: Tr>(<T as Tr>::Assoc);
// Currently rustc does not accept this.
// #[derive(Clone)]
// struct AssocGeneric2<T: Tr>(<T as Tr>::Assoc);
#[derive(Clone)]
struct AssocGeneric3<T: Tr>(Generic<T::Assoc>);
@ -4361,9 +4362,8 @@ fn derive_macro_bounds() {
let x: &AssocGeneric<Copy> = &AssocGeneric(NotCopy);
let x = x.clone();
//^ &AssocGeneric<Copy>
let x: &AssocGeneric2<Copy> = &AssocGeneric2(NotCopy);
let x = x.clone();
//^ &AssocGeneric2<Copy>
// let x: &AssocGeneric2<Copy> = &AssocGeneric2(NotCopy);
// let x = x.clone();
let x: &AssocGeneric3<Copy> = &AssocGeneric3(Generic(NotCopy));
let x = x.clone();
//^ &AssocGeneric3<Copy>