Merge pull request #456 from serde-rs/generic
Generate bounds on type parameters only
This commit is contained in:
commit
85772726ee
@ -1,7 +1,8 @@
|
||||
use std::collections::HashSet;
|
||||
|
||||
use aster::AstBuilder;
|
||||
|
||||
use syntax::ast;
|
||||
use syntax::ptr::P;
|
||||
use syntax::visit;
|
||||
|
||||
use internals::ast::Item;
|
||||
@ -47,6 +48,17 @@ pub fn with_where_predicates_from_fields<F>(
|
||||
.build()
|
||||
}
|
||||
|
||||
// Puts the given bound on any generic type parameters that are used in fields
|
||||
// for which filter returns true.
|
||||
//
|
||||
// For example, the following struct needs the bound `A: Serialize, B: Serialize`.
|
||||
//
|
||||
// struct S<'b, A, B: 'b, C> {
|
||||
// a: A,
|
||||
// b: Option<&'b B>
|
||||
// #[serde(skip_serializing)]
|
||||
// c: C,
|
||||
// }
|
||||
pub fn with_bound<F>(
|
||||
builder: &AstBuilder,
|
||||
item: &Item,
|
||||
@ -56,95 +68,53 @@ pub fn with_bound<F>(
|
||||
) -> ast::Generics
|
||||
where F: Fn(&attr::Field) -> bool,
|
||||
{
|
||||
struct FindTyParams {
|
||||
// Set of all generic type parameters on the current struct (A, B, C in
|
||||
// the example). Initialized up front.
|
||||
all_ty_params: HashSet<ast::Name>,
|
||||
// Set of generic type parameters used in fields for which filter
|
||||
// returns true (A and B in the example). Filled in as the visitor sees
|
||||
// them.
|
||||
relevant_ty_params: HashSet<ast::Name>,
|
||||
}
|
||||
impl visit::Visitor for FindTyParams {
|
||||
fn visit_path(&mut self, path: &ast::Path, _id: ast::NodeId) {
|
||||
if !path.global && path.segments.len() == 1 {
|
||||
let id = path.segments[0].identifier.name;
|
||||
if self.all_ty_params.contains(&id) {
|
||||
self.relevant_ty_params.insert(id);
|
||||
}
|
||||
}
|
||||
visit::walk_path(self, path);
|
||||
}
|
||||
}
|
||||
|
||||
let all_ty_params: HashSet<_> = generics.ty_params.iter()
|
||||
.map(|ty_param| ty_param.ident.name)
|
||||
.collect();
|
||||
|
||||
let relevant_tys = item.body.all_fields()
|
||||
.filter(|&field| filter(&field.attrs))
|
||||
.map(|field| &field.ty);
|
||||
|
||||
let mut visitor = FindTyParams {
|
||||
all_ty_params: all_ty_params,
|
||||
relevant_ty_params: HashSet::new(),
|
||||
};
|
||||
for ty in relevant_tys {
|
||||
visit::walk_ty(&mut visitor, ty);
|
||||
}
|
||||
|
||||
builder.from_generics(generics.clone())
|
||||
.with_predicates(
|
||||
item.body.all_fields()
|
||||
.filter(|&field| filter(&field.attrs))
|
||||
.map(|field| &field.ty)
|
||||
.filter(|ty| !contains_recursion(ty, item.ident))
|
||||
.map(|ty| strip_reference(ty))
|
||||
.map(|ty| builder.where_predicate()
|
||||
// the type that is being bounded e.g. T
|
||||
.bound().build(ty.clone())
|
||||
generics.ty_params.iter()
|
||||
.map(|ty_param| ty_param.ident.name)
|
||||
.filter(|id| visitor.relevant_ty_params.contains(id))
|
||||
.map(|id| builder.where_predicate()
|
||||
// the type parameter that is being bounded e.g. T
|
||||
.bound().build(builder.ty().id(id))
|
||||
// the bound e.g. Serialize
|
||||
.bound().trait_(bound.clone()).build()
|
||||
.build()))
|
||||
.build()
|
||||
}
|
||||
|
||||
// We do not attempt to generate any bounds based on field types that are
|
||||
// directly recursive, as in:
|
||||
//
|
||||
// struct Test<D> {
|
||||
// next: Box<Test<D>>,
|
||||
// }
|
||||
//
|
||||
// This does not catch field types that are mutually recursive with some other
|
||||
// type. For those, we require bounds to be specified by a `bound` attribute if
|
||||
// the inferred ones are not correct.
|
||||
//
|
||||
// struct Test<D> {
|
||||
// #[serde(bound="D: Serialize + Deserialize")]
|
||||
// next: Box<Other<D>>,
|
||||
// }
|
||||
// struct Other<D> {
|
||||
// #[serde(bound="D: Serialize + Deserialize")]
|
||||
// next: Box<Test<D>>,
|
||||
// }
|
||||
fn contains_recursion(ty: &ast::Ty, ident: ast::Ident) -> bool {
|
||||
struct FindRecursion {
|
||||
ident: ast::Ident,
|
||||
found_recursion: bool,
|
||||
}
|
||||
impl visit::Visitor for FindRecursion {
|
||||
fn visit_path(&mut self, path: &ast::Path, _id: ast::NodeId) {
|
||||
if !path.global
|
||||
&& path.segments.len() == 1
|
||||
&& path.segments[0].identifier == self.ident {
|
||||
self.found_recursion = true;
|
||||
} else {
|
||||
visit::walk_path(self, path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut visitor = FindRecursion {
|
||||
ident: ident,
|
||||
found_recursion: false,
|
||||
};
|
||||
visit::walk_ty(&mut visitor, ty);
|
||||
visitor.found_recursion
|
||||
}
|
||||
|
||||
// This is required to handle types that use both a reference and a value of
|
||||
// the same type, as in:
|
||||
//
|
||||
// enum Test<'a, T> where T: 'a {
|
||||
// Lifetime(&'a T),
|
||||
// NoLifetime(T),
|
||||
// }
|
||||
//
|
||||
// Preserving references, we would generate an impl like:
|
||||
//
|
||||
// impl<'a, T> Serialize for Test<'a, T>
|
||||
// where &'a T: Serialize,
|
||||
// T: Serialize { ... }
|
||||
//
|
||||
// And taking a reference to one of the elements would fail with:
|
||||
//
|
||||
// error: cannot infer an appropriate lifetime for pattern due
|
||||
// to conflicting requirements [E0495]
|
||||
// Test::NoLifetime(ref v) => { ... }
|
||||
// ^~~~~
|
||||
//
|
||||
// Instead, we strip references before adding `T: Serialize` bounds in order to
|
||||
// generate:
|
||||
//
|
||||
// impl<'a, T> Serialize for Test<'a, T>
|
||||
// where T: Serialize { ... }
|
||||
fn strip_reference(mut ty: &P<ast::Ty>) -> &P<ast::Ty> {
|
||||
while let ast::TyKind::Rptr(_, ref mut_ty) = ty.node {
|
||||
ty = &mut_ty.ty;
|
||||
}
|
||||
ty
|
||||
}
|
||||
|
@ -99,6 +99,35 @@ struct ListNode<D> {
|
||||
next: Box<ListNode<D>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct RecursiveA {
|
||||
b: Box<RecursiveB>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
enum RecursiveB {
|
||||
A(RecursiveA),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct RecursiveGenericA<T> {
|
||||
t: T,
|
||||
b: Box<RecursiveGenericB<T>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
enum RecursiveGenericB<T> {
|
||||
T(T),
|
||||
A(RecursiveGenericA<T>),
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[allow(dead_code)]
|
||||
struct OptionStatic<'a> {
|
||||
a: Option<&'a str>,
|
||||
b: Option<&'static str>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(bound="D: SerializeWith + DeserializeWith")]
|
||||
struct WithTraits1<D, E> {
|
||||
@ -139,4 +168,3 @@ trait DeserializeWith: Sized {
|
||||
struct X;
|
||||
fn ser_x<S: Serializer>(_: &X, _: &mut S) -> Result<(), S::Error> { panic!() }
|
||||
fn de_x<D: Deserializer>(_: &mut D) -> Result<X, D::Error> { panic!() }
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user