Move associated type search into with_bounds

This commit is contained in:
David Tolnay 2018-04-12 22:39:26 -07:00
parent f06001c086
commit 24700ebeb6
No known key found for this signature in database
GPG Key ID: F9BA143B95FF6D82
3 changed files with 43 additions and 60 deletions

View File

@ -8,8 +8,9 @@
use std::collections::HashSet;
use syn::{self, visit};
use syn;
use syn::punctuated::{Punctuated, Pair};
use syn::visit::{self, Visit};
use internals::ast::{Data, Container};
use internals::attr;
@ -50,46 +51,23 @@ pub fn with_where_predicates(
generics
}
pub fn with_where_predicates_from_fields<F, W>(
pub fn with_where_predicates_from_fields<F>(
cont: &Container,
generics: &syn::Generics,
trait_bound: &syn::Path,
from_field: F,
gen_bound_where: W,
) -> syn::Generics
where
F: Fn(&attr::Field) -> Option<&[syn::WherePredicate]>,
W: Fn(&attr::Field) -> bool,
{
let type_params = generics.type_params()
.map(|param| param.ident)
.collect::<HashSet<_>>();
let predicates_from_associated_types = cont.data
.all_fields()
.filter(|field| gen_bound_where(&field.attrs))
.filter_map(|field| {
if let syn::Type::Path(ref ty) = *field.ty {
if let Some(Pair::Punctuated(ref t, _)) = ty.path.segments.first() {
if type_params.contains(&t.ident) {
return Some(parse_quote!(#ty: #trait_bound));
}
}
}
None::<syn::WherePredicate>
});
let predicates_from_field_attrs = cont.data
let predicates = cont.data
.all_fields()
.flat_map(|field| from_field(&field.attrs))
.flat_map(|predicates| predicates.to_vec());
let mut generics = generics.clone();
{
let predicates = &mut generics.make_where_clause().predicates;
predicates.extend(predicates_from_associated_types);
predicates.extend(predicates_from_field_attrs);
}
generics.make_where_clause()
.predicates
.extend(predicates);
generics
}
@ -113,17 +91,33 @@ pub fn with_bound<F>(
where
F: Fn(&attr::Field, Option<&attr::Variant>) -> bool,
{
struct FindTyParams {
struct FindTyParams<'ast> {
// Set of all generic type parameters on the current struct (A, B, C in
// the example). Initialized up front.
all_type_params: HashSet<syn::Ident>,
// Set of generic type parameters used in fields for which filter
// returns true (A and B in the example). Filled in as the visitor sees
// them.
relevant_type_params: HashSet<syn::Ident>,
// Fields whose type is an associated type of one of the generic type
// parameters.
associated_type_usage: Vec<&'ast syn::TypePath>,
}
impl<'ast> visit::Visit<'ast> for FindTyParams {
fn visit_path(&mut self, path: &syn::Path) {
impl<'ast> Visit<'ast> for FindTyParams<'ast> {
fn visit_field(&mut self, field: &'ast syn::Field) {
if let syn::Type::Path(ref ty) = field.ty {
if let Some(Pair::Punctuated(ref t, _)) = ty.path.segments.first() {
if self.all_type_params.contains(&t.ident) {
self.associated_type_usage.push(ty);
}
}
}
self.visit_type(&field.ty);
}
fn visit_path(&mut self, path: &'ast syn::Path) {
if let Some(seg) = path.segments.last() {
if seg.into_value().ident == "PhantomData" {
// Hardcoded exception, because PhantomData<T> implements
@ -146,7 +140,7 @@ where
// mac: T!(),
// marker: PhantomData<T>,
// }
fn visit_macro(&mut self, _mac: &syn::Macro) {}
fn visit_macro(&mut self, _mac: &'ast syn::Macro) {}
}
let all_type_params = generics.type_params()
@ -156,6 +150,7 @@ where
let mut visitor = FindTyParams {
all_type_params: all_type_params,
relevant_type_params: HashSet::new(),
associated_type_usage: Vec::new(),
};
match cont.data {
Data::Enum(ref variants) => for variant in variants.iter() {
@ -164,27 +159,28 @@ where
.iter()
.filter(|field| filter(&field.attrs, Some(&variant.attrs)));
for field in relevant_fields {
visit::visit_type(&mut visitor, field.ty);
visitor.visit_field(field.original);
}
},
Data::Struct(_, ref fields) => {
for field in fields.iter().filter(|field| filter(&field.attrs, None)) {
visit::visit_type(&mut visitor, field.ty);
visitor.visit_field(field.original);
}
}
}
let relevant_type_params = visitor.relevant_type_params;
let associated_type_usage = visitor.associated_type_usage;
let new_predicates = generics.type_params()
.map(|param| param.ident)
.filter(|id| visitor.relevant_type_params.contains(id))
.map(|id| {
.filter(|id| relevant_type_params.contains(id))
.map(|id| syn::TypePath { qself: None, path: id.into() })
.chain(associated_type_usage.into_iter().cloned())
.map(|bounded_ty| {
syn::WherePredicate::Type(syn::PredicateType {
lifetimes: None,
// the type parameter that is being bounded e.g. T
bounded_ty: syn::Type::Path(syn::TypePath {
qself: None,
path: id.into(),
}),
bounded_ty: syn::Type::Path(bounded_ty),
colon_token: Default::default(),
// the bound e.g. Serialize
bounds: vec![

View File

@ -124,15 +124,7 @@ impl Parameters {
fn build_generics(cont: &Container, borrowed: &BorrowedLifetimes) -> syn::Generics {
let generics = bound::without_defaults(cont.generics);
let delife = borrowed.de_lifetime();
let de_bound = parse_quote!(_serde::Deserialize<#delife>);
let generics = bound::with_where_predicates_from_fields(
cont,
&generics,
&de_bound,
attr::Field::de_bound,
|field| field.deserialize_with().is_none() && !field.skip_deserializing()
);
let generics = bound::with_where_predicates_from_fields(cont, &generics, attr::Field::de_bound);
match cont.attrs.de_bound() {
Some(predicates) => bound::with_where_predicates(&generics, predicates),
@ -144,11 +136,12 @@ fn build_generics(cont: &Container, borrowed: &BorrowedLifetimes) -> syn::Generi
attr::Default::None | attr::Default::Path(_) => generics,
};
let delife = borrowed.de_lifetime();
let generics = bound::with_bound(
cont,
&generics,
needs_deserialize_bound,
&de_bound,
&parse_quote!(_serde::Deserialize<#delife>),
);
bound::with_bound(

View File

@ -130,14 +130,8 @@ impl Parameters {
fn build_generics(cont: &Container) -> syn::Generics {
let generics = bound::without_defaults(cont.generics);
let trait_bound = parse_quote!(_serde::Serialize);
let generics = bound::with_where_predicates_from_fields(
cont,
&generics,
&trait_bound,
attr::Field::ser_bound,
|field| field.serialize_with().is_none() && !field.skip_serializing()
);
let generics =
bound::with_where_predicates_from_fields(cont, &generics, attr::Field::ser_bound);
match cont.attrs.ser_bound() {
Some(predicates) => bound::with_where_predicates(&generics, predicates),
@ -145,7 +139,7 @@ fn build_generics(cont: &Container) -> syn::Generics {
cont,
&generics,
needs_serialize_bound,
&trait_bound
&parse_quote!(_serde::Serialize),
),
}
}