diff --git a/serde_codegen/src/bound.rs b/serde_codegen/src/bound.rs index cdbe5fa1..2c659e33 100644 --- a/serde_codegen/src/bound.rs +++ b/serde_codegen/src/bound.rs @@ -70,6 +70,7 @@ pub fn with_bound( .map(|&(ref field, _)| &field.ty) // TODO this filter can be removed later, see comment on function .filter(|ty| contains_generic(ty, generics)) + .filter(|ty| !contains_recursion(ty, item.ident)) .map(|ty| strip_reference(ty)) .map(|ty| builder.where_predicate() // the type that is being bounded e.g. T @@ -159,6 +160,50 @@ fn contains_generic(ty: &ast::Ty, generics: &ast::Generics) -> bool { visitor.found_generic } +// We do not attempt to generate any bounds based on field types that are +// directly recursive, as in: +// +// struct Test { +// next: Box>, +// } +// +// This does not catch field types that are mutually recursive with some other +// type. For those, we require bounds to be specified by a `where` attribute if +// the inferred ones are not correct. +// +// struct Test { +// #[serde(where="D: Serialize + Deserialize")] +// next: Box>, +// } +// struct Other { +// #[serde(where="D: Serialize + Deserialize")] +// next: Box>, +// } +fn contains_recursion(ty: &ast::Ty, ident: ast::Ident) -> bool { + struct FindRecursion { + ident: ast::Ident, + found_recursion: bool, + } + impl<'v> visit::Visitor<'v> for FindRecursion { + fn visit_path(&mut self, path: &'v ast::Path, _id: ast::NodeId) { + if !path.global + && path.segments.len() == 1 + && path.segments[0].identifier == self.ident { + self.found_recursion = true; + } else { + visit::walk_path(self, path); + } + } + } + + let mut visitor = FindRecursion { + ident: ident, + found_recursion: false, + }; + visit::walk_ty(&mut visitor, ty); + visitor.found_recursion +} + // This is required to handle types that use both a reference and a value of // the same type, as in: // diff --git a/serde_tests/tests/test_gen.rs b/serde_tests/tests/test_gen.rs index e6cbeff6..6559b9b9 100644 --- a/serde_tests/tests/test_gen.rs +++ b/serde_tests/tests/test_gen.rs @@ -75,7 +75,6 @@ struct Tuple( ); #[derive(Serialize, Deserialize)] -#[serde(where(serialize="D: Serialize", deserialize="D: Deserialize"))] enum TreeNode { Split { left: Box>, @@ -89,17 +88,33 @@ enum TreeNode { #[derive(Serialize, Deserialize)] struct ListNode { data: D, - #[serde(where="")] next: Box>, } #[derive(Serialize, Deserialize)] -struct SerializeWithTrait { +#[serde(where="D: SerializeWith + DeserializeWith")] +struct WithTraits1 { + #[serde(serialize_with="SerializeWith::serialize_with", + deserialize_with="DeserializeWith::deserialize_with")] + d: D, #[serde(serialize_with="SerializeWith::serialize_with", deserialize_with="DeserializeWith::deserialize_with", - where(serialize="D: SerializeWith", - deserialize="D: DeserializeWith"))] - data: D, + where="E: SerializeWith + DeserializeWith")] + e: E, +} + +#[derive(Serialize, Deserialize)] +#[serde(where(serialize="D: SerializeWith", + deserialize="D: DeserializeWith"))] +struct WithTraits2 { + #[serde(serialize_with="SerializeWith::serialize_with", + deserialize_with="DeserializeWith::deserialize_with")] + d: D, + #[serde(serialize_with="SerializeWith::serialize_with", + deserialize_with="DeserializeWith::deserialize_with", + where(serialize="E: SerializeWith", + deserialize="E: DeserializeWith"))] + e: E, } //////////////////////////////////////////////////////////////////////////