From 9a3c1243f4acc14c83463bb5147c22a290456d7d Mon Sep 17 00:00:00 2001 From: David Tolnay Date: Sun, 19 Feb 2017 16:04:50 -0800 Subject: [PATCH] Deserialization of Haskell style enums --- serde/src/de/content.rs | 46 ++++++- serde/src/de/private.rs | 3 +- serde_derive/src/de.rs | 235 +++++++++++++++++++++++++++++++- test_suite/tests/test_macros.rs | 125 +++++++++++++++-- 4 files changed, 394 insertions(+), 15 deletions(-) diff --git a/serde/src/de/content.rs b/serde/src/de/content.rs index 1ae2e869..5ba2a169 100644 --- a/serde/src/de/content.rs +++ b/serde/src/de/content.rs @@ -21,7 +21,7 @@ use collections::{String, Vec}; use alloc::boxed::Box; use de::{self, Deserialize, DeserializeSeed, Deserializer, Visitor, SeqVisitor, MapVisitor, - EnumVisitor}; + EnumVisitor, Unexpected}; /// Used from generated code to buffer the contents of the Deserializer when /// deserializing untagged enums and internally tagged enums. @@ -493,6 +493,50 @@ impl Visitor for TaggedContentVisitor } } +/// Used by generated code to deserialize an adjacently tagged enum. +/// +/// Not public API. +pub enum TagOrContentField { + Tag, + Content, +} + +/// Not public API. +pub struct TagOrContentFieldVisitor { + pub tag: &'static str, + pub content: &'static str, +} + +impl DeserializeSeed for TagOrContentFieldVisitor { + type Value = TagOrContentField; + + fn deserialize(self, deserializer: D) -> Result + where D: Deserializer + { + deserializer.deserialize_str(self) + } +} + +impl Visitor for TagOrContentFieldVisitor { + type Value = TagOrContentField; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "{:?} or {:?}", self.tag, self.content) + } + + fn visit_str(self, field: &str) -> Result + where E: de::Error + { + if field == self.tag { + Ok(TagOrContentField::Tag) + } else if field == self.content { + Ok(TagOrContentField::Content) + } else { + Err(de::Error::invalid_value(Unexpected::Str(field), &self)) + } + } +} + /// Not public API pub struct ContentDeserializer { content: Content, diff --git a/serde/src/de/private.rs b/serde/src/de/private.rs index 8ef62231..092d66a6 100644 --- a/serde/src/de/private.rs +++ b/serde/src/de/private.rs @@ -4,7 +4,8 @@ use de::{Deserialize, Deserializer, Error, Visitor}; #[cfg(any(feature = "std", feature = "collections"))] pub use de::content::{Content, ContentRefDeserializer, ContentDeserializer, TaggedContentVisitor, - InternallyTaggedUnitVisitor, UntaggedUnitVisitor}; + TagOrContentField, TagOrContentFieldVisitor, InternallyTaggedUnitVisitor, + UntaggedUnitVisitor}; /// If the missing field is of type `Option` then treat is as `None`, /// otherwise it is an error. diff --git a/serde_derive/src/de.rs b/serde_derive/src/de.rs index b73d1669..b47f334b 100644 --- a/serde_derive/src/de.rs +++ b/serde_derive/src/de.rs @@ -185,7 +185,7 @@ fn deserialize_tuple(ident: &syn::Ident, __Visitor { marker: _serde::export::PhantomData::<#ident #ty_generics> } }; let dispatch = if let Some(deserializer) = deserializer { - quote!(_serde::Deserializer::deserialize(#deserializer, #visitor_expr)) + quote!(_serde::Deserializer::deserialize_tuple(#deserializer, #nfields, #visitor_expr)) } else if is_enum { quote!(_serde::de::VariantVisitor::visit_tuple(visitor, #nfields, #visitor_expr)) } else if nfields == 1 { @@ -442,7 +442,14 @@ fn deserialize_item_enum(ident: &syn::Ident, item_attrs, tag) } - attr::EnumTag::Adjacent { .. } => unimplemented!(), + attr::EnumTag::Adjacent { ref tag, ref content } => { + deserialize_adjacently_tagged_enum(ident, + generics, + variants, + item_attrs, + tag, + content) + } attr::EnumTag::None => { deserialize_untagged_enum(ident, generics, variants, item_attrs) } @@ -597,6 +604,230 @@ fn deserialize_internally_tagged_enum(ident: &syn::Ident, } } +fn deserialize_adjacently_tagged_enum(ident: &syn::Ident, + generics: &syn::Generics, + variants: &[Variant], + item_attrs: &attr::Item, + tag: &str, + content: &str) + -> Fragment { + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + let variant_names_idents: Vec<_> = variants.iter() + .enumerate() + .filter(|&(_, variant)| !variant.attrs.skip_deserializing()) + .map(|(i, variant)| (variant.attrs.name().deserialize_name(), field_i(i))) + .collect(); + + let variants_stmt = { + let variant_names = variant_names_idents.iter().map(|&(ref name, _)| name); + quote! { + const VARIANTS: &'static [&'static str] = &[ #(#variant_names),* ]; + } + }; + + let variant_visitor = Stmts(deserialize_field_visitor(variant_names_idents, item_attrs, true)); + + let ref variant_arms: Vec<_> = variants.iter() + .enumerate() + .filter(|&(_, variant)| !variant.attrs.skip_deserializing()) + .map(|(i, variant)| { + let variant_index = field_i(i); + + let block = Match(deserialize_untagged_variant( + ident, + generics, + variant, + item_attrs, + quote!(_deserializer), + )); + + quote! { + __Field::#variant_index => #block + } + }) + .collect(); + + let expecting = format!("adjacently tagged enum {}", ident); + let type_name = item_attrs.name().deserialize_name(); + + let tag_or_content = quote! { + _serde::de::private::TagOrContentFieldVisitor { + tag: #tag, + content: #content, + } + }; + + fn is_unit(variant: &Variant) -> bool { + match variant.style { + Style::Unit => true, + Style::Struct | Style::Tuple | Style::Newtype => false, + } + } + + let mut missing_content = quote! { + _serde::export::Err(<__V::Error as _serde::de::Error>::missing_field(#content)) + }; + if variants.iter().any(is_unit) { + let fallthrough = if variants.iter().all(is_unit) { + None + } else { + Some(quote! { + _ => #missing_content + }) + }; + let arms = variants.iter() + .enumerate() + .filter(|&(_, variant)| !variant.attrs.skip_deserializing() && is_unit(variant)) + .map(|(i, variant)| { + let variant_index = field_i(i); + let variant_ident = &variant.ident; + quote! { + __Field::#variant_index => _serde::export::Ok(#ident::#variant_ident), + } + }); + missing_content = quote! { + match __field { + #(#arms)* + #fallthrough + } + }; + } + + let visit_third_key = quote! { + // Visit the third key in the map, hopefully there isn't one. + match try!(_serde::de::MapVisitor::visit_key_seed(&mut visitor, #tag_or_content)) { + _serde::export::Some(_serde::de::private::TagOrContentField::Tag) => { + _serde::export::Err(<__V::Error as _serde::de::Error>::duplicate_field(#tag)) + } + _serde::export::Some(_serde::de::private::TagOrContentField::Content) => { + _serde::export::Err(<__V::Error as _serde::de::Error>::duplicate_field(#content)) + } + _serde::export::None => _serde::export::Ok(__ret), + } + }; + + quote_block! { + #variant_visitor + + #variants_stmt + + struct __Seed #impl_generics #where_clause { + field: __Field, + marker: _serde::export::PhantomData<#ident #ty_generics>, + } + + impl #impl_generics _serde::de::DeserializeSeed for __Seed #ty_generics #where_clause { + type Value = #ident #ty_generics; + + fn deserialize<__D>(self, _deserializer: __D) -> _serde::export::Result + where __D: _serde::Deserializer + { + match self.field { + #(#variant_arms)* + } + } + } + + struct __Visitor #impl_generics #where_clause { + marker: _serde::export::PhantomData<#ident #ty_generics>, + } + + impl #impl_generics _serde::de::Visitor for __Visitor #ty_generics #where_clause { + type Value = #ident #ty_generics; + + fn expecting(&self, formatter: &mut _serde::export::fmt::Formatter) -> _serde::export::fmt::Result { + _serde::export::fmt::Formatter::write_str(formatter, #expecting) + } + + fn visit_map<__V>(self, mut visitor: __V) -> _serde::export::Result + where __V: _serde::de::MapVisitor + { + // Visit the first key. + match try!(_serde::de::MapVisitor::visit_key_seed(&mut visitor, #tag_or_content)) { + // First key is the tag. + _serde::export::Some(_serde::de::private::TagOrContentField::Tag) => { + // Parse the tag. + let __field = try!(_serde::de::MapVisitor::visit_value(&mut visitor)); + // Visit the second key. + match try!(_serde::de::MapVisitor::visit_key_seed(&mut visitor, #tag_or_content)) { + // Second key is a duplicate of the tag. + _serde::export::Some(_serde::de::private::TagOrContentField::Tag) => { + _serde::export::Err(<__V::Error as _serde::de::Error>::duplicate_field(#tag)) + } + // Second key is the content. + _serde::export::Some(_serde::de::private::TagOrContentField::Content) => { + let __ret = try!(_serde::de::MapVisitor::visit_value_seed(&mut visitor, __Seed { field: __field, marker: _serde::export::PhantomData })); + // Visit the third key, hopefully there isn't one. + #visit_third_key + } + // There is no second key; might be okay if the we have a unit variant. + _serde::export::None => #missing_content + } + } + // First key is the content. + _serde::export::Some(_serde::de::private::TagOrContentField::Content) => { + // Buffer up the content. + let __content = try!(_serde::de::MapVisitor::visit_value::<_serde::de::private::Content>(&mut visitor)); + // Visit the second key. + match try!(_serde::de::MapVisitor::visit_key_seed(&mut visitor, #tag_or_content)) { + // Second key is the tag. + _serde::export::Some(_serde::de::private::TagOrContentField::Tag) => { + let _deserializer = _serde::de::private::ContentDeserializer::<__V::Error>::new(__content); + // Parse the tag. + let __ret = try!(match try!(_serde::de::MapVisitor::visit_value(&mut visitor)) { + // Deserialize the buffered content now that we know the variant. + #(#variant_arms)* + }); + // Visit the third key, hopefully there isn't one. + #visit_third_key + } + // Second key is a duplicate of the content. + _serde::export::Some(_serde::de::private::TagOrContentField::Content) => { + _serde::export::Err(<__V::Error as _serde::de::Error>::duplicate_field(#content)) + } + // There is no second key. + _serde::export::None => { + _serde::export::Err(<__V::Error as _serde::de::Error>::missing_field(#tag)) + } + } + } + // There is no first key. + _serde::export::None => { + _serde::export::Err(<__V::Error as _serde::de::Error>::missing_field(#tag)) + } + } + } + + fn visit_seq<__V>(self, mut visitor: __V) -> _serde::export::Result + where __V: _serde::de::SeqVisitor + { + // Visit the first element - the tag. + match try!(_serde::de::SeqVisitor::visit(&mut visitor)) { + _serde::export::Some(__field) => { + // Visit the second element - the content. + match try!(_serde::de::SeqVisitor::visit_seed(&mut visitor, __Seed { field: __field, marker: _serde::export::PhantomData })) { + _serde::export::Some(__ret) => _serde::export::Ok(__ret), + // There is no second element. + _serde::export::None => { + _serde::export::Err(_serde::de::Error::invalid_length(1, &self)) + } + } + } + // There is no first element. + _serde::export::None => { + _serde::export::Err(_serde::de::Error::invalid_length(0, &self)) + } + } + } + } + + const FIELDS: &'static [&'static str] = &[#tag, #content]; + _serde::Deserializer::deserialize_struct(deserializer, #type_name, FIELDS, + __Visitor { marker: _serde::export::PhantomData::<#ident #ty_generics> }) + } +} + fn deserialize_untagged_enum(ident: &syn::Ident, generics: &syn::Generics, variants: &[Variant], diff --git a/test_suite/tests/test_macros.rs b/test_suite/tests/test_macros.rs index 4b84daeb..f4ef675d 100644 --- a/test_suite/tests/test_macros.rs +++ b/test_suite/tests/test_macros.rs @@ -884,17 +884,18 @@ fn test_internally_tagged_enum() { #[test] fn test_adjacently_tagged_enum() { - #[derive(Debug, PartialEq, Serialize)] + #[derive(Debug, PartialEq, Serialize, Deserialize)] #[serde(tag = "t", content = "c")] - enum AdjacentlyTagged { + enum AdjacentlyTagged { Unit, - Newtype(u8), + Newtype(T), Tuple(u8, u8), Struct { f: u8 }, } - assert_ser_tokens( - &AdjacentlyTagged::Unit, + // unit with no content + assert_tokens( + &AdjacentlyTagged::Unit::, &[ Token::StructStart("AdjacentlyTagged", 1), @@ -906,8 +907,45 @@ fn test_adjacently_tagged_enum() { ] ); - assert_ser_tokens( - &AdjacentlyTagged::Newtype(1), + // unit with tag first + assert_de_tokens( + &AdjacentlyTagged::Unit::, + &[ + Token::StructStart("AdjacentlyTagged", 1), + + Token::StructSep, + Token::Str("t"), + Token::Str("Unit"), + + Token::StructSep, + Token::Str("c"), + Token::Unit, + + Token::StructEnd, + ] + ); + + // unit with content first + assert_de_tokens( + &AdjacentlyTagged::Unit::, + &[ + Token::StructStart("AdjacentlyTagged", 1), + + Token::StructSep, + Token::Str("c"), + Token::Unit, + + Token::StructSep, + Token::Str("t"), + Token::Str("Unit"), + + Token::StructEnd, + ] + ); + + // newtype with tag first + assert_tokens( + &AdjacentlyTagged::Newtype::(1), &[ Token::StructStart("AdjacentlyTagged", 2), @@ -923,8 +961,27 @@ fn test_adjacently_tagged_enum() { ] ); - assert_ser_tokens( - &AdjacentlyTagged::Tuple(1, 1), + // newtype with content first + assert_de_tokens( + &AdjacentlyTagged::Newtype::(1), + &[ + Token::StructStart("AdjacentlyTagged", 2), + + Token::StructSep, + Token::Str("c"), + Token::U8(1), + + Token::StructSep, + Token::Str("t"), + Token::Str("Newtype"), + + Token::StructEnd, + ] + ); + + // tuple with tag first + assert_tokens( + &AdjacentlyTagged::Tuple::(1, 1), &[ Token::StructStart("AdjacentlyTagged", 2), @@ -945,8 +1002,32 @@ fn test_adjacently_tagged_enum() { ] ); - assert_ser_tokens( - &AdjacentlyTagged::Struct { f: 1 }, + // tuple with content first + assert_de_tokens( + &AdjacentlyTagged::Tuple::(1, 1), + &[ + Token::StructStart("AdjacentlyTagged", 2), + + Token::StructSep, + Token::Str("c"), + Token::TupleStart(2), + Token::TupleSep, + Token::U8(1), + Token::TupleSep, + Token::U8(1), + Token::TupleEnd, + + Token::StructSep, + Token::Str("t"), + Token::Str("Tuple"), + + Token::StructEnd, + ] + ); + + // struct with tag first + assert_tokens( + &AdjacentlyTagged::Struct:: { f: 1 }, &[ Token::StructStart("AdjacentlyTagged", 2), @@ -965,4 +1046,26 @@ fn test_adjacently_tagged_enum() { Token::StructEnd, ] ); + + // struct with content first + assert_de_tokens( + &AdjacentlyTagged::Struct:: { f: 1 }, + &[ + Token::StructStart("AdjacentlyTagged", 2), + + Token::StructSep, + Token::Str("c"), + Token::StructStart("Struct", 1), + Token::StructSep, + Token::Str("f"), + Token::U8(1), + Token::StructEnd, + + Token::StructSep, + Token::Str("t"), + Token::Str("Struct"), + + Token::StructEnd, + ] + ); }