From b98a9a8f9b7b64b70935d0980f6f55a8767c1e06 Mon Sep 17 00:00:00 2001 From: David Tolnay Date: Sun, 17 Sep 2017 13:45:07 -0700 Subject: [PATCH] Support deserializing internally tagged enum from seq During serialization, internally tagged enums invoke the Serializer's serialize_struct. In JSON this turns into a map which uses visit_map when deserialized. But some formats employ visit_seq when deserializing a struct. One example is rmp-serde. Such formats were previously unable to deserialize an internally tagged enum. This change fixes it by adding visit_seq for internally tagged enums. --- serde/src/private/de.rs | 42 +++++++++++---- serde_derive/src/de.rs | 56 ++++++++++++++------ test_suite/tests/test_macros.rs | 92 +++++++++++++++++++++------------ 3 files changed, 130 insertions(+), 60 deletions(-) diff --git a/serde/src/private/de.rs b/serde/src/private/de.rs index a79b5a1f..3d8a2ec3 100644 --- a/serde/src/private/de.rs +++ b/serde/src/private/de.rs @@ -834,26 +834,43 @@ mod content { type Value = TaggedContent<'de, T>; fn expecting(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - fmt.write_str("any value") + fmt.write_str("internally tagged enum") } - fn visit_map(self, mut visitor: V) -> Result + fn visit_seq(self, mut seq: S) -> Result where - V: MapAccess<'de>, + S: SeqAccess<'de>, + { + let tag = match try!(seq.next_element()) { + Some(tag) => tag, + None => { + return Err(de::Error::missing_field(self.tag_name)); + } + }; + let rest = de::value::SeqAccessDeserializer::new(seq); + Ok(TaggedContent { + tag: tag, + content: try!(Content::deserialize(rest)), + }) + } + + fn visit_map(self, mut map: M) -> Result + where + M: MapAccess<'de>, { let mut tag = None; - let mut vec = Vec::with_capacity(size_hint::cautious(visitor.size_hint())); + let mut vec = Vec::with_capacity(size_hint::cautious(map.size_hint())); while let Some(k) = - try!(visitor.next_key_seed(TagOrContentVisitor::new(self.tag_name))) { + try!(map.next_key_seed(TagOrContentVisitor::new(self.tag_name))) { match k { TagOrContent::Tag => { if tag.is_some() { return Err(de::Error::duplicate_field(self.tag_name)); } - tag = Some(try!(visitor.next_value())); + tag = Some(try!(map.next_value())); } TagOrContent::Content(k) => { - let v = try!(visitor.next_value()); + let v = try!(map.next_value()); vec.push((k, v)); } } @@ -1802,9 +1819,16 @@ mod content { write!(formatter, "unit variant {}::{}", self.type_name, self.variant_name) } - fn visit_map(self, _: V) -> Result<(), V::Error> + fn visit_seq(self, _: S) -> Result<(), S::Error> where - V: MapAccess<'de>, + S: SeqAccess<'de>, + { + Ok(()) + } + + fn visit_map(self, _: M) -> Result<(), M::Error> + where + M: MapAccess<'de>, { Ok(()) } diff --git a/serde_derive/src/de.rs b/serde_derive/src/de.rs index fe064dd5..44ed7fa0 100644 --- a/serde_derive/src/de.rs +++ b/serde_derive/src/de.rs @@ -222,7 +222,7 @@ fn deserialize_body(cont: &Container, params: &Parameters) -> Fragment { if fields.iter().any(|field| field.ident.is_none()) { panic!("struct has unnamed fields"); } - deserialize_struct(None, params, fields, &cont.attrs, None) + deserialize_struct(None, params, fields, &cont.attrs, None, Untagged::No) } Body::Struct(Style::Tuple, ref fields) | Body::Struct(Style::Newtype, ref fields) => { @@ -488,15 +488,20 @@ fn deserialize_newtype_struct(type_path: &Tokens, params: &Parameters, field: &F } } +enum Untagged { + Yes, + No, +} + fn deserialize_struct( variant_ident: Option<&syn::Ident>, params: &Parameters, fields: &[Field], cattrs: &attr::Container, deserializer: Option, + untagged: Untagged, ) -> Fragment { let is_enum = variant_ident.is_some(); - let is_untagged = deserializer.is_some(); let this = ¶ms.this; let (de_impl_generics, de_ty_generics, ty_generics, where_clause) = split_with_de_lifetime(params,); @@ -559,18 +564,19 @@ fn deserialize_struct( quote!(mut __seq) }; - let visit_seq = if is_untagged { - // untagged struct variants do not get a visit_seq method - None - } else { - Some(quote! { - #[inline] - fn visit_seq<__A>(self, #visitor_var: __A) -> _serde::export::Result - where __A: _serde::de::SeqAccess<#delife> - { - #visit_seq - } - }) + // untagged struct variants do not get a visit_seq method + let visit_seq = match untagged { + Untagged::Yes => None, + Untagged::No => { + Some(quote! { + #[inline] + fn visit_seq<__A>(self, #visitor_var: __A) -> _serde::export::Result + where __A: _serde::de::SeqAccess<#delife> + { + #visit_seq + } + }) + } }; quote_block! { @@ -1148,7 +1154,7 @@ fn deserialize_externally_tagged_variant( deserialize_tuple(Some(variant_ident), params, &variant.fields, cattrs, None) } Style::Struct => { - deserialize_struct(Some(variant_ident), params, &variant.fields, cattrs, None) + deserialize_struct(Some(variant_ident), params, &variant.fields, cattrs, None, Untagged::No) } } } @@ -1175,8 +1181,23 @@ fn deserialize_internally_tagged_variant( _serde::export::Ok(#this::#variant_ident) } } - Style::Newtype | Style::Struct => { - deserialize_untagged_variant(params, variant, cattrs, deserializer) + Style::Newtype => { + deserialize_untagged_newtype_variant( + variant_ident, + params, + &variant.fields[0], + deserializer, + ) + } + Style::Struct => { + deserialize_struct( + Some(variant_ident), + params, + &variant.fields, + cattrs, + Some(deserializer), + Untagged::No, + ) } Style::Tuple => unreachable!("checked in serde_derive_internals"), } @@ -1238,6 +1259,7 @@ fn deserialize_untagged_variant( &variant.fields, cattrs, Some(deserializer), + Untagged::Yes, ) } } diff --git a/test_suite/tests/test_macros.rs b/test_suite/tests/test_macros.rs index 66e72f23..0ea2578a 100644 --- a/test_suite/tests/test_macros.rs +++ b/test_suite/tests/test_macros.rs @@ -591,11 +591,10 @@ fn test_internally_tagged_enum() { #[serde(tag = "type")] enum InternallyTagged { A { a: u8 }, - B { b: u8 }, - C, - D(BTreeMap), - E(Newtype), - F(Struct), + B, + C(BTreeMap), + D(Newtype), + E(Struct), } assert_tokens( @@ -613,35 +612,62 @@ fn test_internally_tagged_enum() { ], ); - assert_tokens( - &InternallyTagged::B { b: 2 }, + assert_de_tokens( + &InternallyTagged::A { a: 1 }, &[ - Token::Struct { name: "InternallyTagged", len: 2 }, - - Token::Str("type"), - Token::Str("B"), - - Token::Str("b"), - Token::U8(2), - - Token::StructEnd, + Token::Seq { len: Some(2) }, + Token::Str("A"), + Token::U8(1), + Token::SeqEnd, ], ); assert_tokens( - &InternallyTagged::C, + &InternallyTagged::B, &[ Token::Struct { name: "InternallyTagged", len: 1 }, Token::Str("type"), - Token::Str("C"), + Token::Str("B"), Token::StructEnd, ], ); + assert_de_tokens( + &InternallyTagged::B, + &[ + Token::Seq { len: Some(1) }, + Token::Str("B"), + Token::SeqEnd, + ], + ); + assert_tokens( - &InternallyTagged::D(BTreeMap::new()), + &InternallyTagged::C(BTreeMap::new()), + &[ + Token::Map { len: Some(1) }, + + Token::Str("type"), + Token::Str("C"), + + Token::MapEnd, + ], + ); + + assert_de_tokens_error::( + &[ + Token::Seq { len: Some(2) }, + Token::Str("C"), + Token::Map { len: Some(0) }, + Token::MapEnd, + Token::SeqEnd, + ], + "invalid type: sequence, expected a map", + ); + + assert_tokens( + &InternallyTagged::D(Newtype(BTreeMap::new())), &[ Token::Map { len: Some(1) }, @@ -653,24 +679,12 @@ fn test_internally_tagged_enum() { ); assert_tokens( - &InternallyTagged::E(Newtype(BTreeMap::new())), - &[ - Token::Map { len: Some(1) }, - - Token::Str("type"), - Token::Str("E"), - - Token::MapEnd, - ], - ); - - assert_tokens( - &InternallyTagged::F(Struct { f: 6 }), + &InternallyTagged::E(Struct { f: 6 }), &[ Token::Struct { name: "Struct", len: 2 }, Token::Str("type"), - Token::Str("F"), + Token::Str("E"), Token::Str("f"), Token::U8(6), @@ -679,6 +693,16 @@ fn test_internally_tagged_enum() { ], ); + assert_de_tokens( + &InternallyTagged::E(Struct { f: 6 }), + &[ + Token::Seq { len: Some(2) }, + Token::Str("E"), + Token::U8(6), + Token::SeqEnd, + ], + ); + assert_de_tokens_error::( &[Token::Map { len: Some(0) }, Token::MapEnd], "missing field `type`", @@ -693,7 +717,7 @@ fn test_internally_tagged_enum() { Token::MapEnd, ], - "unknown variant `Z`, expected one of `A`, `B`, `C`, `D`, `E`, `F`", + "unknown variant `Z`, expected one of `A`, `B`, `C`, `D`, `E`", ); }