Add bounds for fields in derive macro
This commit is contained in:
parent
e739c999cd
commit
0241b52dad
@ -16,7 +16,7 @@ fn test_copy_expand_simple() {
|
||||
#[derive(Copy)]
|
||||
struct Foo;
|
||||
|
||||
impl < > core::marker::Copy for Foo< > {}"#]],
|
||||
impl < > core::marker::Copy for Foo< > where {}"#]],
|
||||
);
|
||||
}
|
||||
|
||||
@ -41,7 +41,7 @@ fn test_copy_expand_in_core() {
|
||||
#[derive(Copy)]
|
||||
struct Foo;
|
||||
|
||||
impl < > crate ::marker::Copy for Foo< > {}"#]],
|
||||
impl < > crate ::marker::Copy for Foo< > where {}"#]],
|
||||
);
|
||||
}
|
||||
|
||||
@ -57,7 +57,7 @@ fn test_copy_expand_with_type_params() {
|
||||
#[derive(Copy)]
|
||||
struct Foo<A, B>;
|
||||
|
||||
impl <T0: core::marker::Copy, T1: core::marker::Copy, > core::marker::Copy for Foo<T0, T1, > {}"#]],
|
||||
impl <A: core::marker::Copy, B: core::marker::Copy, > core::marker::Copy for Foo<A, B, > where {}"#]],
|
||||
);
|
||||
}
|
||||
|
||||
@ -74,7 +74,7 @@ fn test_copy_expand_with_lifetimes() {
|
||||
#[derive(Copy)]
|
||||
struct Foo<A, B, 'a, 'b>;
|
||||
|
||||
impl <T0: core::marker::Copy, T1: core::marker::Copy, > core::marker::Copy for Foo<T0, T1, > {}"#]],
|
||||
impl <A: core::marker::Copy, B: core::marker::Copy, > core::marker::Copy for Foo<A, B, > where {}"#]],
|
||||
);
|
||||
}
|
||||
|
||||
@ -90,7 +90,7 @@ fn test_clone_expand() {
|
||||
#[derive(Clone)]
|
||||
struct Foo<A, B>;
|
||||
|
||||
impl <T0: core::clone::Clone, T1: core::clone::Clone, > core::clone::Clone for Foo<T0, T1, > {}"#]],
|
||||
impl <A: core::clone::Clone, B: core::clone::Clone, > core::clone::Clone for Foo<A, B, > where {}"#]],
|
||||
);
|
||||
}
|
||||
|
||||
@ -106,6 +106,6 @@ fn test_clone_expand_with_const_generics() {
|
||||
#[derive(Clone)]
|
||||
struct Foo<const X: usize, T>(u32);
|
||||
|
||||
impl <const T0: usize, T1: core::clone::Clone, > core::clone::Clone for Foo<T0, T1, > {}"#]],
|
||||
impl <const X: usize, T: core::clone::Clone, > core::clone::Clone for Foo<X, T, > where u32: core::clone::Clone, {}"#]],
|
||||
);
|
||||
}
|
||||
|
@ -1,11 +1,12 @@
|
||||
//! Builtin derives.
|
||||
|
||||
use base_db::{CrateOrigin, LangCrateOrigin};
|
||||
use either::Either;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::tt::{self, TokenId};
|
||||
use syntax::{
|
||||
ast::{self, AstNode, HasGenericParams, HasModuleItem, HasName},
|
||||
ast::{self, AstNode, HasGenericParams, HasModuleItem, HasName, HasTypeBounds},
|
||||
match_ast,
|
||||
};
|
||||
|
||||
@ -60,8 +61,11 @@ pub fn find_builtin_derive(ident: &name::Name) -> Option<BuiltinDeriveExpander>
|
||||
|
||||
struct BasicAdtInfo {
|
||||
name: tt::Ident,
|
||||
/// `Some(ty)` if it's a const param of type `ty`, `None` if it's a type param.
|
||||
param_types: Vec<Option<tt::Subtree>>,
|
||||
/// first field is the name, and
|
||||
/// second field is `Some(ty)` if it's a const param of type `ty`, `None` if it's a type param.
|
||||
/// third fields is where bounds, if any
|
||||
param_types: Vec<(tt::Subtree, Option<tt::Subtree>, Option<tt::Subtree>)>,
|
||||
field_types: Vec<tt::Subtree>,
|
||||
}
|
||||
|
||||
fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
|
||||
@ -75,17 +79,34 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
|
||||
ExpandError::Other("no item found".into())
|
||||
})?;
|
||||
let node = item.syntax();
|
||||
let (name, params) = match_ast! {
|
||||
let (name, params, fields) = match_ast! {
|
||||
match node {
|
||||
ast::Struct(it) => (it.name(), it.generic_param_list()),
|
||||
ast::Enum(it) => (it.name(), it.generic_param_list()),
|
||||
ast::Union(it) => (it.name(), it.generic_param_list()),
|
||||
ast::Struct(it) => {
|
||||
(it.name(), it.generic_param_list(), it.field_list().into_iter().collect::<Vec<_>>())
|
||||
},
|
||||
ast::Enum(it) => (it.name(), it.generic_param_list(), it.variant_list().into_iter().flat_map(|x| x.variants()).filter_map(|x| x.field_list()).collect()),
|
||||
ast::Union(it) => (it.name(), it.generic_param_list(), it.record_field_list().into_iter().map(|x| ast::FieldList::RecordFieldList(x)).collect()),
|
||||
_ => {
|
||||
debug!("unexpected node is {:?}", node);
|
||||
return Err(ExpandError::Other("expected struct, enum or union".into()))
|
||||
},
|
||||
}
|
||||
};
|
||||
let field_types = fields
|
||||
.into_iter()
|
||||
.flat_map(|f| match f {
|
||||
ast::FieldList::RecordFieldList(x) => Either::Left(
|
||||
x.fields()
|
||||
.filter_map(|x| x.ty())
|
||||
.map(|x| mbe::syntax_node_to_token_tree(x.syntax()).0),
|
||||
),
|
||||
ast::FieldList::TupleFieldList(x) => Either::Right(
|
||||
x.fields()
|
||||
.filter_map(|x| x.ty())
|
||||
.map(|x| mbe::syntax_node_to_token_tree(x.syntax()).0),
|
||||
),
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let name = name.ok_or_else(|| {
|
||||
debug!("parsed item has no name");
|
||||
ExpandError::Other("missing name".into())
|
||||
@ -97,7 +118,17 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
|
||||
.into_iter()
|
||||
.flat_map(|param_list| param_list.type_or_const_params())
|
||||
.map(|param| {
|
||||
if let ast::TypeOrConstParam::Const(param) = param {
|
||||
let name = param
|
||||
.name()
|
||||
.map(|x| mbe::syntax_node_to_token_tree(x.syntax()).0)
|
||||
.unwrap_or_else(tt::Subtree::empty);
|
||||
let bounds = match ¶m {
|
||||
ast::TypeOrConstParam::Type(x) => {
|
||||
x.type_bound_list().map(|x| mbe::syntax_node_to_token_tree(x.syntax()).0)
|
||||
}
|
||||
ast::TypeOrConstParam::Const(_) => None,
|
||||
};
|
||||
let ty = if let ast::TypeOrConstParam::Const(param) = param {
|
||||
let ty = param
|
||||
.ty()
|
||||
.map(|ty| mbe::syntax_node_to_token_tree(ty.syntax()).0)
|
||||
@ -105,10 +136,11 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
|
||||
Some(ty)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
(name, ty, bounds)
|
||||
})
|
||||
.collect();
|
||||
Ok(BasicAdtInfo { name: name_token, param_types })
|
||||
Ok(BasicAdtInfo { name: name_token, param_types, field_types })
|
||||
}
|
||||
|
||||
fn expand_simple_derive(tt: &tt::Subtree, trait_path: tt::Subtree) -> ExpandResult<tt::Subtree> {
|
||||
@ -116,16 +148,16 @@ fn expand_simple_derive(tt: &tt::Subtree, trait_path: tt::Subtree) -> ExpandResu
|
||||
Ok(info) => info,
|
||||
Err(e) => return ExpandResult::with_err(tt::Subtree::empty(), e),
|
||||
};
|
||||
let mut where_block = vec![];
|
||||
let (params, args): (Vec<_>, Vec<_>) = info
|
||||
.param_types
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(idx, param_ty)| {
|
||||
let ident = tt::Leaf::Ident(tt::Ident {
|
||||
span: tt::TokenId::unspecified(),
|
||||
text: format!("T{idx}").into(),
|
||||
});
|
||||
.map(|(ident, param_ty, bound)| {
|
||||
let ident_ = ident.clone();
|
||||
if let Some(b) = bound {
|
||||
let ident = ident.clone();
|
||||
where_block.push(quote! { #ident : #b , });
|
||||
}
|
||||
if let Some(ty) = param_ty {
|
||||
(quote! { const #ident : #ty , }, quote! { #ident_ , })
|
||||
} else {
|
||||
@ -134,9 +166,16 @@ fn expand_simple_derive(tt: &tt::Subtree, trait_path: tt::Subtree) -> ExpandResu
|
||||
}
|
||||
})
|
||||
.unzip();
|
||||
|
||||
where_block.extend(info.field_types.iter().map(|x| {
|
||||
let x = x.clone();
|
||||
let bound = trait_path.clone();
|
||||
quote! { #x : #bound , }
|
||||
}));
|
||||
|
||||
let name = info.name;
|
||||
let expanded = quote! {
|
||||
impl < ##params > #trait_path for #name < ##args > {}
|
||||
impl < ##params > #trait_path for #name < ##args > where ##where_block {}
|
||||
};
|
||||
ExpandResult::ok(expanded)
|
||||
}
|
||||
|
@ -471,7 +471,7 @@ struct Foo {}
|
||||
"#,
|
||||
expect![[r#"
|
||||
Clone
|
||||
impl < >core::clone::Clone for Foo< >{}
|
||||
impl < >core::clone::Clone for Foo< >where{}
|
||||
"#]],
|
||||
);
|
||||
}
|
||||
@ -488,7 +488,7 @@ struct Foo {}
|
||||
"#,
|
||||
expect![[r#"
|
||||
Copy
|
||||
impl < >core::marker::Copy for Foo< >{}
|
||||
impl < >core::marker::Copy for Foo< >where{}
|
||||
"#]],
|
||||
);
|
||||
}
|
||||
@ -504,7 +504,7 @@ struct Foo {}
|
||||
"#,
|
||||
expect![[r#"
|
||||
Copy
|
||||
impl < >core::marker::Copy for Foo< >{}
|
||||
impl < >core::marker::Copy for Foo< >where{}
|
||||
"#]],
|
||||
);
|
||||
check(
|
||||
@ -516,7 +516,7 @@ struct Foo {}
|
||||
"#,
|
||||
expect![[r#"
|
||||
Clone
|
||||
impl < >core::clone::Clone for Foo< >{}
|
||||
impl < >core::clone::Clone for Foo< >where{}
|
||||
"#]],
|
||||
);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user