fix: only generate trait bound for associated types in field types
This commit is contained in:
parent
1c25885bd2
commit
4f0c6fac17
@ -114,6 +114,66 @@ impl <A: core::clone::Clone, B: core::clone::Clone, > core::clone::Clone for Com
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clone_expand_with_associated_types() {
|
||||
check(
|
||||
r#"
|
||||
//- minicore: derive, clone
|
||||
trait Trait {
|
||||
type InWc;
|
||||
type InFieldQualified;
|
||||
type InFieldShorthand;
|
||||
type InGenericArg;
|
||||
}
|
||||
trait Marker {}
|
||||
struct Vec<T>(T);
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Foo<T: Trait>
|
||||
where
|
||||
<T as Trait>::InWc: Marker,
|
||||
{
|
||||
qualified: <T as Trait>::InFieldQualified,
|
||||
shorthand: T::InFieldShorthand,
|
||||
generic: Vec<T::InGenericArg>,
|
||||
}
|
||||
"#,
|
||||
expect![[r#"
|
||||
trait Trait {
|
||||
type InWc;
|
||||
type InFieldQualified;
|
||||
type InFieldShorthand;
|
||||
type InGenericArg;
|
||||
}
|
||||
trait Marker {}
|
||||
struct Vec<T>(T);
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Foo<T: Trait>
|
||||
where
|
||||
<T as Trait>::InWc: Marker,
|
||||
{
|
||||
qualified: <T as Trait>::InFieldQualified,
|
||||
shorthand: T::InFieldShorthand,
|
||||
generic: Vec<T::InGenericArg>,
|
||||
}
|
||||
|
||||
impl <T: core::clone::Clone, > core::clone::Clone for Foo<T, > where T: Trait, T::InFieldShorthand: core::clone::Clone, T::InGenericArg: core::clone::Clone, {
|
||||
fn clone(&self ) -> Self {
|
||||
match self {
|
||||
Foo {
|
||||
qualified: qualified, shorthand: shorthand, generic: generic,
|
||||
}
|
||||
=>Foo {
|
||||
qualified: qualified.clone(), shorthand: shorthand.clone(), generic: generic.clone(),
|
||||
}
|
||||
,
|
||||
}
|
||||
}
|
||||
}"#]],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clone_expand_with_const_generics() {
|
||||
check(
|
||||
|
@ -4,17 +4,16 @@ use ::tt::Ident;
|
||||
use base_db::{CrateOrigin, LangCrateOrigin};
|
||||
use itertools::izip;
|
||||
use mbe::TokenMap;
|
||||
use std::collections::HashSet;
|
||||
use rustc_hash::FxHashSet;
|
||||
use stdx::never;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::tt::{self, TokenId};
|
||||
use syntax::{
|
||||
ast::{
|
||||
self, AstNode, FieldList, HasAttrs, HasGenericParams, HasModuleItem, HasName,
|
||||
HasTypeBounds, PathType,
|
||||
},
|
||||
match_ast,
|
||||
use crate::{
|
||||
name::{AsName, Name},
|
||||
tt::{self, TokenId},
|
||||
};
|
||||
use syntax::ast::{
|
||||
self, AstNode, FieldList, HasAttrs, HasGenericParams, HasModuleItem, HasName, HasTypeBounds,
|
||||
};
|
||||
|
||||
use crate::{db::ExpandDatabase, name, quote, ExpandError, ExpandResult, MacroCallId};
|
||||
@ -201,33 +200,46 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
|
||||
debug!("no module item parsed");
|
||||
ExpandError::Other("no item found".into())
|
||||
})?;
|
||||
let node = item.syntax();
|
||||
let (name, params, shape) = match_ast! {
|
||||
match node {
|
||||
ast::Struct(it) => (it.name(), it.generic_param_list(), AdtShape::Struct(VariantShape::from(it.field_list(), &token_map)?)),
|
||||
ast::Enum(it) => {
|
||||
let default_variant = it.variant_list().into_iter().flat_map(|x| x.variants()).position(|x| x.attrs().any(|x| x.simple_name() == Some("default".into())));
|
||||
(
|
||||
it.name(),
|
||||
it.generic_param_list(),
|
||||
AdtShape::Enum {
|
||||
default_variant,
|
||||
variants: it.variant_list()
|
||||
.into_iter()
|
||||
.flat_map(|x| x.variants())
|
||||
.map(|x| Ok((name_to_token(&token_map,x.name())?, VariantShape::from(x.field_list(), &token_map)?))).collect::<Result<_, ExpandError>>()?
|
||||
}
|
||||
)
|
||||
},
|
||||
ast::Union(it) => (it.name(), it.generic_param_list(), AdtShape::Union),
|
||||
_ => {
|
||||
debug!("unexpected node is {:?}", node);
|
||||
return Err(ExpandError::Other("expected struct, enum or union".into()))
|
||||
},
|
||||
let adt = ast::Adt::cast(item.syntax().clone()).ok_or_else(|| {
|
||||
debug!("expected adt, found: {:?}", item);
|
||||
ExpandError::Other("expected struct, enum or union".into())
|
||||
})?;
|
||||
let (name, generic_param_list, shape) = match &adt {
|
||||
ast::Adt::Struct(it) => (
|
||||
it.name(),
|
||||
it.generic_param_list(),
|
||||
AdtShape::Struct(VariantShape::from(it.field_list(), &token_map)?),
|
||||
),
|
||||
ast::Adt::Enum(it) => {
|
||||
let default_variant = it
|
||||
.variant_list()
|
||||
.into_iter()
|
||||
.flat_map(|x| x.variants())
|
||||
.position(|x| x.attrs().any(|x| x.simple_name() == Some("default".into())));
|
||||
(
|
||||
it.name(),
|
||||
it.generic_param_list(),
|
||||
AdtShape::Enum {
|
||||
default_variant,
|
||||
variants: it
|
||||
.variant_list()
|
||||
.into_iter()
|
||||
.flat_map(|x| x.variants())
|
||||
.map(|x| {
|
||||
Ok((
|
||||
name_to_token(&token_map, x.name())?,
|
||||
VariantShape::from(x.field_list(), &token_map)?,
|
||||
))
|
||||
})
|
||||
.collect::<Result<_, ExpandError>>()?,
|
||||
},
|
||||
)
|
||||
}
|
||||
ast::Adt::Union(it) => (it.name(), it.generic_param_list(), AdtShape::Union),
|
||||
};
|
||||
let mut param_type_set: HashSet<String> = HashSet::new();
|
||||
let param_types = params
|
||||
|
||||
let mut param_type_set: FxHashSet<Name> = FxHashSet::default();
|
||||
let param_types = generic_param_list
|
||||
.into_iter()
|
||||
.flat_map(|param_list| param_list.type_or_const_params())
|
||||
.map(|param| {
|
||||
@ -235,7 +247,7 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
|
||||
let this = param.name();
|
||||
match this {
|
||||
Some(x) => {
|
||||
param_type_set.insert(x.to_string());
|
||||
param_type_set.insert(x.as_name());
|
||||
mbe::syntax_node_to_token_tree(x.syntax()).0
|
||||
}
|
||||
None => tt::Subtree::empty(),
|
||||
@ -259,37 +271,33 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
|
||||
(name, ty, bounds)
|
||||
})
|
||||
.collect();
|
||||
let is_associated_type = |p: &PathType| {
|
||||
if let Some(p) = p.path() {
|
||||
if let Some(parent) = p.qualifier() {
|
||||
if let Some(x) = parent.segment() {
|
||||
if let Some(x) = x.path_type() {
|
||||
if let Some(x) = x.path() {
|
||||
if let Some(pname) = x.as_single_name_ref() {
|
||||
if param_type_set.contains(&pname.to_string()) {
|
||||
// <T as Trait>::Assoc
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(pname) = parent.as_single_name_ref() {
|
||||
if param_type_set.contains(&pname.to_string()) {
|
||||
// T::Assoc
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
|
||||
// For a generic parameter `T`, when shorthand associated type `T::Assoc` appears in field
|
||||
// types (of any variant for enums), we generate trait bound for it. It sounds reasonable to
|
||||
// also generate trait bound for qualified associated type `<T as Trait>::Assoc`, but rustc
|
||||
// does not do that for some unknown reason.
|
||||
//
|
||||
// See the analogous function in rustc [find_type_parameters()] and rust-lang/rust#50730.
|
||||
// [find_type_parameters()]: https://github.com/rust-lang/rust/blob/1.70.0/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs#L378
|
||||
|
||||
// It's cumbersome to deal with the distinct structures of ADTs, so let's just get untyped
|
||||
// `SyntaxNode` that contains fields and look for descendant `ast::PathType`s. Of note is that
|
||||
// we should not inspect `ast::PathType`s in parameter bounds and where clauses.
|
||||
let field_list = match adt {
|
||||
ast::Adt::Enum(it) => it.variant_list().map(|list| list.syntax().clone()),
|
||||
ast::Adt::Struct(it) => it.field_list().map(|list| list.syntax().clone()),
|
||||
ast::Adt::Union(it) => it.record_field_list().map(|list| list.syntax().clone()),
|
||||
};
|
||||
let associated_types = node
|
||||
.descendants()
|
||||
.filter_map(PathType::cast)
|
||||
.filter(is_associated_type)
|
||||
let associated_types = field_list
|
||||
.into_iter()
|
||||
.flat_map(|it| it.descendants())
|
||||
.filter_map(ast::PathType::cast)
|
||||
.filter_map(|p| {
|
||||
let name = p.path()?.qualifier()?.as_single_name_ref()?.as_name();
|
||||
param_type_set.contains(&name).then_some(p)
|
||||
})
|
||||
.map(|x| mbe::syntax_node_to_token_tree(x.syntax()).0)
|
||||
.collect::<Vec<_>>();
|
||||
.collect();
|
||||
let name_token = name_to_token(&token_map, name)?;
|
||||
Ok(BasicAdtInfo { name: name_token, shape, param_types, associated_types })
|
||||
}
|
||||
@ -334,18 +342,18 @@ fn name_to_token(token_map: &TokenMap, name: Option<ast::Name>) -> Result<tt::Id
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// where B1, ..., BN are the bounds given by `bounds_paths`.'. Z is a phantom type, and
|
||||
/// where B1, ..., BN are the bounds given by `bounds_paths`. Z is a phantom type, and
|
||||
/// therefore does not get bound by the derived trait.
|
||||
fn expand_simple_derive(
|
||||
tt: &tt::Subtree,
|
||||
trait_path: tt::Subtree,
|
||||
trait_body: impl FnOnce(&BasicAdtInfo) -> tt::Subtree,
|
||||
make_trait_body: impl FnOnce(&BasicAdtInfo) -> tt::Subtree,
|
||||
) -> ExpandResult<tt::Subtree> {
|
||||
let info = match parse_adt(tt) {
|
||||
Ok(info) => info,
|
||||
Err(e) => return ExpandResult::new(tt::Subtree::empty(), e),
|
||||
};
|
||||
let trait_body = trait_body(&info);
|
||||
let trait_body = make_trait_body(&info);
|
||||
let mut where_block = vec![];
|
||||
let (params, args): (Vec<_>, Vec<_>) = info
|
||||
.param_types
|
||||
|
@ -4335,8 +4335,9 @@ fn derive_macro_bounds() {
|
||||
#[derive(Clone)]
|
||||
struct AssocGeneric<T: Tr>(T::Assoc);
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AssocGeneric2<T: Tr>(<T as Tr>::Assoc);
|
||||
// Currently rustc does not accept this.
|
||||
// #[derive(Clone)]
|
||||
// struct AssocGeneric2<T: Tr>(<T as Tr>::Assoc);
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AssocGeneric3<T: Tr>(Generic<T::Assoc>);
|
||||
@ -4361,9 +4362,8 @@ fn derive_macro_bounds() {
|
||||
let x: &AssocGeneric<Copy> = &AssocGeneric(NotCopy);
|
||||
let x = x.clone();
|
||||
//^ &AssocGeneric<Copy>
|
||||
let x: &AssocGeneric2<Copy> = &AssocGeneric2(NotCopy);
|
||||
let x = x.clone();
|
||||
//^ &AssocGeneric2<Copy>
|
||||
// let x: &AssocGeneric2<Copy> = &AssocGeneric2(NotCopy);
|
||||
// let x = x.clone();
|
||||
let x: &AssocGeneric3<Copy> = &AssocGeneric3(Generic(NotCopy));
|
||||
let x = x.clone();
|
||||
//^ &AssocGeneric3<Copy>
|
||||
|
Loading…
x
Reference in New Issue
Block a user