Merge pull request #456 from serde-rs/generic

Generate bounds on type parameters only
This commit is contained in:
David Tolnay 2016-07-22 07:58:56 -07:00 committed by GitHub
commit 85772726ee
2 changed files with 85 additions and 87 deletions

View File

@ -1,7 +1,8 @@
use std::collections::HashSet;
use aster::AstBuilder;
use syntax::ast;
use syntax::ptr::P;
use syntax::visit;
use internals::ast::Item;
@ -47,6 +48,17 @@ pub fn with_where_predicates_from_fields<F>(
.build()
}
// Puts the given bound on any generic type parameters that are used in fields
// for which filter returns true.
//
// For example, the following struct needs the bound `A: Serialize, B: Serialize`.
//
// struct S<'b, A, B: 'b, C> {
// a: A,
// b: Option<&'b B>
// #[serde(skip_serializing)]
// c: C,
// }
pub fn with_bound<F>(
builder: &AstBuilder,
item: &Item,
@ -56,95 +68,53 @@ pub fn with_bound<F>(
) -> ast::Generics
where F: Fn(&attr::Field) -> bool,
{
struct FindTyParams {
// Set of all generic type parameters on the current struct (A, B, C in
// the example). Initialized up front.
all_ty_params: HashSet<ast::Name>,
// Set of generic type parameters used in fields for which filter
// returns true (A and B in the example). Filled in as the visitor sees
// them.
relevant_ty_params: HashSet<ast::Name>,
}
impl visit::Visitor for FindTyParams {
fn visit_path(&mut self, path: &ast::Path, _id: ast::NodeId) {
if !path.global && path.segments.len() == 1 {
let id = path.segments[0].identifier.name;
if self.all_ty_params.contains(&id) {
self.relevant_ty_params.insert(id);
}
}
visit::walk_path(self, path);
}
}
let all_ty_params: HashSet<_> = generics.ty_params.iter()
.map(|ty_param| ty_param.ident.name)
.collect();
let relevant_tys = item.body.all_fields()
.filter(|&field| filter(&field.attrs))
.map(|field| &field.ty);
let mut visitor = FindTyParams {
all_ty_params: all_ty_params,
relevant_ty_params: HashSet::new(),
};
for ty in relevant_tys {
visit::walk_ty(&mut visitor, ty);
}
builder.from_generics(generics.clone())
.with_predicates(
item.body.all_fields()
.filter(|&field| filter(&field.attrs))
.map(|field| &field.ty)
.filter(|ty| !contains_recursion(ty, item.ident))
.map(|ty| strip_reference(ty))
.map(|ty| builder.where_predicate()
// the type that is being bounded e.g. T
.bound().build(ty.clone())
generics.ty_params.iter()
.map(|ty_param| ty_param.ident.name)
.filter(|id| visitor.relevant_ty_params.contains(id))
.map(|id| builder.where_predicate()
// the type parameter that is being bounded e.g. T
.bound().build(builder.ty().id(id))
// the bound e.g. Serialize
.bound().trait_(bound.clone()).build()
.build()))
.build()
}
// We do not attempt to generate any bounds based on field types that are
// directly recursive, as in:
//
// struct Test<D> {
// next: Box<Test<D>>,
// }
//
// This does not catch field types that are mutually recursive with some other
// type. For those, we require bounds to be specified by a `bound` attribute if
// the inferred ones are not correct.
//
// struct Test<D> {
// #[serde(bound="D: Serialize + Deserialize")]
// next: Box<Other<D>>,
// }
// struct Other<D> {
// #[serde(bound="D: Serialize + Deserialize")]
// next: Box<Test<D>>,
// }
fn contains_recursion(ty: &ast::Ty, ident: ast::Ident) -> bool {
struct FindRecursion {
ident: ast::Ident,
found_recursion: bool,
}
impl visit::Visitor for FindRecursion {
fn visit_path(&mut self, path: &ast::Path, _id: ast::NodeId) {
if !path.global
&& path.segments.len() == 1
&& path.segments[0].identifier == self.ident {
self.found_recursion = true;
} else {
visit::walk_path(self, path);
}
}
}
let mut visitor = FindRecursion {
ident: ident,
found_recursion: false,
};
visit::walk_ty(&mut visitor, ty);
visitor.found_recursion
}
// This is required to handle types that use both a reference and a value of
// the same type, as in:
//
// enum Test<'a, T> where T: 'a {
// Lifetime(&'a T),
// NoLifetime(T),
// }
//
// Preserving references, we would generate an impl like:
//
// impl<'a, T> Serialize for Test<'a, T>
// where &'a T: Serialize,
// T: Serialize { ... }
//
// And taking a reference to one of the elements would fail with:
//
// error: cannot infer an appropriate lifetime for pattern due
// to conflicting requirements [E0495]
// Test::NoLifetime(ref v) => { ... }
// ^~~~~
//
// Instead, we strip references before adding `T: Serialize` bounds in order to
// generate:
//
// impl<'a, T> Serialize for Test<'a, T>
// where T: Serialize { ... }
fn strip_reference(mut ty: &P<ast::Ty>) -> &P<ast::Ty> {
while let ast::TyKind::Rptr(_, ref mut_ty) = ty.node {
ty = &mut_ty.ty;
}
ty
}

View File

@ -99,6 +99,35 @@ struct ListNode<D> {
next: Box<ListNode<D>>,
}
#[derive(Serialize, Deserialize)]
struct RecursiveA {
b: Box<RecursiveB>,
}
#[derive(Serialize, Deserialize)]
enum RecursiveB {
A(RecursiveA),
}
#[derive(Serialize, Deserialize)]
struct RecursiveGenericA<T> {
t: T,
b: Box<RecursiveGenericB<T>>,
}
#[derive(Serialize, Deserialize)]
enum RecursiveGenericB<T> {
T(T),
A(RecursiveGenericA<T>),
}
#[derive(Serialize)]
#[allow(dead_code)]
struct OptionStatic<'a> {
a: Option<&'a str>,
b: Option<&'static str>,
}
#[derive(Serialize, Deserialize)]
#[serde(bound="D: SerializeWith + DeserializeWith")]
struct WithTraits1<D, E> {
@ -139,4 +168,3 @@ trait DeserializeWith: Sized {
struct X;
fn ser_x<S: Serializer>(_: &X, _: &mut S) -> Result<(), S::Error> { panic!() }
fn de_x<D: Deserializer>(_: &mut D) -> Result<X, D::Error> { panic!() }