fix: preserve where clause when builtin derive

This commit is contained in:
austaras 2024-02-04 11:35:27 +08:00
parent e9d3565cd1
commit dad0fdb13f
3 changed files with 50 additions and 5 deletions

View File

@ -157,7 +157,7 @@ where
generic: Vec<T::InGenericArg>,
}
impl <T: $crate::clone::Clone, > $crate::clone::Clone for Foo<T, > where T: Trait, T::InFieldShorthand: $crate::clone::Clone, T::InGenericArg: $crate::clone::Clone, {
impl <T: $crate::clone::Clone, > $crate::clone::Clone for Foo<T, > where <T as Trait>::InWc: Marker, T: Trait, T::InFieldShorthand: $crate::clone::Clone, T::InGenericArg: $crate::clone::Clone, {
fn clone(&self ) -> Self {
match self {
Foo {

View File

@ -194,6 +194,7 @@ struct BasicAdtInfo {
/// second field is `Some(ty)` if it's a const param of type `ty`, `None` if it's a type param.
/// third fields is where bounds, if any
param_types: Vec<(tt::Subtree, Option<tt::Subtree>, Option<tt::Subtree>)>,
where_clause: Vec<tt::Subtree>,
associated_types: Vec<tt::Subtree>,
}
@ -202,10 +203,11 @@ fn parse_adt(
adt: &ast::Adt,
call_site: Span,
) -> Result<BasicAdtInfo, ExpandError> {
let (name, generic_param_list, shape) = match adt {
let (name, generic_param_list, where_clause, shape) = match adt {
ast::Adt::Struct(it) => (
it.name(),
it.generic_param_list(),
it.where_clause(),
AdtShape::Struct(VariantShape::from(tm, it.field_list())?),
),
ast::Adt::Enum(it) => {
@ -217,6 +219,7 @@ fn parse_adt(
(
it.name(),
it.generic_param_list(),
it.where_clause(),
AdtShape::Enum {
default_variant,
variants: it
@ -233,7 +236,9 @@ fn parse_adt(
},
)
}
ast::Adt::Union(it) => (it.name(), it.generic_param_list(), AdtShape::Union),
ast::Adt::Union(it) => {
(it.name(), it.generic_param_list(), it.where_clause(), AdtShape::Union)
}
};
let mut param_type_set: FxHashSet<Name> = FxHashSet::default();
@ -274,6 +279,14 @@ fn parse_adt(
})
.collect();
let where_clause = if let Some(w) = where_clause {
w.predicates()
.map(|it| mbe::syntax_node_to_token_tree(it.syntax(), tm, call_site))
.collect()
} else {
vec![]
};
// For a generic parameter `T`, when shorthand associated type `T::Assoc` appears in field
// types (of any variant for enums), we generate trait bound for it. It sounds reasonable to
// also generate trait bound for qualified associated type `<T as Trait>::Assoc`, but rustc
@ -301,7 +314,7 @@ fn parse_adt(
.map(|it| mbe::syntax_node_to_token_tree(it.syntax(), tm, call_site))
.collect();
let name_token = name_to_token(tm, name)?;
Ok(BasicAdtInfo { name: name_token, shape, param_types, associated_types })
Ok(BasicAdtInfo { name: name_token, shape, param_types, where_clause, associated_types })
}
fn name_to_token(
@ -366,7 +379,8 @@ fn expand_simple_derive(
}
};
let trait_body = make_trait_body(&info);
let mut where_block = vec![];
let mut where_block: Vec<_> =
info.where_clause.into_iter().map(|w| quote! {invoc_span => #w , }).collect();
let (params, args): (Vec<_>, Vec<_>) = info
.param_types
.into_iter()

View File

@ -1373,3 +1373,34 @@ pub fn attr_macro() {}
"#,
);
}
#[test]
fn clone_with_type_bound() {
check_types(
r#"
//- minicore: derive, clone, builtin_impls
#[derive(Clone)]
struct Float;
trait TensorKind: Clone {
/// The primitive type of the tensor.
type Primitive: Clone;
}
impl TensorKind for Float {
type Primitive = f64;
}
#[derive(Clone)]
struct Tensor<K = Float> where K: TensorKind
{
primitive: K::Primitive,
}
fn foo(t: Tensor) {
let x = t.clone();
//^ Tensor<Float>
}
"#,
);
}