Deriving: Include bound generic params for extracted type parameters in where clause

This commit is contained in:
Audun Halland 2021-09-29 00:46:29 +02:00
parent 8f8092cc32
commit f0e99827f8
2 changed files with 54 additions and 10 deletions

View File

@ -332,20 +332,27 @@ pub fn combine_substructure(
RefCell::new(f)
}
struct TypeParameter {
bound_generic_params: Vec<ast::GenericParam>,
ty: P<ast::Ty>,
}
/// This method helps to extract all the type parameters referenced from a
/// type. For a type parameter `<T>`, it looks for either a `TyPath` that
/// is not global and starts with `T`, or a `TyQPath`.
/// Also include bound generic params from the input type.
fn find_type_parameters(
ty: &ast::Ty,
ty_param_names: &[Symbol],
cx: &ExtCtxt<'_>,
) -> Vec<P<ast::Ty>> {
) -> Vec<TypeParameter> {
use rustc_ast::visit;
struct Visitor<'a, 'b> {
cx: &'a ExtCtxt<'b>,
ty_param_names: &'a [Symbol],
types: Vec<P<ast::Ty>>,
bound_generic_params_stack: Vec<ast::GenericParam>,
type_params: Vec<TypeParameter>,
}
impl<'a, 'b> visit::Visitor<'a> for Visitor<'a, 'b> {
@ -353,7 +360,10 @@ fn find_type_parameters(
if let ast::TyKind::Path(_, ref path) = ty.kind {
if let Some(segment) = path.segments.first() {
if self.ty_param_names.contains(&segment.ident.name) {
self.types.push(P(ty.clone()));
self.type_params.push(TypeParameter {
bound_generic_params: self.bound_generic_params_stack.clone(),
ty: P(ty.clone()),
});
}
}
}
@ -361,15 +371,35 @@ fn find_type_parameters(
visit::walk_ty(self, ty)
}
// Place bound generic params on a stack, to extract them when a type is encountered.
fn visit_poly_trait_ref(
&mut self,
trait_ref: &'a ast::PolyTraitRef,
modifier: &'a ast::TraitBoundModifier,
) {
let stack_len = trait_ref.bound_generic_params.len();
self.bound_generic_params_stack
.extend(trait_ref.bound_generic_params.clone().into_iter());
visit::walk_poly_trait_ref(self, trait_ref, modifier);
self.bound_generic_params_stack.truncate(stack_len);
}
fn visit_mac_call(&mut self, mac: &ast::MacCall) {
self.cx.span_err(mac.span(), "`derive` cannot be used on items with type macros");
}
}
let mut visitor = Visitor { cx, ty_param_names, types: Vec::new() };
let mut visitor = Visitor {
cx,
ty_param_names,
bound_generic_params_stack: Vec::new(),
type_params: Vec::new(),
};
visit::Visitor::visit_ty(&mut visitor, ty);
visitor.types
visitor.type_params
}
impl<'a> TraitDef<'a> {
@ -617,11 +647,11 @@ impl<'a> TraitDef<'a> {
ty_params.map(|ty_param| ty_param.ident.name).collect();
for field_ty in field_tys {
let tys = find_type_parameters(&field_ty, &ty_param_names, cx);
let field_ty_params = find_type_parameters(&field_ty, &ty_param_names, cx);
for ty in tys {
for field_ty_param in field_ty_params {
// if we have already handled this type, skip it
if let ast::TyKind::Path(_, ref p) = ty.kind {
if let ast::TyKind::Path(_, ref p) = field_ty_param.ty.kind {
if p.segments.len() == 1
&& ty_param_names.contains(&p.segments[0].ident.name)
{
@ -639,8 +669,8 @@ impl<'a> TraitDef<'a> {
let predicate = ast::WhereBoundPredicate {
span: self.span,
bound_generic_params: Vec::new(),
bounded_ty: ty,
bound_generic_params: field_ty_param.bound_generic_params,
bounded_ty: field_ty_param.ty,
bounds,
};

View File

@ -0,0 +1,14 @@
// check-pass
#![feature(generic_associated_types)]
trait CallWithShim: Sized {
type Shim<'s>
where
Self: 's;
}
#[derive(Clone)]
struct ShimMethod<T: CallWithShim + 'static>(pub &'static dyn for<'s> Fn(&'s mut T::Shim<'s>));
pub fn main() {}