From 19ec8bbdb9e2cdaa2225bcb0a86f2f8b9677cb2b Mon Sep 17 00:00:00 2001 From: David Tolnay Date: Sun, 28 Feb 2016 12:17:29 -0800 Subject: [PATCH] feat(codegen): Inhibit generic bounds if skip_serializing The generated code for a struct like: struct Test { a: X #[serde(skip_serializing)] b: B #[serde(serialize_with="...")] c: C } Used to be: impl Serialize for Test where A: Serialize, B: Serialize, C: Serialize, { ... } Now it is: impl Serialize for Test where X: Serialize, { ... } Both `skip_serializing` and `serialize_with` mean the type does not need to implement `Serialize`. --- serde_codegen/src/attr.rs | 2 +- serde_codegen/src/ser.rs | 117 ++++++++++++++++++++++++-- serde_tests/tests/test_annotations.rs | 100 ++++++++++++++++++---- 3 files changed, 195 insertions(+), 24 deletions(-) diff --git a/serde_codegen/src/attr.rs b/serde_codegen/src/attr.rs index d937f605..d8ec35c5 100644 --- a/serde_codegen/src/attr.rs +++ b/serde_codegen/src/attr.rs @@ -382,7 +382,7 @@ fn get_renames(cx: &ExtCtxt, Ok((ser_name, de_name)) } -fn get_serde_meta_items(attr: &ast::Attribute) -> Option<&[P]> { +pub fn get_serde_meta_items(attr: &ast::Attribute) -> Option<&[P]> { match attr.node.value.node { ast::MetaItemKind::List(ref name, ref items) if name == &"serde" => { attr::mark_used(&attr); diff --git a/serde_codegen/src/ser.rs b/serde_codegen/src/ser.rs index cb97640d..1c87a9cf 100644 --- a/serde_codegen/src/ser.rs +++ b/serde_codegen/src/ser.rs @@ -60,11 +60,7 @@ fn serialize_item( } }; - let impl_generics = builder.from_generics(generics.clone()) - .add_ty_param_bound( - builder.path().global().ids(&["serde", "ser", "Serialize"]).build() - ) - .build(); + let impl_generics = build_impl_generics(cx, builder, item, generics); let ty = builder.ty().path() .segment(item.ident).with_generics(impl_generics.clone()).build() @@ -89,6 +85,117 @@ fn serialize_item( ).unwrap()) } +// All the generics in the input, plus a bound `T: Serialize` for each field +// type that will be serialized by us. +fn build_impl_generics( + cx: &ExtCtxt, + builder: &aster::AstBuilder, + item: &Item, + generics: &ast::Generics, +) -> ast::Generics { + let serialize_path = builder.path() + .global() + .ids(&["serde", "ser", "Serialize"]) + .build(); + + builder.from_generics(generics.clone()) + .with_predicates( + all_variants(cx, item).iter() + .flat_map(|variant_data| all_struct_fields(variant_data)) + .filter(|field| serialized_by_us(field)) + .map(|field| &field.node.ty) + .map(|ty| strip_reference(ty)) + .map(|ty| builder.where_predicate() + // the type that is being bounded i.e. T + .bound().build(ty.clone()) + // the bound i.e. Serialize + .bound().trait_(serialize_path.clone()).build() + .build())) + .build() +} + +fn all_variants<'a>(cx: &ExtCtxt, item: &'a Item) -> Vec<&'a ast::VariantData> { + match item.node { + ast::ItemKind::Struct(ref variant_data, _) => { + vec![variant_data] + } + ast::ItemKind::Enum(ref enum_def, _) => { + enum_def.variants.iter() + .map(|variant| &variant.node.data) + .collect() + } + _ => { + cx.span_bug(item.span, + "expected Item to be Struct or Enum in #[derive(Serialize)]"); + } + } +} + +fn all_struct_fields(variant_data: &ast::VariantData) -> &[ast::StructField] { + match *variant_data { + ast::VariantData::Struct(ref fields, _) | + ast::VariantData::Tuple(ref fields, _) => { + fields + } + ast::VariantData::Unit(_) => { + &[] + } + } +} + +// Fields with a `skip_serializing` or `serialize_with` attribute are not +// serialized by us. All other fields will receive a `T: Serialize` bound where +// T is the type of the field. +fn serialized_by_us(field: &ast::StructField) -> bool { + for meta_items in field.node.attrs.iter().filter_map(attr::get_serde_meta_items) { + for meta_item in meta_items { + match meta_item.node { + ast::MetaItemKind::Word(ref name) if name == &"skip_serializing" => { + return false + } + ast::MetaItemKind::NameValue(ref name, _) if name == &"serialize_with" => { + return false + } + _ => {} + } + } + } + true +} + +// This is required to handle types that use both a reference and a value of +// the same type, as in: +// +// enum Test<'a, T> where T: 'a { +// Lifetime(&'a T), +// NoLifetime(T), +// } +// +// Preserving references, we would generate an impl like: +// +// impl<'a, T> Serialize for Test<'a, T> +// where &'a T: Serialize, +// T: Serialize { ... } +// +// And taking a reference to one of the elements would fail with: +// +// error: cannot infer an appropriate lifetime for pattern due +// to conflicting requirements [E0495] +// Test::NoLifetime(ref v) => { ... } +// ^~~~~ +// +// Instead, we strip references before adding `T: Serialize` bounds in order to +// generate: +// +// impl<'a, T> Serialize for Test<'a, T> +// where T: Serialize { ... } +fn strip_reference(ty: &P) -> &P { + match ty.node { + ast::TyKind::Rptr(_, ref mut_ty) => &mut_ty.ty, + _ => ty + } +} + fn serialize_body( cx: &ExtCtxt, builder: &aster::AstBuilder, diff --git a/serde_tests/tests/test_annotations.rs b/serde_tests/tests/test_annotations.rs index cdefac86..fcf5db85 100644 --- a/serde_tests/tests/test_annotations.rs +++ b/serde_tests/tests/test_annotations.rs @@ -10,23 +10,33 @@ use token::{ assert_de_tokens_error }; -trait Trait: Sized { +trait MyDefault: Sized { fn my_default() -> Self; +} +trait ShouldSkip: Sized { fn should_skip(&self) -> bool; +} +trait SerializeWith: Sized { fn serialize_with(&self, ser: &mut S) -> Result<(), S::Error> where S: Serializer; +} +trait DeserializeWith: Sized { fn deserialize_with(de: &mut D) -> Result where D: Deserializer; } -impl Trait for i32 { +impl MyDefault for i32 { fn my_default() -> Self { 123 } +} +impl ShouldSkip for i32 { fn should_skip(&self) -> bool { *self == 123 } +} +impl SerializeWith for i32 { fn serialize_with(&self, ser: &mut S) -> Result<(), S::Error> where S: Serializer { @@ -36,7 +46,9 @@ impl Trait for i32 { false.serialize(ser) } } +} +impl DeserializeWith for i32 { fn deserialize_with(de: &mut D) -> Result where D: Deserializer { @@ -49,11 +61,11 @@ impl Trait for i32 { } #[derive(Debug, PartialEq, Serialize, Deserialize)] -struct DefaultStruct where C: Trait { +struct DefaultStruct where C: MyDefault { a1: A, #[serde(default)] a2: B, - #[serde(default="Trait::my_default")] + #[serde(default="MyDefault::my_default")] a3: C, } @@ -95,12 +107,12 @@ fn test_default_struct() { } #[derive(Debug, PartialEq, Serialize, Deserialize)] -enum DefaultEnum where C: Trait { +enum DefaultEnum where C: MyDefault { Struct { a1: A, #[serde(default)] a2: B, - #[serde(default="Trait::my_default")] + #[serde(default="MyDefault::my_default")] a3: C, } } @@ -389,11 +401,11 @@ fn test_rename_enum() { } #[derive(Debug, PartialEq, Serialize)] -struct SkipSerializingStruct<'a, B, C> where C: Trait { +struct SkipSerializingStruct<'a, B, C> where C: ShouldSkip { a: &'a i8, #[serde(skip_serializing)] b: B, - #[serde(skip_serializing_if="Trait::should_skip")] + #[serde(skip_serializing_if="ShouldSkip::should_skip")] c: C, } @@ -440,12 +452,12 @@ fn test_skip_serializing_struct() { } #[derive(Debug, PartialEq, Serialize)] -enum SkipSerializingEnum<'a, B, C> where C: Trait { +enum SkipSerializingEnum<'a, B, C> where C: ShouldSkip { Struct { a: &'a i8, #[serde(skip_serializing)] _b: B, - #[serde(skip_serializing_if="Trait::should_skip")] + #[serde(skip_serializing_if="ShouldSkip::should_skip")] c: C, } } @@ -492,10 +504,62 @@ fn test_skip_serializing_enum() { ); } +#[derive(Debug, PartialEq)] +struct NotSerializeStruct(i8); + +#[derive(Debug, PartialEq)] +enum NotSerializeEnum { Trouble } + +impl SerializeWith for NotSerializeEnum { + fn serialize_with(&self, ser: &mut S) -> Result<(), S::Error> + where S: Serializer + { + "trouble".serialize(ser) + } +} + #[derive(Debug, PartialEq, Serialize)] -struct SerializeWithStruct<'a, B> where B: Trait { +struct ContainsNotSerialize<'a, B, C, D> where B: 'a, D: SerializeWith { + a: &'a Option, + #[serde(skip_serializing)] + b: &'a B, + #[serde(skip_serializing)] + c: Option, + #[serde(serialize_with="SerializeWith::serialize_with")] + d: D, +} + +#[test] +fn test_elt_not_serialize() { + let a = 1; + assert_ser_tokens( + &ContainsNotSerialize { + a: &Some(a), + b: &NotSerializeStruct(2), + c: Some(NotSerializeEnum::Trouble), + d: NotSerializeEnum::Trouble, + }, + &[ + Token::StructStart("ContainsNotSerialize", Some(2)), + + Token::StructSep, + Token::Str("a"), + Token::Option(true), + Token::I8(1), + + Token::StructSep, + Token::Str("d"), + Token::Str("trouble"), + + Token::StructEnd, + ] + ); +} + +#[derive(Debug, PartialEq, Serialize)] +struct SerializeWithStruct<'a, B> where B: SerializeWith { a: &'a i8, - #[serde(serialize_with="Trait::serialize_with")] + #[serde(serialize_with="SerializeWith::serialize_with")] b: B, } @@ -544,10 +608,10 @@ fn test_serialize_with_struct() { } #[derive(Debug, PartialEq, Serialize)] -enum SerializeWithEnum<'a, B> where B: Trait { +enum SerializeWithEnum<'a, B> where B: SerializeWith { Struct { a: &'a i8, - #[serde(serialize_with="Trait::serialize_with")] + #[serde(serialize_with="SerializeWith::serialize_with")] b: B, } } @@ -597,9 +661,9 @@ fn test_serialize_with_enum() { } #[derive(Debug, PartialEq, Deserialize)] -struct DeserializeWithStruct where B: Trait { +struct DeserializeWithStruct where B: DeserializeWith { a: i8, - #[serde(deserialize_with="Trait::deserialize_with")] + #[serde(deserialize_with="DeserializeWith::deserialize_with")] b: B, } @@ -647,10 +711,10 @@ fn test_deserialize_with_struct() { } #[derive(Debug, PartialEq, Deserialize)] -enum DeserializeWithEnum where B: Trait { +enum DeserializeWithEnum where B: DeserializeWith { Struct { a: i8, - #[serde(deserialize_with="Trait::deserialize_with")] + #[serde(deserialize_with="DeserializeWith::deserialize_with")] b: B, } }