diff --git a/serde/src/de/content.rs b/serde/src/de/content.rs new file mode 100644 index 00000000..fa8875c8 --- /dev/null +++ b/serde/src/de/content.rs @@ -0,0 +1,743 @@ +// This module is doc(hidden) and nothing here should be used outside of +// generated code. +// +// We will iterate on the implementation for a few releases and only have to +// worry about backward compatibility for the `untagged` and `tag` attributes +// rather than for this entire mechanism. +// +// This issue is tracking making some of this stuff public: +// https://github.com/serde-rs/serde/issues/741 + +#![doc(hidden)] + +use core::fmt; +use core::marker::PhantomData; + +#[cfg(all(not(feature = "std"), feature = "collections"))] +use collections::{String, Vec}; + +#[cfg(all(feature = "alloc", not(feature = "std")))] +use alloc::boxed::Box; + +use de::{ + self, + Deserialize, + DeserializeSeed, + Deserializer, + Visitor, + SeqVisitor, + MapVisitor, + EnumVisitor, +}; + +/// Used from generated code to buffer the contents of the Deserializer when +/// deserializing untagged enums and internally tagged enums. +/// +/// Not public API. Use serde-value instead. +#[derive(Debug)] +pub enum Content { + // Don't mind the PhantomData, just need to use E somewhere. + Bool(bool, PhantomData), + + U8(u8), + U16(u16), + U32(u32), + U64(u64), + + I8(i8), + I16(i16), + I32(i32), + I64(i64), + + F32(f32), + F64(f64), + + Char(char), + String(String), + Bytes(Vec), + + None, + Some(Box>), + + Unit, + Newtype(Box>), + Seq(Vec>), + Map(Vec<(Content, Content)>), +} + +impl Deserialize for Content { + fn deserialize(deserializer: D) -> Result { + // Untagged and internally tagged enums are only supported in + // self-describing formats. + deserializer.deserialize(ContentVisitor::new()) + } +} + +struct ContentVisitor { + err: PhantomData, +} + +impl ContentVisitor { + fn new() -> Self { + ContentVisitor { + err: PhantomData, + } + } +} + +impl Visitor for ContentVisitor { + type Value = Content; + + fn expecting(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt.write_str("any value") + } + + fn visit_bool(self, value: bool) -> Result + where F: de::Error + { + Ok(Content::Bool(value, PhantomData)) + } + + fn visit_i8(self, value: i8) -> Result + where F: de::Error + { + Ok(Content::I8(value)) + } + + fn visit_i16(self, value: i16) -> Result + where F: de::Error + { + Ok(Content::I16(value)) + } + + fn visit_i32(self, value: i32) -> Result + where F: de::Error + { + Ok(Content::I32(value)) + } + + fn visit_i64(self, value: i64) -> Result + where F: de::Error + { + Ok(Content::I64(value)) + } + + fn visit_u8(self, value: u8) -> Result + where F: de::Error + { + Ok(Content::U8(value)) + } + + fn visit_u16(self, value: u16) -> Result + where F: de::Error + { + Ok(Content::U16(value)) + } + + fn visit_u32(self, value: u32) -> Result + where F: de::Error + { + Ok(Content::U32(value)) + } + + fn visit_u64(self, value: u64) -> Result + where F: de::Error + { + Ok(Content::U64(value)) + } + + fn visit_f32(self, value: f32) -> Result + where F: de::Error + { + Ok(Content::F32(value)) + } + + fn visit_f64(self, value: f64) -> Result + where F: de::Error + { + Ok(Content::F64(value)) + } + + fn visit_char(self, value: char) -> Result + where F: de::Error + { + Ok(Content::Char(value)) + } + + fn visit_str(self, value: &str) -> Result + where F: de::Error + { + Ok(Content::String(value.into())) + } + + fn visit_string(self, value: String) -> Result + where F: de::Error + { + Ok(Content::String(value)) + } + + fn visit_bytes(self, value: &[u8]) -> Result + where F: de::Error + { + Ok(Content::Bytes(value.into())) + } + + fn visit_byte_buf(self, value: Vec) -> Result + where F: de::Error + { + Ok(Content::Bytes(value)) + } + + fn visit_unit(self) -> Result + where F: de::Error + { + Ok(Content::Unit) + } + + fn visit_none(self) -> Result + where F: de::Error + { + Ok(Content::None) + } + + fn visit_some(self, deserializer: D) -> Result + where D: Deserializer + { + Deserialize::deserialize(deserializer).map(|v| Content::Some(Box::new(v))) + } + + fn visit_newtype_struct(self, deserializer: D) -> Result + where D: Deserializer + { + Deserialize::deserialize(deserializer).map(|v| Content::Newtype(Box::new(v))) + } + + fn visit_seq(self, mut visitor: V) -> Result + where V: SeqVisitor + { + let mut vec = Vec::with_capacity(visitor.size_hint().0); + while let Some(e) = try!(visitor.visit()) { + vec.push(e); + } + Ok(Content::Seq(vec)) + } + + fn visit_map(self, mut visitor: V) -> Result + where V: MapVisitor + { + let mut vec = Vec::with_capacity(visitor.size_hint().0); + while let Some(kv) = try!(visitor.visit()) { + vec.push(kv); + } + Ok(Content::Map(vec)) + } + + fn visit_enum(self, _visitor: V) -> Result + where V: EnumVisitor + { + Err(de::Error::custom("untagged and internally tagged enums do not support enum input")) + } +} + +/// This is the type of the map keys in an internally tagged enum. +/// +/// Not public API. +pub enum TagOrContent { + Tag, + Content(Content), +} + +struct TagOrContentVisitor { + name: &'static str, + err: PhantomData, +} + +impl TagOrContentVisitor { + fn new(name: &'static str) -> Self { + TagOrContentVisitor { + name: name, + err: PhantomData, + } + } +} + +impl DeserializeSeed for TagOrContentVisitor { + type Value = TagOrContent; + + fn deserialize(self, deserializer: D) -> Result + where D: Deserializer + { + // Internally tagged enums are only supported in self-describing + // formats. + deserializer.deserialize(self) + } +} + +impl Visitor for TagOrContentVisitor { + type Value = TagOrContent; + + fn expecting(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + write!(fmt, "a type tag `{}` or any other value", self.name) + } + + fn visit_bool(self, value: bool) -> Result + where F: de::Error + { + ContentVisitor::new().visit_bool(value).map(TagOrContent::Content) + } + + fn visit_i8(self, value: i8) -> Result + where F: de::Error + { + ContentVisitor::new().visit_i8(value).map(TagOrContent::Content) + } + + fn visit_i16(self, value: i16) -> Result + where F: de::Error + { + ContentVisitor::new().visit_i16(value).map(TagOrContent::Content) + } + + fn visit_i32(self, value: i32) -> Result + where F: de::Error + { + ContentVisitor::new().visit_i32(value).map(TagOrContent::Content) + } + + fn visit_i64(self, value: i64) -> Result + where F: de::Error + { + ContentVisitor::new().visit_i64(value).map(TagOrContent::Content) + } + + fn visit_u8(self, value: u8) -> Result + where F: de::Error + { + ContentVisitor::new().visit_u8(value).map(TagOrContent::Content) + } + + fn visit_u16(self, value: u16) -> Result + where F: de::Error + { + ContentVisitor::new().visit_u16(value).map(TagOrContent::Content) + } + + fn visit_u32(self, value: u32) -> Result + where F: de::Error + { + ContentVisitor::new().visit_u32(value).map(TagOrContent::Content) + } + + fn visit_u64(self, value: u64) -> Result + where F: de::Error + { + ContentVisitor::new().visit_u64(value).map(TagOrContent::Content) + } + + fn visit_f32(self, value: f32) -> Result + where F: de::Error + { + ContentVisitor::new().visit_f32(value).map(TagOrContent::Content) + } + + fn visit_f64(self, value: f64) -> Result + where F: de::Error + { + ContentVisitor::new().visit_f64(value).map(TagOrContent::Content) + } + + fn visit_char(self, value: char) -> Result + where F: de::Error + { + ContentVisitor::new().visit_char(value).map(TagOrContent::Content) + } + + fn visit_str(self, value: &str) -> Result + where F: de::Error + { + if value == self.name { + Ok(TagOrContent::Tag) + } else { + ContentVisitor::new().visit_str(value).map(TagOrContent::Content) + } + } + + fn visit_string(self, value: String) -> Result + where F: de::Error + { + if value == self.name { + Ok(TagOrContent::Tag) + } else { + ContentVisitor::new().visit_string(value).map(TagOrContent::Content) + } + } + + fn visit_bytes(self, value: &[u8]) -> Result + where F: de::Error + { + if value == self.name.as_bytes() { + Ok(TagOrContent::Tag) + } else { + ContentVisitor::new().visit_bytes(value).map(TagOrContent::Content) + } + } + + fn visit_byte_buf(self, value: Vec) -> Result + where F: de::Error + { + if value == self.name.as_bytes() { + Ok(TagOrContent::Tag) + } else { + ContentVisitor::new().visit_byte_buf(value).map(TagOrContent::Content) + } + } + + fn visit_unit(self) -> Result + where F: de::Error + { + ContentVisitor::new().visit_unit().map(TagOrContent::Content) + } + + fn visit_none(self) -> Result + where F: de::Error + { + ContentVisitor::new().visit_none().map(TagOrContent::Content) + } + + fn visit_some(self, deserializer: D) -> Result + where D: Deserializer + { + ContentVisitor::new().visit_some(deserializer).map(TagOrContent::Content) + } + + fn visit_newtype_struct(self, deserializer: D) -> Result + where D: Deserializer + { + ContentVisitor::new().visit_newtype_struct(deserializer).map(TagOrContent::Content) + } + + fn visit_seq(self, visitor: V) -> Result + where V: SeqVisitor + { + ContentVisitor::new().visit_seq(visitor).map(TagOrContent::Content) + } + + fn visit_map(self, visitor: V) -> Result + where V: MapVisitor + { + ContentVisitor::new().visit_map(visitor).map(TagOrContent::Content) + } + + fn visit_enum(self, visitor: V) -> Result + where V: EnumVisitor + { + ContentVisitor::new().visit_enum(visitor).map(TagOrContent::Content) + } +} + +/// Used by generated code to deserialize an internally tagged enum. +/// +/// Not public API. +pub struct TaggedContent { + pub tag: T, + pub content: Content, +} + +/// Not public API. +pub struct TaggedContentVisitor { + tag_name: &'static str, + tag: PhantomData, + err: PhantomData, +} + +impl TaggedContentVisitor { + /// Visitor for the content of an internally tagged enum with the given tag + /// name. + pub fn new(name: &'static str) -> Self { + TaggedContentVisitor { + tag_name: name, + tag: PhantomData, + err: PhantomData, + } + } +} + +impl DeserializeSeed for TaggedContentVisitor + where T: Deserialize +{ + type Value = TaggedContent; + + fn deserialize(self, deserializer: D) -> Result + where D: Deserializer + { + // Internally tagged enums are only supported in self-describing + // formats. + deserializer.deserialize(self) + } +} + +impl Visitor for TaggedContentVisitor + where T: Deserialize +{ + type Value = TaggedContent; + + fn expecting(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt.write_str("any value") + } + + fn visit_map(self, mut visitor: V) -> Result + where V: MapVisitor + { + let mut tag = None; + let mut vec = Vec::with_capacity(visitor.size_hint().0); + while let Some(k) = try!(visitor.visit_key_seed(TagOrContentVisitor::new(self.tag_name))) { + match k { + TagOrContent::Tag => { + if tag.is_some() { + return Err(de::Error::duplicate_field(self.tag_name)); + } + tag = Some(try!(visitor.visit_value())); + } + TagOrContent::Content(k) => { + let v = try!(visitor.visit_value()); + vec.push((k, v)); + } + } + } + match tag { + None => { + Err(de::Error::missing_field(self.tag_name)) + } + Some(tag) => { + Ok(TaggedContent { + tag: tag, + content: Content::Map(vec), + }) + } + } + } +} + +/// Used when deserializing an internally tagged enum because the content will +/// be used exactly once. +impl Deserializer for Content + where E: de::Error +{ + type Error = E; + + fn deserialize(self, visitor: V) -> Result + where V: Visitor + { + match self { + Content::Bool(v, _) => visitor.visit_bool(v), + Content::U8(v) => visitor.visit_u8(v), + Content::U16(v) => visitor.visit_u16(v), + Content::U32(v) => visitor.visit_u32(v), + Content::U64(v) => visitor.visit_u64(v), + Content::I8(v) => visitor.visit_i8(v), + Content::I16(v) => visitor.visit_i16(v), + Content::I32(v) => visitor.visit_i32(v), + Content::I64(v) => visitor.visit_i64(v), + Content::F32(v) => visitor.visit_f32(v), + Content::F64(v) => visitor.visit_f64(v), + Content::Char(v) => visitor.visit_char(v), + Content::String(v) => visitor.visit_string(v), + Content::Unit => visitor.visit_unit(), + Content::None => visitor.visit_none(), + Content::Some(v) => visitor.visit_some(*v), + Content::Newtype(v) => visitor.visit_newtype_struct(*v), + Content::Seq(v) => { + let seq = v.into_iter(); + let mut seq_visitor = de::value::SeqDeserializer::new(seq); + let value = try!(visitor.visit_seq(&mut seq_visitor)); + try!(seq_visitor.end()); + Ok(value) + }, + Content::Map(v) => { + let map = v.into_iter(); + let mut map_visitor = de::value::MapDeserializer::new(map); + let value = try!(visitor.visit_map(&mut map_visitor)); + try!(map_visitor.end()); + Ok(value) + }, + Content::Bytes(v) => visitor.visit_byte_buf(v), + } + } + + fn deserialize_option(self, visitor: V) -> Result + where V: Visitor + { + match self { + Content::None => visitor.visit_none(), + Content::Some(v) => visitor.visit_some(*v), + Content::Unit => visitor.visit_unit(), + _ => visitor.visit_some(self) + } + } + + fn deserialize_newtype_struct(self, _name: &str, visitor: V) -> Result + where V: Visitor + { + visitor.visit_newtype_struct(self) + } + + forward_to_deserialize! { + bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string unit seq + seq_fixed_size bytes byte_buf map unit_struct tuple_struct struct + struct_field tuple enum ignored_any + } +} + +impl de::value::ValueDeserializer for Content + where E: de::Error +{ + type Deserializer = Self; + + fn into_deserializer(self) -> Self { + self + } +} + +/// Used when deserializing an untagged enum because the content may need to be +/// used more than once. +impl<'a, E> Deserializer for &'a Content + where E: de::Error +{ + type Error = E; + + fn deserialize(self, visitor: V) -> Result + where V: Visitor + { + match *self { + Content::Bool(v, _) => visitor.visit_bool(v), + Content::U8(v) => visitor.visit_u8(v), + Content::U16(v) => visitor.visit_u16(v), + Content::U32(v) => visitor.visit_u32(v), + Content::U64(v) => visitor.visit_u64(v), + Content::I8(v) => visitor.visit_i8(v), + Content::I16(v) => visitor.visit_i16(v), + Content::I32(v) => visitor.visit_i32(v), + Content::I64(v) => visitor.visit_i64(v), + Content::F32(v) => visitor.visit_f32(v), + Content::F64(v) => visitor.visit_f64(v), + Content::Char(v) => visitor.visit_char(v), + Content::String(ref v) => visitor.visit_str(v), + Content::Unit => visitor.visit_unit(), + Content::None => visitor.visit_none(), + Content::Some(ref v) => visitor.visit_some(&**v), + Content::Newtype(ref v) => visitor.visit_newtype_struct(&**v), + Content::Seq(ref v) => { + let seq = v.into_iter(); + let mut seq_visitor = de::value::SeqDeserializer::new(seq); + let value = try!(visitor.visit_seq(&mut seq_visitor)); + try!(seq_visitor.end()); + Ok(value) + }, + Content::Map(ref v) => { + let map = v.into_iter().map(|&(ref k, ref v)| (k, v)); + let mut map_visitor = de::value::MapDeserializer::new(map); + let value = try!(visitor.visit_map(&mut map_visitor)); + try!(map_visitor.end()); + Ok(value) + }, + Content::Bytes(ref v) => visitor.visit_bytes(v), + } + } + + fn deserialize_option(self, visitor: V) -> Result + where V: Visitor + { + match *self { + Content::None => visitor.visit_none(), + Content::Some(ref v) => visitor.visit_some(&**v), + Content::Unit => visitor.visit_unit(), + _ => visitor.visit_some(self) + } + } + + fn deserialize_newtype_struct(self, _name: &str, visitor: V) -> Result + where V: Visitor + { + visitor.visit_newtype_struct(self) + } + + forward_to_deserialize! { + bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string unit seq + seq_fixed_size bytes byte_buf map unit_struct tuple_struct struct + struct_field tuple enum ignored_any + } +} + +impl<'a, E> de::value::ValueDeserializer for &'a Content + where E: de::Error +{ + type Deserializer = Self; + + fn into_deserializer(self) -> Self { + self + } +} + +/// Visitor for deserializing an internally tagged unit variant. +/// +/// Not public API. +pub struct InternallyTaggedUnitVisitor<'a> { + type_name: &'a str, + variant_name: &'a str, +} + +impl<'a> InternallyTaggedUnitVisitor<'a> { + /// Not public API. + pub fn new(type_name: &'a str, variant_name: &'a str) -> Self { + InternallyTaggedUnitVisitor { + type_name: type_name, + variant_name: variant_name, + } + } +} + +impl<'a> Visitor for InternallyTaggedUnitVisitor<'a> { + type Value = (); + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "unit variant {}::{}", self.type_name, self.variant_name) + } + + fn visit_map(self, _: V) -> Result<(), V::Error> + where V: MapVisitor + { + Ok(()) + } +} + +/// Visitor for deserializing an untagged unit variant. +/// +/// Not public API. +pub struct UntaggedUnitVisitor<'a> { + type_name: &'a str, + variant_name: &'a str, +} + +impl<'a> UntaggedUnitVisitor<'a> { + /// Not public API. + pub fn new(type_name: &'a str, variant_name: &'a str) -> Self { + UntaggedUnitVisitor { + type_name: type_name, + variant_name: variant_name, + } + } +} + +impl<'a> Visitor for UntaggedUnitVisitor<'a> { + type Value = (); + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "unit variant {}::{}", self.type_name, self.variant_name) + } + + fn visit_unit(self) -> Result<(), E> + where E: de::Error + { + Ok(()) + } +} diff --git a/serde/src/de/mod.rs b/serde/src/de/mod.rs index ce86ce1e..8471c995 100644 --- a/serde/src/de/mod.rs +++ b/serde/src/de/mod.rs @@ -115,6 +115,8 @@ mod from_primitive; // Helpers used by generated code. Not public API. #[doc(hidden)] pub mod private; +#[cfg(any(feature = "std", feature = "collections"))] +mod content; /////////////////////////////////////////////////////////////////////////////// diff --git a/serde/src/de/private.rs b/serde/src/de/private.rs index 1ff206cd..750eecdc 100644 --- a/serde/src/de/private.rs +++ b/serde/src/de/private.rs @@ -2,6 +2,14 @@ use core::marker::PhantomData; use de::{Deserialize, Deserializer, Error, Visitor}; +#[cfg(any(feature = "std", feature = "collections"))] +pub use de::content::{ + Content, + TaggedContentVisitor, + InternallyTaggedUnitVisitor, + UntaggedUnitVisitor, +}; + /// If the missing field is of type `Option` then treat is as `None`, /// otherwise it is an error. pub fn missing_field(field: &'static str) -> Result diff --git a/serde/src/de/value.rs b/serde/src/de/value.rs index 32b69d5f..85d15267 100644 --- a/serde/src/de/value.rs +++ b/serde/src/de/value.rs @@ -428,7 +428,9 @@ impl SeqDeserializer } } - fn end(&mut self) -> Result<(), E> { + /// Check for remaining elements after passing a `SeqDeserializer` to + /// `Visitor::visit_seq`. + pub fn end(mut self) -> Result<(), E> { let mut remaining = 0; while self.iter.next().is_some() { remaining += 1; @@ -610,17 +612,9 @@ impl MapDeserializer } } - fn next_pair(&mut self) -> Option<(::First, ::Second)> { - match self.iter.next() { - Some(kv) => { - self.count += 1; - Some(private::Pair::split(kv)) - } - None => None, - } - } - - fn end(&mut self) -> Result<(), E> { + /// Check for remaining elements after passing a `MapDeserializer` to + /// `Visitor::visit_map`. + pub fn end(mut self) -> Result<(), E> { let mut remaining = 0; while self.iter.next().is_some() { remaining += 1; @@ -633,6 +627,16 @@ impl MapDeserializer Err(de::Error::invalid_length(self.count + remaining, &ExpectedInMap(self.count))) } } + + fn next_pair(&mut self) -> Option<(::First, ::Second)> { + match self.iter.next() { + Some(kv) => { + self.count += 1; + Some(private::Pair::split(kv)) + } + None => None, + } + } } impl de::Deserializer for MapDeserializer diff --git a/serde/src/ser/mod.rs b/serde/src/ser/mod.rs index a18346ce..42ccd169 100644 --- a/serde/src/ser/mod.rs +++ b/serde/src/ser/mod.rs @@ -106,6 +106,10 @@ use core::fmt::Display; mod impls; mod impossible; +// Helpers used by generated code. Not public API. +#[doc(hidden)] +pub mod private; + pub use self::impossible::Impossible; /////////////////////////////////////////////////////////////////////////////// diff --git a/serde/src/ser/private.rs b/serde/src/ser/private.rs new file mode 100644 index 00000000..9d17ec78 --- /dev/null +++ b/serde/src/ser/private.rs @@ -0,0 +1,235 @@ +use core::fmt::{self, Display}; + +use ser::{self, Serialize, Serializer, SerializeMap, SerializeStruct}; + +/// Not public API. +pub fn serialize_tagged_newtype( + serializer: S, + type_ident: &'static str, + variant_ident: &'static str, + tag: &'static str, + variant_name: &'static str, + value: T, +) -> Result + where S: Serializer, + T: Serialize +{ + value.serialize(TaggedSerializer { + type_ident: type_ident, + variant_ident: variant_ident, + tag: tag, + variant_name: variant_name, + delegate: serializer, + }) +} + +struct TaggedSerializer { + type_ident: &'static str, + variant_ident: &'static str, + tag: &'static str, + variant_name: &'static str, + delegate: S, +} + +enum Unsupported { + Boolean, + Integer, + Float, + Char, + String, + ByteArray, + Optional, + Unit, + UnitStruct, + Sequence, + Tuple, + TupleStruct, + Enum, +} + +impl Display for Unsupported { + fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + match *self { + Unsupported::Boolean => formatter.write_str("a boolean"), + Unsupported::Integer => formatter.write_str("an integer"), + Unsupported::Float => formatter.write_str("a float"), + Unsupported::Char => formatter.write_str("a char"), + Unsupported::String => formatter.write_str("a string"), + Unsupported::ByteArray => formatter.write_str("a byte array"), + Unsupported::Optional => formatter.write_str("an optional"), + Unsupported::Unit => formatter.write_str("unit"), + Unsupported::UnitStruct => formatter.write_str("a unit struct"), + Unsupported::Sequence => formatter.write_str("a sequence"), + Unsupported::Tuple => formatter.write_str("a tuple"), + Unsupported::TupleStruct => formatter.write_str("a tuple struct"), + Unsupported::Enum => formatter.write_str("an enum"), + } + } +} + +struct Error { + type_ident: &'static str, + variant_ident: &'static str, + ty: Unsupported, +} + +impl Display for Error { + fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, + "cannot serialize tagged newtype variant {}::{} containing {}", + self.type_ident, self.variant_ident, self.ty) + } +} + +impl TaggedSerializer + where S: Serializer +{ + fn bad_type(self, what: Unsupported) -> S::Error { + ser::Error::custom(Error { + type_ident: self.type_ident, + variant_ident: self.variant_ident, + ty: what, + }) + } +} + +impl Serializer for TaggedSerializer + where S: Serializer +{ + type Ok = S::Ok; + type Error = S::Error; + + type SerializeSeq = S::SerializeSeq; + type SerializeTuple = S::SerializeTuple; + type SerializeTupleStruct = S::SerializeTupleStruct; + type SerializeTupleVariant = S::SerializeTupleVariant; + type SerializeMap = S::SerializeMap; + type SerializeStruct = S::SerializeStruct; + type SerializeStructVariant = S::SerializeStructVariant; + + fn serialize_bool(self, _: bool) -> Result { + Err(self.bad_type(Unsupported::Boolean)) + } + + fn serialize_i8(self, _: i8) -> Result { + Err(self.bad_type(Unsupported::Integer)) + } + + fn serialize_i16(self, _: i16) -> Result { + Err(self.bad_type(Unsupported::Integer)) + } + + fn serialize_i32(self, _: i32) -> Result { + Err(self.bad_type(Unsupported::Integer)) + } + + fn serialize_i64(self, _: i64) -> Result { + Err(self.bad_type(Unsupported::Integer)) + } + + fn serialize_u8(self, _: u8) -> Result { + Err(self.bad_type(Unsupported::Integer)) + } + + fn serialize_u16(self, _: u16) -> Result { + Err(self.bad_type(Unsupported::Integer)) + } + + fn serialize_u32(self, _: u32) -> Result { + Err(self.bad_type(Unsupported::Integer)) + } + + fn serialize_u64(self, _: u64) -> Result { + Err(self.bad_type(Unsupported::Integer)) + } + + fn serialize_f32(self, _: f32) -> Result { + Err(self.bad_type(Unsupported::Float)) + } + + fn serialize_f64(self, _: f64) -> Result { + Err(self.bad_type(Unsupported::Float)) + } + + fn serialize_char(self, _: char) -> Result { + Err(self.bad_type(Unsupported::Char)) + } + + fn serialize_str(self, _: &str) -> Result { + Err(self.bad_type(Unsupported::String)) + } + + fn serialize_bytes(self, _: &[u8]) -> Result { + Err(self.bad_type(Unsupported::ByteArray)) + } + + fn serialize_none(self) -> Result { + Err(self.bad_type(Unsupported::Optional)) + } + + fn serialize_some(self, _: &T) -> Result + where T: Serialize + { + Err(self.bad_type(Unsupported::Optional)) + } + + fn serialize_unit(self) -> Result { + Err(self.bad_type(Unsupported::Unit)) + } + + fn serialize_unit_struct(self, _: &'static str) -> Result { + Err(self.bad_type(Unsupported::UnitStruct)) + } + + fn serialize_unit_variant(self, _: &'static str, _: usize, _: &'static str) -> Result { + Err(self.bad_type(Unsupported::Enum)) + } + + fn serialize_newtype_struct(self, _: &'static str, value: &T) -> Result + where T: Serialize + { + value.serialize(self) + } + + fn serialize_newtype_variant(self, _: &'static str, _: usize, _: &'static str, _: &T) -> Result + where T: Serialize + { + Err(self.bad_type(Unsupported::Enum)) + } + + fn serialize_seq(self, _: Option) -> Result { + Err(self.bad_type(Unsupported::Sequence)) + } + + fn serialize_seq_fixed_size(self, _: usize) -> Result { + Err(self.bad_type(Unsupported::Sequence)) + } + + fn serialize_tuple(self, _: usize) -> Result { + Err(self.bad_type(Unsupported::Tuple)) + } + + fn serialize_tuple_struct(self, _: &'static str, _: usize) -> Result { + Err(self.bad_type(Unsupported::TupleStruct)) + } + + fn serialize_tuple_variant(self, _: &'static str, _: usize, _: &'static str, _: usize) -> Result { + Err(self.bad_type(Unsupported::Enum)) + } + + fn serialize_map(self, len: Option) -> Result { + let mut map = try!(self.delegate.serialize_map(len.map(|len| len + 1))); + try!(map.serialize_entry(self.tag, self.variant_name)); + Ok(map) + } + + fn serialize_struct(self, name: &'static str, len: usize) -> Result { + let mut state = try!(self.delegate.serialize_struct(name, len + 1)); + try!(state.serialize_field(self.tag, self.variant_name)); + Ok(state) + } + + fn serialize_struct_variant(self, _: &'static str, _: usize, _: &'static str, _: usize) -> Result { + Err(self.bad_type(Unsupported::Enum)) + } +} diff --git a/serde_codegen_internals/src/attr.rs b/serde_codegen_internals/src/attr.rs index 575dad5a..97fa4650 100644 --- a/serde_codegen_internals/src/attr.rs +++ b/serde_codegen_internals/src/attr.rs @@ -92,6 +92,32 @@ pub struct Item { deny_unknown_fields: bool, ser_bound: Option>, de_bound: Option>, + tag: EnumTag, +} + +/// Styles of representing an enum. +#[derive(Debug)] +pub enum EnumTag { + /// The default. + /// + /// ```json + /// {"variant1": {"key1": "value1", "key2": "value2"}} + /// ``` + External, + + /// `#[serde(tag = "type")]` + /// + /// ```json + /// {"type": "variant1", "key1": "value1", "key2": "value2"} + /// ``` + Internal(String), + + /// `#[serde(untagged)]` + /// + /// ```json + /// {"key1": "value1", "key2": "value2"} + /// ``` + None, } impl Item { @@ -102,6 +128,8 @@ impl Item { let mut deny_unknown_fields = BoolAttr::none(cx, "deny_unknown_fields"); let mut ser_bound = Attr::none(cx, "bound"); let mut de_bound = Attr::none(cx, "bound"); + let mut untagged = BoolAttr::none(cx, "untagged"); + let mut internal_tag = Attr::none(cx, "tag"); for meta_items in item.attrs.iter().filter_map(get_serde_meta_items) { for meta_item in meta_items { @@ -143,6 +171,32 @@ impl Item { } } + // Parse `#[serde(untagged)]` + MetaItem(Word(ref name)) if name == "untagged" => { + match item.body { + syn::Body::Enum(_) => { + untagged.set_true(); + } + syn::Body::Struct(_) => { + cx.error("#[serde(untagged)] can only be used on enums") + } + } + } + + // Parse `#[serde(tag = "type")]` + MetaItem(NameValue(ref name, ref lit)) if name == "tag" => { + if let Ok(s) = get_string_from_lit(cx, name.as_ref(), name.as_ref(), lit) { + match item.body { + syn::Body::Enum(_) => { + internal_tag.set(s); + } + syn::Body::Struct(_) => { + cx.error("#[serde(tag = \"...\")] can only be used on enums") + } + } + } + } + MetaItem(ref meta_item) => { cx.error(format!("unknown serde container attribute `{}`", meta_item.name())); @@ -155,6 +209,32 @@ impl Item { } } + let tag = match (untagged.get(), internal_tag.get()) { + (false, None) => EnumTag::External, + (true, None) => EnumTag::None, + (false, Some(tag)) => { + // Check that there are no tuple variants. + if let syn::Body::Enum(ref variants) = item.body { + for variant in variants { + match variant.data { + syn::VariantData::Struct(_) | syn::VariantData::Unit => {} + syn::VariantData::Tuple(ref fields) => { + if fields.len() != 1 { + cx.error("#[serde(tag = \"...\")] cannot be used with tuple variants"); + break; + } + } + } + } + } + EnumTag::Internal(tag) + } + (true, Some(_)) => { + cx.error("enum cannot be both untagged and internally tagged"); + EnumTag::External // doesn't matter, will error + } + }; + Item { name: Name { serialize: ser_name.get().unwrap_or_else(|| item.ident.to_string()), @@ -163,6 +243,7 @@ impl Item { deny_unknown_fields: deny_unknown_fields.get(), ser_bound: ser_bound.get(), de_bound: de_bound.get(), + tag: tag, } } @@ -181,6 +262,10 @@ impl Item { pub fn de_bound(&self) -> Option<&[syn::WherePredicate]> { self.de_bound.as_ref().map(|vec| &vec[..]) } + + pub fn tag(&self) -> &EnumTag { + &self.tag + } } /// Represents variant attribute information diff --git a/serde_derive/src/de.rs b/serde_derive/src/de.rs index 3d7b5d1f..b72d78ad 100644 --- a/serde_derive/src/de.rs +++ b/serde_derive/src/de.rs @@ -110,7 +110,8 @@ fn deserialize_body( impl_generics, ty, fields, - &item.attrs) + &item.attrs, + None) } Body::Struct(Style::Tuple, ref fields) | Body::Struct(Style::Newtype, ref fields) => { @@ -124,7 +125,8 @@ fn deserialize_body( impl_generics, ty, fields, - &item.attrs) + &item.attrs, + None) } Body::Struct(Style::Unit, _) => { deserialize_unit_struct( @@ -238,6 +240,7 @@ fn deserialize_tuple( ty: syn::Ty, fields: &[Field], item_attrs: &attr::Item, + deserializer: Option, ) -> Tokens { let where_clause = &impl_generics.where_clause; @@ -274,7 +277,9 @@ fn deserialize_tuple( false, ); - let dispatch = if is_enum { + let dispatch = if let Some(deserializer) = deserializer { + quote!(_serde::Deserializer::deserialize(#deserializer, #visitor_expr)) + } else if is_enum { quote!(_serde::de::VariantVisitor::visit_tuple(visitor, #nfields, #visitor_expr)) } else if nfields == 1 { let type_name = item_attrs.name().deserialize_name(); @@ -424,7 +429,11 @@ fn deserialize_struct( ty: syn::Ty, fields: &[Field], item_attrs: &attr::Item, + deserializer: Option, ) -> Tokens { + let is_enum = variant_ident.is_some(); + let is_untagged = deserializer.is_some(); + let where_clause = &impl_generics.where_clause; let (visitor_item, visitor_ty, visitor_expr) = deserialize_visitor(impl_generics); @@ -454,8 +463,11 @@ fn deserialize_struct( item_attrs, ); - let is_enum = variant_ident.is_some(); - let dispatch = if is_enum { + let dispatch = if let Some(deserializer) = deserializer { + quote! { + _serde::Deserializer::deserialize(#deserializer, #visitor_expr) + } + } else if is_enum { quote! { _serde::de::VariantVisitor::visit_struct(visitor, FIELDS, #visitor_expr) } @@ -473,6 +485,20 @@ fn deserialize_struct( quote!(mut visitor) }; + let visit_seq = if is_untagged { + // untagged struct variants do not get a visit_seq method + None + } else { + Some(quote! { + #[inline] + fn visit_seq<__V>(self, #visitor_var: __V) -> _serde::export::Result<#ty, __V::Error> + where __V: _serde::de::SeqVisitor + { + #visit_seq + } + }) + }; + quote!({ #field_visitor @@ -485,12 +511,7 @@ fn deserialize_struct( _serde::export::fmt::Formatter::write_str(formatter, #expecting) } - #[inline] - fn visit_seq<__V>(self, #visitor_var: __V) -> _serde::export::Result<#ty, __V::Error> - where __V: _serde::de::SeqVisitor - { - #visit_seq - } + #visit_seq #[inline] fn visit_map<__V>(self, mut visitor: __V) -> _serde::export::Result<#ty, __V::Error> @@ -512,6 +533,45 @@ fn deserialize_item_enum( ty: syn::Ty, variants: &[Variant], item_attrs: &attr::Item +) -> Tokens { + match *item_attrs.tag() { + attr::EnumTag::External => { + deserialize_externally_tagged_enum( + type_ident, + impl_generics, + ty, + variants, + item_attrs, + ) + } + attr::EnumTag::Internal(ref tag) => { + deserialize_internally_tagged_enum( + type_ident, + impl_generics, + ty, + variants, + item_attrs, + tag, + ) + } + attr::EnumTag::None => { + deserialize_untagged_enum( + type_ident, + impl_generics, + ty, + variants, + item_attrs, + ) + } + } +} + +fn deserialize_externally_tagged_enum( + type_ident: &syn::Ident, + impl_generics: &syn::Generics, + ty: syn::Ty, + variants: &[Variant], + item_attrs: &attr::Item, ) -> Tokens { let where_clause = &impl_generics.where_clause; @@ -545,7 +605,7 @@ fn deserialize_item_enum( .map(|(i, variant)| { let variant_name = field_i(i); - let block = deserialize_variant( + let block = deserialize_externally_tagged_variant( type_ident, impl_generics, ty.clone(), @@ -604,7 +664,111 @@ fn deserialize_item_enum( }) } -fn deserialize_variant( +fn deserialize_internally_tagged_enum( + type_ident: &syn::Ident, + impl_generics: &syn::Generics, + ty: syn::Ty, + variants: &[Variant], + item_attrs: &attr::Item, + tag: &str, +) -> Tokens { + let variant_names_idents: Vec<_> = variants.iter() + .enumerate() + .filter(|&(_, variant)| !variant.attrs.skip_deserializing()) + .map(|(i, variant)| (variant.attrs.name().deserialize_name(), field_i(i))) + .collect(); + + let variants_stmt = { + let variant_names = variant_names_idents.iter().map(|&(ref name, _)| name); + quote! { + const VARIANTS: &'static [&'static str] = &[ #(#variant_names),* ]; + } + }; + + let variant_visitor = deserialize_field_visitor( + variant_names_idents, + item_attrs, + true, + ); + + // Match arms to extract a variant from a string + let variant_arms = variants.iter() + .enumerate() + .filter(|&(_, variant)| !variant.attrs.skip_deserializing()) + .map(|(i, variant)| { + let variant_name = field_i(i); + + let block = deserialize_internally_tagged_variant( + type_ident, + impl_generics, + ty.clone(), + variant, + item_attrs, + quote!(_tagged.content), + ); + + quote! { + __Field::#variant_name => #block + } + }); + + quote!({ + #variant_visitor + + #variants_stmt + + let _tagged = try!(_serde::Deserializer::deserialize( + deserializer, + _serde::de::private::TaggedContentVisitor::<__Field, __D::Error>::new(#tag))); + + match _tagged.tag { + #(#variant_arms)* + } + }) +} + +fn deserialize_untagged_enum( + type_ident: &syn::Ident, + impl_generics: &syn::Generics, + ty: syn::Ty, + variants: &[Variant], + item_attrs: &attr::Item, +) -> Tokens { + let attempts = variants.iter() + .filter(|variant| !variant.attrs.skip_deserializing()) + .map(|variant| { + deserialize_untagged_variant( + type_ident, + impl_generics, + ty.clone(), + variant, + item_attrs, + quote!(&_content), + ) + }); + + // TODO this message could be better by saving the errors from the failed + // attempts. The heuristic used by TOML was to count the number of fields + // processed before an error, and use the error that happened after the + // largest number of fields. I'm not sure I like that. Maybe it would be + // better to save all the errors and combine them into one message that + // explains why none of the variants matched. + let fallthrough_msg = format!("data did not match any variant of untagged enum {}", type_ident); + + quote!({ + let _content = try!(<_serde::de::private::Content<__D::Error> as _serde::Deserialize>::deserialize(deserializer)); + + #( + if let _serde::export::Ok(ok) = #attempts { + return _serde::export::Ok(ok); + } + )* + + _serde::export::Err(_serde::de::Error::custom(#fallthrough_msg)) + }) +} + +fn deserialize_externally_tagged_variant( type_ident: &syn::Ident, generics: &syn::Generics, ty: syn::Ty, @@ -621,7 +785,7 @@ fn deserialize_variant( }) } Style::Newtype => { - deserialize_newtype_variant( + deserialize_externally_tagged_newtype_variant( type_ident, variant_ident, generics, @@ -636,6 +800,7 @@ fn deserialize_variant( ty, &variant.fields, item_attrs, + None, ) } Style::Struct => { @@ -646,22 +811,115 @@ fn deserialize_variant( ty, &variant.fields, item_attrs, + None, ) } } } -fn deserialize_newtype_variant( +fn deserialize_internally_tagged_variant( + type_ident: &syn::Ident, + generics: &syn::Generics, + ty: syn::Ty, + variant: &Variant, + item_attrs: &attr::Item, + deserializer: Tokens, +) -> Tokens { + let variant_ident = &variant.ident; + + match variant.style { + Style::Unit => { + let type_name = type_ident.as_ref(); + let variant_name = variant.ident.as_ref(); + quote!({ + try!(_serde::Deserializer::deserialize(#deserializer, _serde::de::private::InternallyTaggedUnitVisitor::new(#type_name, #variant_name))); + _serde::export::Ok(#type_ident::#variant_ident) + }) + } + Style::Newtype | Style::Struct => { + deserialize_untagged_variant( + type_ident, + generics, + ty, + variant, + item_attrs, + deserializer, + ) + } + Style::Tuple => unreachable!("checked in serde_codegen_internals"), + } +} + +fn deserialize_untagged_variant( + type_ident: &syn::Ident, + generics: &syn::Generics, + ty: syn::Ty, + variant: &Variant, + item_attrs: &attr::Item, + deserializer: Tokens, +) -> Tokens { + let variant_ident = &variant.ident; + + match variant.style { + Style::Unit => { + let type_name = type_ident.as_ref(); + let variant_name = variant.ident.as_ref(); + quote! { + _serde::export::Result::map( + _serde::Deserializer::deserialize( + #deserializer, + _serde::de::private::UntaggedUnitVisitor::new(#type_name, #variant_name) + ), + |()| #type_ident::#variant_ident) + } + } + Style::Newtype => { + deserialize_untagged_newtype_variant( + type_ident, + variant_ident, + generics, + &variant.fields[0], + deserializer, + ) + } + Style::Tuple => { + deserialize_tuple( + type_ident, + Some(variant_ident), + generics, + ty, + &variant.fields, + item_attrs, + Some(deserializer), + ) + } + Style::Struct => { + deserialize_struct( + type_ident, + Some(variant_ident), + generics, + ty, + &variant.fields, + item_attrs, + Some(deserializer), + ) + } + } +} + +fn deserialize_externally_tagged_newtype_variant( type_ident: &syn::Ident, variant_ident: &syn::Ident, impl_generics: &syn::Generics, field: &Field, ) -> Tokens { - let visit = match field.attrs.deserialize_with() { + match field.attrs.deserialize_with() { None => { let field_ty = &field.ty; quote! { - try!(_serde::de::VariantVisitor::visit_newtype::<#field_ty>(visitor)) + _serde::export::Result::map( + _serde::de::VariantVisitor::visit_newtype::<#field_ty>(visitor), + #type_ident::#variant_ident), } } Some(path) => { @@ -670,12 +928,41 @@ fn deserialize_newtype_variant( quote!({ #wrapper #wrapper_impl - try!(_serde::de::VariantVisitor::visit_newtype::<#wrapper_ty>(visitor)).value + _serde::export::Result::map( + _serde::de::VariantVisitor::visit_newtype::<#wrapper_ty>(visitor), + |_wrapper| #type_ident::#variant_ident(_wrapper.value)) + }) + } + } +} + +fn deserialize_untagged_newtype_variant( + type_ident: &syn::Ident, + variant_ident: &syn::Ident, + impl_generics: &syn::Generics, + field: &Field, + deserializer: Tokens, +) -> Tokens { + match field.attrs.deserialize_with() { + None => { + let field_ty = &field.ty; + quote!({ + _serde::export::Result::map( + <#field_ty as _serde::Deserialize>::deserialize(#deserializer), + #type_ident::#variant_ident) + }) + } + Some(path) => { + let (wrapper, wrapper_impl, wrapper_ty) = wrap_deserialize_with( + type_ident, impl_generics, field.ty, path); + quote!({ + #wrapper + #wrapper_impl + _serde::export::Result::map( + <#wrapper_ty as _serde::Deserialize>::deserialize(#deserializer), + |_wrapper| #type_ident::#variant_ident(_wrapper.value)) }) } - }; - quote! { - _serde::export::Ok(#type_ident::#variant_ident(#visit)), } } diff --git a/serde_derive/src/ser.rs b/serde_derive/src/ser.rs index 6478de41..52c86c7e 100644 --- a/serde_derive/src/ser.rs +++ b/serde_derive/src/ser.rs @@ -251,14 +251,11 @@ fn serialize_variant( variant_index: usize, item_attrs: &attr::Item, ) -> Tokens { - let type_name = item_attrs.name().serialize_name(); - let variant_ident = variant.ident.clone(); - let variant_name = variant.attrs.name().serialize_name(); if variant.attrs.skip_serializing() { let skipped_msg = format!("the enum variant {}::{} cannot be serialized", - type_ident, variant_ident); + type_ident, variant_ident); let skipped_err = quote! { _serde::export::Err(_serde::ser::Error::custom(#skipped_msg)) }; @@ -271,140 +268,351 @@ fn serialize_variant( #type_ident::#variant_ident #fields_pat => #skipped_err, } } else { // variant wasn't skipped - match variant.style { + let case = match variant.style { Style::Unit => { quote! { - #type_ident::#variant_ident => - _serde::Serializer::serialize_unit_variant( - _serializer, - #type_name, - #variant_index, - #variant_name, - ), + #type_ident::#variant_ident } - }, + } Style::Newtype => { - let block = serialize_newtype_variant( - type_name, - variant_index, - variant_name, - ty, - generics, - &variant.fields[0], - ); - quote! { - #type_ident::#variant_ident(ref __simple_value) => #block, + #type_ident::#variant_ident(ref __simple_value) } - }, + } Style::Tuple => { let field_names = (0 .. variant.fields.len()) .map(|i| Ident::new(format!("__field{}", i))); - - let block = serialize_tuple_variant( - type_name, - variant_index, - variant_name, - generics, - ty, - &variant.fields, - ); - quote! { - #type_ident::#variant_ident(#(ref #field_names),*) => { #block } + #type_ident::#variant_ident(#(ref #field_names),*) } } Style::Struct => { let fields = variant.fields.iter() .map(|f| f.ident.clone().expect("struct variant has unnamed fields")); + quote! { + #type_ident::#variant_ident { #(ref #fields),* } + } + } + }; - let block = serialize_struct_variant( - variant_index, - variant_name, + let body = match *item_attrs.tag() { + attr::EnumTag::External => { + serialize_externally_tagged_variant( generics, ty, - &variant.fields, + variant, + variant_index, item_attrs, - ); + ) + } + attr::EnumTag::Internal(ref tag) => { + serialize_internally_tagged_variant( + type_ident.as_ref(), + variant_ident.as_ref(), + generics, + ty, + variant, + item_attrs, + tag, + ) + } + attr::EnumTag::None => { + serialize_untagged_variant( + generics, + ty, + variant, + item_attrs, + ) + } + }; - quote! { - #type_ident::#variant_ident { #(ref #fields),* } => { #block } - } + quote! { + #case => #body + } + } +} + +fn serialize_externally_tagged_variant( + generics: &syn::Generics, + ty: syn::Ty, + variant: &Variant, + variant_index: usize, + item_attrs: &attr::Item, +) -> Tokens { + let type_name = item_attrs.name().serialize_name(); + let variant_name = variant.attrs.name().serialize_name(); + + match variant.style { + Style::Unit => { + quote! { + _serde::Serializer::serialize_unit_variant( + _serializer, + #type_name, + #variant_index, + #variant_name, + ), + } + } + Style::Newtype => { + let field = &variant.fields[0]; + let mut field_expr = quote!(__simple_value); + if let Some(path) = field.attrs.serialize_with() { + field_expr = wrap_serialize_with( + &ty, generics, field.ty, path, field_expr); + } + + quote! { + _serde::Serializer::serialize_newtype_variant( + _serializer, + #type_name, + #variant_index, + #variant_name, + #field_expr, + ), + } + } + Style::Tuple => { + let block = serialize_tuple_variant( + TupleVariant::ExternallyTagged { + type_name: type_name, + variant_index: variant_index, + variant_name: variant_name, + }, + generics, + ty, + &variant.fields, + ); + + quote! { + { #block } + } + } + Style::Struct => { + let block = serialize_struct_variant( + StructVariant::ExternallyTagged { + variant_index: variant_index, + variant_name: variant_name, + }, + generics, + ty, + &variant.fields, + item_attrs, + ); + + quote! { + { #block } } } } } -fn serialize_newtype_variant( - type_name: String, - variant_index: usize, - variant_name: String, - item_ty: syn::Ty, +fn serialize_internally_tagged_variant( + type_ident: &str, + variant_ident: &str, generics: &syn::Generics, - field: &Field, + ty: syn::Ty, + variant: &Variant, + item_attrs: &attr::Item, + tag: &str, ) -> Tokens { - let mut field_expr = quote!(__simple_value); - if let Some(path) = field.attrs.serialize_with() { - field_expr = wrap_serialize_with( - &item_ty, generics, field.ty, path, field_expr); - } + let type_name = item_attrs.name().serialize_name(); + let variant_name = variant.attrs.name().serialize_name(); - quote! { - _serde::Serializer::serialize_newtype_variant( - _serializer, - #type_name, - #variant_index, - #variant_name, - #field_expr, - ) + match variant.style { + Style::Unit => { + quote!({ + let mut __struct = try!(_serde::Serializer::serialize_struct( + _serializer, #type_name, 1)); + try!(_serde::ser::SerializeStruct::serialize_field( + &mut __struct, #tag, #variant_name)); + _serde::ser::SerializeStruct::end(__struct) + }) + } + Style::Newtype => { + let field = &variant.fields[0]; + let mut field_expr = quote!(__simple_value); + if let Some(path) = field.attrs.serialize_with() { + field_expr = wrap_serialize_with( + &ty, generics, field.ty, path, field_expr); + } + + quote! { + _serde::ser::private::serialize_tagged_newtype( + _serializer, + #type_ident, + #variant_ident, + #tag, + #variant_name, + #field_expr, + ), + } + } + Style::Struct => { + let block = serialize_struct_variant( + StructVariant::InternallyTagged { + tag: tag, + variant_name: variant_name, + }, + generics, + ty, + &variant.fields, + item_attrs, + ); + + quote! { + { #block } + } + } + Style::Tuple => unreachable!("checked in serde_codegen_internals"), } } +fn serialize_untagged_variant( + generics: &syn::Generics, + ty: syn::Ty, + variant: &Variant, + item_attrs: &attr::Item, +) -> Tokens { + match variant.style { + Style::Unit => { + quote! { + _serde::Serializer::serialize_unit(_serializer), + } + } + Style::Newtype => { + let field = &variant.fields[0]; + let mut field_expr = quote!(__simple_value); + if let Some(path) = field.attrs.serialize_with() { + field_expr = wrap_serialize_with( + &ty, generics, field.ty, path, field_expr); + } + + quote! { + _serde::Serialize::serialize(#field_expr, _serializer), + } + } + Style::Tuple => { + let block = serialize_tuple_variant( + TupleVariant::Untagged, + generics, + ty, + &variant.fields, + ); + + quote! { + { #block } + } + } + Style::Struct => { + let block = serialize_struct_variant( + StructVariant::Untagged, + generics, + ty, + &variant.fields, + item_attrs, + ); + + quote! { + { #block } + } + } + } +} + +enum TupleVariant { + ExternallyTagged { + type_name: String, + variant_index: usize, + variant_name: String, + }, + Untagged, +} + fn serialize_tuple_variant( - type_name: String, - variant_index: usize, - variant_name: String, + context: TupleVariant, generics: &syn::Generics, structure_ty: syn::Ty, fields: &[Field], ) -> Tokens { + let method = match context { + TupleVariant::ExternallyTagged{..} => { + quote!(_serde::ser::SerializeTupleVariant::serialize_field) + } + TupleVariant::Untagged => { + quote!(_serde::ser::SerializeTuple::serialize_element) + } + }; + let serialize_stmts = serialize_tuple_struct_visitor( structure_ty, fields, generics, true, - quote!(_serde::ser::SerializeTupleVariant::serialize_field), + method, ); let len = serialize_stmts.len(); let let_mut = mut_if(len > 0); - quote! { - let #let_mut __serde_state = try!(_serde::Serializer::serialize_tuple_variant( - _serializer, - #type_name, - #variant_index, - #variant_name, - #len)); - #(#serialize_stmts)* - _serde::ser::SerializeTupleVariant::end(__serde_state) + match context { + TupleVariant::ExternallyTagged { type_name, variant_index, variant_name } => { + quote! { + let #let_mut __serde_state = try!(_serde::Serializer::serialize_tuple_variant( + _serializer, + #type_name, + #variant_index, + #variant_name, + #len)); + #(#serialize_stmts)* + _serde::ser::SerializeTupleVariant::end(__serde_state) + } + } + TupleVariant::Untagged => { + quote! { + let #let_mut __serde_state = try!(_serde::Serializer::serialize_tuple( + _serializer, + #len)); + #(#serialize_stmts)* + _serde::ser::SerializeTuple::end(__serde_state) + } + } } } -fn serialize_struct_variant( - variant_index: usize, - variant_name: String, +enum StructVariant<'a> { + ExternallyTagged { + variant_index: usize, + variant_name: String, + }, + InternallyTagged { + tag: &'a str, + variant_name: String, + }, + Untagged, +} + +fn serialize_struct_variant<'a>( + context: StructVariant<'a>, generics: &syn::Generics, ty: syn::Ty, fields: &[Field], item_attrs: &attr::Item, ) -> Tokens { + let method = match context { + StructVariant::ExternallyTagged{..} => { + quote!(_serde::ser::SerializeStructVariant::serialize_field) + } + StructVariant::InternallyTagged{..} | StructVariant::Untagged => { + quote!(_serde::ser::SerializeStruct::serialize_field) + } + }; + let serialize_fields = serialize_struct_visitor( ty.clone(), fields, generics, true, - quote!(_serde::ser::SerializeStructVariant::serialize_field), + method, ); let item_name = item_attrs.name().serialize_name(); @@ -426,16 +634,47 @@ fn serialize_struct_variant( }) .fold(quote!(0), |sum, expr| quote!(#sum + #expr)); - quote! { - let #let_mut __serde_state = try!(_serde::Serializer::serialize_struct_variant( - _serializer, - #item_name, - #variant_index, - #variant_name, - #len, - )); - #(#serialize_fields)* - _serde::ser::SerializeStructVariant::end(__serde_state) + match context { + StructVariant::ExternallyTagged { variant_index, variant_name } => { + quote! { + let #let_mut __serde_state = try!(_serde::Serializer::serialize_struct_variant( + _serializer, + #item_name, + #variant_index, + #variant_name, + #len, + )); + #(#serialize_fields)* + _serde::ser::SerializeStructVariant::end(__serde_state) + } + } + StructVariant::InternallyTagged { tag, variant_name } => { + quote! { + let mut __serde_state = try!(_serde::Serializer::serialize_struct( + _serializer, + #item_name, + #len + 1, + )); + try!(_serde::ser::SerializeStruct::serialize_field( + &mut __serde_state, + #tag, + #variant_name, + )); + #(#serialize_fields)* + _serde::ser::SerializeStruct::end(__serde_state) + } + } + StructVariant::Untagged => { + quote! { + let #let_mut __serde_state = try!(_serde::Serializer::serialize_struct( + _serializer, + #item_name, + #len, + )); + #(#serialize_fields)* + _serde::ser::SerializeStruct::end(__serde_state) + } + } } } diff --git a/serde_test/src/de.rs b/serde_test/src/de.rs index 8bdfdcee..138ed0cd 100644 --- a/serde_test/src/de.rs +++ b/serde_test/src/de.rs @@ -47,91 +47,29 @@ impl Deserializer } } - fn visit_seq(&mut self, len: Option, visitor: V) -> Result + fn visit_seq(&mut self, len: Option, sep: Token<'static>, end: Token<'static>, visitor: V) -> Result where V: Visitor, { let value = try!(visitor.visit_seq(DeserializerSeqVisitor { de: self, len: len, + sep: sep, + end: end.clone(), })); - try!(self.expect_token(Token::SeqEnd)); + try!(self.expect_token(end)); Ok(value) } - fn visit_array(&mut self, len: usize, visitor: V) -> Result - where V: Visitor, - { - let value = try!(visitor.visit_seq(DeserializerArrayVisitor { - de: self, - len: len, - })); - try!(self.expect_token(Token::SeqEnd)); - Ok(value) - } - - fn visit_tuple(&mut self, len: usize, visitor: V) -> Result - where V: Visitor, - { - let value = try!(visitor.visit_seq(DeserializerTupleVisitor { - de: self, - len: len, - })); - try!(self.expect_token(Token::TupleEnd)); - Ok(value) - } - - fn visit_tuple_struct(&mut self, len: usize, visitor: V) -> Result - where V: Visitor, - { - let value = try!(visitor.visit_seq(DeserializerTupleStructVisitor { - de: self, - len: len, - })); - try!(self.expect_token(Token::TupleStructEnd)); - Ok(value) - } - - fn visit_variant_seq(&mut self, len: Option, visitor: V) -> Result - where V: Visitor, - { - let value = try!(visitor.visit_seq(DeserializerVariantSeqVisitor { - de: self, - len: len, - })); - try!(self.expect_token(Token::EnumSeqEnd)); - Ok(value) - } - - fn visit_map(&mut self, len: Option, visitor: V) -> Result + fn visit_map(&mut self, len: Option, sep: Token<'static>, end: Token<'static>, visitor: V) -> Result where V: Visitor, { let value = try!(visitor.visit_map(DeserializerMapVisitor { de: self, len: len, + sep: sep, + end: end.clone(), })); - try!(self.expect_token(Token::MapEnd)); - Ok(value) - } - - fn visit_struct(&mut self, fields: &'static [&'static str], visitor: V) -> Result - where V: Visitor, - { - let value = try!(visitor.visit_map(DeserializerStructVisitor { - de: self, - len: fields.len(), - })); - try!(self.expect_token(Token::StructEnd)); - Ok(value) - } - - fn visit_variant_map(&mut self, len: Option, visitor: V) -> Result - where V: Visitor, - { - let value = try!(visitor.visit_map(DeserializerVariantMapVisitor { - de: self, - len: len, - })); - try!(self.expect_token(Token::EnumMapEnd)); + try!(self.expect_token(end)); Ok(value) } } @@ -141,89 +79,9 @@ impl<'a, I> de::Deserializer for &'a mut Deserializer { type Error = Error; - fn deserialize_seq<__V>(self, visitor: __V) -> Result<__V::Value, Self::Error> - where __V: de::Visitor { - self.deserialize(visitor) - } - fn deserialize_struct_field<__V>(self, visitor: __V) -> Result<__V::Value, Self::Error> - where __V: de::Visitor { - self.deserialize(visitor) - } - fn deserialize_map<__V>(self, visitor: __V) -> Result<__V::Value, Self::Error> - where __V: de::Visitor { - self.deserialize(visitor) - } - fn deserialize_unit<__V>(self, visitor: __V) -> Result<__V::Value, Self::Error> - where __V: de::Visitor { - self.deserialize(visitor) - } - fn deserialize_bytes<__V>(self, visitor: __V) -> Result<__V::Value, Self::Error> - where __V: de::Visitor { - self.deserialize(visitor) - } - fn deserialize_byte_buf<__V>(self, visitor: __V) -> Result<__V::Value, Self::Error> - where __V: de::Visitor { - self.deserialize(visitor) - } - fn deserialize_ignored_any<__V>(self, visitor: __V) -> Result<__V::Value, Self::Error> - where __V: de::Visitor { - self.deserialize(visitor) - } - fn deserialize_string<__V>(self, visitor: __V) -> Result<__V::Value, Self::Error> - where __V: de::Visitor { - self.deserialize(visitor) - } - fn deserialize_str<__V>(self, visitor: __V) -> Result<__V::Value, Self::Error> - where __V: de::Visitor { - self.deserialize(visitor) - } - fn deserialize_char<__V>(self, visitor: __V) -> Result<__V::Value, Self::Error> - where __V: de::Visitor { - self.deserialize(visitor) - } - fn deserialize_i64<__V>(self, visitor: __V) -> Result<__V::Value, Self::Error> - where __V: de::Visitor { - self.deserialize(visitor) - } - fn deserialize_i32<__V>(self, visitor: __V) -> Result<__V::Value, Self::Error> - where __V: de::Visitor { - self.deserialize(visitor) - } - fn deserialize_i16<__V>(self, visitor: __V) -> Result<__V::Value, Self::Error> - where __V: de::Visitor { - self.deserialize(visitor) - } - fn deserialize_i8<__V>(self, visitor: __V) -> Result<__V::Value, Self::Error> - where __V: de::Visitor { - self.deserialize(visitor) - } - fn deserialize_u64<__V>(self, visitor: __V) -> Result<__V::Value, Self::Error> - where __V: de::Visitor { - self.deserialize(visitor) - } - fn deserialize_u32<__V>(self, visitor: __V) -> Result<__V::Value, Self::Error> - where __V: de::Visitor { - self.deserialize(visitor) - } - fn deserialize_u16<__V>(self, visitor: __V) -> Result<__V::Value, Self::Error> - where __V: de::Visitor { - self.deserialize(visitor) - } - fn deserialize_u8<__V>(self, visitor: __V) -> Result<__V::Value, Self::Error> - where __V: de::Visitor { - self.deserialize(visitor) - } - fn deserialize_f32<__V>(self, visitor: __V) -> Result<__V::Value, Self::Error> - where __V: de::Visitor { - self.deserialize(visitor) - } - fn deserialize_f64<__V>(self, visitor: __V) -> Result<__V::Value, Self::Error> - where __V: de::Visitor { - self.deserialize(visitor) - } - fn deserialize_bool<__V>(self, visitor: __V) -> Result<__V::Value, Self::Error> - where __V: de::Visitor { - self.deserialize(visitor) + forward_to_deserialize! { + bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string unit + seq bytes byte_buf map struct_field ignored_any } fn deserialize(self, visitor: V) -> Result @@ -251,16 +109,22 @@ impl<'a, I> de::Deserializer for &'a mut Deserializer Some(Token::Unit) => visitor.visit_unit(), Some(Token::UnitStruct(_name)) => visitor.visit_unit(), Some(Token::SeqStart(len)) => { - self.visit_seq(len, visitor) + self.visit_seq(len, Token::SeqSep, Token::SeqEnd, visitor) } - Some(Token::SeqArrayStart(len))| Some(Token::TupleStructStart(_, len)) => { - self.visit_seq(Some(len), visitor) + Some(Token::SeqArrayStart(len)) => { + self.visit_seq(Some(len), Token::SeqSep, Token::SeqEnd, visitor) + } + Some(Token::TupleStart(len)) => { + self.visit_seq(Some(len), Token::TupleSep, Token::TupleEnd, visitor) + } + Some(Token::TupleStructStart(_, len)) => { + self.visit_seq(Some(len), Token::TupleStructSep, Token::TupleStructEnd, visitor) } Some(Token::MapStart(len)) => { - self.visit_map(len, visitor) + self.visit_map(len, Token::MapSep, Token::MapEnd, visitor) } Some(Token::StructStart(_, len)) => { - self.visit_map(Some(len), visitor) + self.visit_map(Some(len), Token::StructSep, Token::StructEnd, visitor) } Some(token) => Err(Error::UnexpectedToken(token)), None => Err(Error::EndOfTokens), @@ -360,7 +224,7 @@ impl<'a, I> de::Deserializer for &'a mut Deserializer match self.tokens.peek() { Some(&Token::SeqArrayStart(_)) => { self.tokens.next(); - self.visit_array(len, visitor) + self.visit_seq(Some(len), Token::SeqSep, Token::SeqEnd, visitor) } Some(_) => self.deserialize(visitor), None => Err(Error::EndOfTokens), @@ -379,19 +243,19 @@ impl<'a, I> de::Deserializer for &'a mut Deserializer } Some(&Token::SeqStart(_)) => { self.tokens.next(); - self.visit_seq(Some(len), visitor) + self.visit_seq(Some(len), Token::SeqSep, Token::SeqEnd, visitor) } Some(&Token::SeqArrayStart(_)) => { self.tokens.next(); - self.visit_array(len, visitor) + self.visit_seq(Some(len), Token::SeqSep, Token::SeqEnd, visitor) } Some(&Token::TupleStart(_)) => { self.tokens.next(); - self.visit_tuple(len, visitor) + self.visit_seq(Some(len), Token::TupleSep, Token::TupleEnd, visitor) } Some(&Token::TupleStructStart(_, _)) => { self.tokens.next(); - self.visit_tuple_struct(len, visitor) + self.visit_seq(Some(len), Token::TupleStructSep, Token::TupleStructEnd, visitor) } Some(_) => self.deserialize(visitor), None => Err(Error::EndOfTokens), @@ -419,20 +283,20 @@ impl<'a, I> de::Deserializer for &'a mut Deserializer } Some(&Token::SeqStart(_)) => { self.tokens.next(); - self.visit_seq(Some(len), visitor) + self.visit_seq(Some(len), Token::SeqSep, Token::SeqEnd, visitor) } Some(&Token::SeqArrayStart(_)) => { self.tokens.next(); - self.visit_array(len, visitor) + self.visit_seq(Some(len), Token::SeqSep, Token::SeqEnd, visitor) } Some(&Token::TupleStart(_)) => { self.tokens.next(); - self.visit_tuple(len, visitor) + self.visit_seq(Some(len), Token::TupleSep, Token::TupleEnd, visitor) } Some(&Token::TupleStructStart(n, _)) => { self.tokens.next(); if name == n { - self.visit_tuple_struct(len, visitor) + self.visit_seq(Some(len), Token::TupleStructSep, Token::TupleStructEnd, visitor) } else { Err(Error::InvalidName(n)) } @@ -452,14 +316,14 @@ impl<'a, I> de::Deserializer for &'a mut Deserializer Some(&Token::StructStart(n, _)) => { self.tokens.next(); if name == n { - self.visit_struct(fields, visitor) + self.visit_map(Some(fields.len()), Token::StructSep, Token::StructEnd, visitor) } else { Err(Error::InvalidName(n)) } } Some(&Token::MapStart(_)) => { self.tokens.next(); - self.visit_map(Some(fields.len()), visitor) + self.visit_map(Some(fields.len()), Token::MapSep, Token::MapEnd, visitor) } Some(_) => self.deserialize(visitor), None => Err(Error::EndOfTokens), @@ -472,6 +336,8 @@ impl<'a, I> de::Deserializer for &'a mut Deserializer struct DeserializerSeqVisitor<'a, I: 'a> where I: Iterator> { de: &'a mut Deserializer, len: Option, + sep: Token<'static>, + end: Token<'static>, } impl<'a, I> SeqVisitor for DeserializerSeqVisitor<'a, I> @@ -482,158 +348,15 @@ impl<'a, I> SeqVisitor for DeserializerSeqVisitor<'a, I> fn visit_seed(&mut self, seed: T) -> Result, Error> where T: DeserializeSeed, { - match self.de.tokens.peek() { - Some(&Token::SeqSep) => { - self.de.tokens.next(); - self.len = self.len.map(|len| len - 1); - seed.deserialize(&mut *self.de).map(Some) - } - Some(&Token::SeqEnd) => Ok(None), - Some(_) => { - let token = self.de.tokens.next().unwrap(); - Err(Error::UnexpectedToken(token)) - } - None => Err(Error::EndOfTokens), + if self.de.tokens.peek() == Some(&self.end) { + return Ok(None); } - } - - fn size_hint(&self) -> (usize, Option) { - let len = self.len.unwrap_or(0); - (len, self.len) - } -} - -////////////////////////////////////////////////////////////////////////// - -struct DeserializerArrayVisitor<'a, I: 'a> where I: Iterator> { - de: &'a mut Deserializer, - len: usize, -} - -impl<'a, I> SeqVisitor for DeserializerArrayVisitor<'a, I> - where I: Iterator>, -{ - type Error = Error; - - fn visit_seed(&mut self, seed: T) -> Result, Error> - where T: DeserializeSeed, - { - match self.de.tokens.peek() { - Some(&Token::SeqSep) => { - self.de.tokens.next(); - self.len -= 1; + match self.de.tokens.next() { + Some(ref token) if *token == self.sep => { + self.len = self.len.map(|len| len.saturating_sub(1)); seed.deserialize(&mut *self.de).map(Some) } - Some(&Token::SeqEnd) => Ok(None), - Some(_) => { - let token = self.de.tokens.next().unwrap(); - Err(Error::UnexpectedToken(token)) - } - None => Err(Error::EndOfTokens), - } - } - - fn size_hint(&self) -> (usize, Option) { - (self.len, Some(self.len)) - } -} - -////////////////////////////////////////////////////////////////////////// - -struct DeserializerTupleVisitor<'a, I: 'a> where I: Iterator> { - de: &'a mut Deserializer, - len: usize, -} - -impl<'a, I> SeqVisitor for DeserializerTupleVisitor<'a, I> - where I: Iterator>, -{ - type Error = Error; - - fn visit_seed(&mut self, seed: T) -> Result, Error> - where T: DeserializeSeed, - { - match self.de.tokens.peek() { - Some(&Token::TupleSep) => { - self.de.tokens.next(); - self.len -= 1; - seed.deserialize(&mut *self.de).map(Some) - } - Some(&Token::TupleEnd) => Ok(None), - Some(_) => { - let token = self.de.tokens.next().unwrap(); - Err(Error::UnexpectedToken(token)) - } - None => Err(Error::EndOfTokens), - } - } - - fn size_hint(&self) -> (usize, Option) { - (self.len, Some(self.len)) - } -} - -////////////////////////////////////////////////////////////////////////// - -struct DeserializerTupleStructVisitor<'a, I: 'a> where I: Iterator> { - de: &'a mut Deserializer, - len: usize, -} - -impl<'a, I> SeqVisitor for DeserializerTupleStructVisitor<'a, I> - where I: Iterator>, -{ - type Error = Error; - - fn visit_seed(&mut self, seed: T) -> Result, Error> - where T: DeserializeSeed, - { - match self.de.tokens.peek() { - Some(&Token::TupleStructSep) => { - self.de.tokens.next(); - self.len -= 1; - seed.deserialize(&mut *self.de).map(Some) - } - Some(&Token::TupleStructEnd) => Ok(None), - Some(_) => { - let token = self.de.tokens.next().unwrap(); - Err(Error::UnexpectedToken(token)) - } - None => Err(Error::EndOfTokens), - } - } - - fn size_hint(&self) -> (usize, Option) { - (self.len, Some(self.len)) - } -} - -////////////////////////////////////////////////////////////////////////// - -struct DeserializerVariantSeqVisitor<'a, I: 'a> where I: Iterator> { - de: &'a mut Deserializer, - len: Option, -} - -impl<'a, I> SeqVisitor for DeserializerVariantSeqVisitor<'a, I> - where I: Iterator>, -{ - type Error = Error; - - fn visit_seed(&mut self, seed: T) -> Result, Error> - where T: DeserializeSeed, - { - match self.de.tokens.peek() { - Some(&Token::EnumSeqSep) => { - self.de.tokens.next(); - self.len = self.len.map(|len| len - 1); - seed.deserialize(&mut *self.de).map(Some) - } - Some(&Token::EnumSeqEnd) => Ok(None), - Some(_) => { - let token = self.de.tokens.next().unwrap(); - Err(Error::UnexpectedToken(token)) - } + Some(other) => Err(Error::UnexpectedToken(other)), None => Err(Error::EndOfTokens), } } @@ -649,6 +372,8 @@ impl<'a, I> SeqVisitor for DeserializerVariantSeqVisitor<'a, I> struct DeserializerMapVisitor<'a, I: 'a> where I: Iterator> { de: &'a mut Deserializer, len: Option, + sep: Token<'static>, + end: Token<'static>, } impl<'a, I> MapVisitor for DeserializerMapVisitor<'a, I> @@ -659,17 +384,15 @@ impl<'a, I> MapVisitor for DeserializerMapVisitor<'a, I> fn visit_key_seed(&mut self, seed: K) -> Result, Error> where K: DeserializeSeed, { - match self.de.tokens.peek() { - Some(&Token::MapSep) => { - self.de.tokens.next(); - self.len = self.len.map(|len| if len > 0 { len - 1} else { 0 }); + if self.de.tokens.peek() == Some(&self.end) { + return Ok(None); + } + match self.de.tokens.next() { + Some(ref token) if *token == self.sep => { + self.len = self.len.map(|len| len.saturating_sub(1)); seed.deserialize(&mut *self.de).map(Some) } - Some(&Token::MapEnd) => Ok(None), - Some(_) => { - let token = self.de.tokens.next().unwrap(); - Err(Error::UnexpectedToken(token)) - } + Some(other) => Err(Error::UnexpectedToken(other)), None => Err(Error::EndOfTokens), } } @@ -688,47 +411,6 @@ impl<'a, I> MapVisitor for DeserializerMapVisitor<'a, I> ////////////////////////////////////////////////////////////////////////// -struct DeserializerStructVisitor<'a, I: 'a> where I: Iterator> { - de: &'a mut Deserializer, - len: usize, -} - -impl<'a, I> MapVisitor for DeserializerStructVisitor<'a, I> - where I: Iterator>, -{ - type Error = Error; - - fn visit_key_seed(&mut self, seed: K) -> Result, Error> - where K: DeserializeSeed, - { - match self.de.tokens.peek() { - Some(&Token::StructSep) => { - self.de.tokens.next(); - self.len = self.len.saturating_sub(1); - seed.deserialize(&mut *self.de).map(Some) - } - Some(&Token::StructEnd) => Ok(None), - Some(_) => { - let token = self.de.tokens.next().unwrap(); - Err(Error::UnexpectedToken(token)) - } - None => Err(Error::EndOfTokens), - } - } - - fn visit_value_seed(&mut self, seed: V) -> Result - where V: DeserializeSeed, - { - seed.deserialize(&mut *self.de) - } - - fn size_hint(&self) -> (usize, Option) { - (self.len, Some(self.len)) - } -} - -////////////////////////////////////////////////////////////////////////// - struct DeserializerEnumVisitor<'a, I: 'a> where I: Iterator> { de: &'a mut Deserializer, } @@ -803,7 +485,7 @@ impl<'a, I> VariantVisitor for DeserializerEnumVisitor<'a, I> let token = self.de.tokens.next().unwrap(); if len == enum_len { - self.de.visit_variant_seq(Some(len), visitor) + self.de.visit_seq(Some(len), Token::EnumSeqSep, Token::EnumSeqEnd, visitor) } else { Err(Error::UnexpectedToken(token)) } @@ -812,7 +494,7 @@ impl<'a, I> VariantVisitor for DeserializerEnumVisitor<'a, I> let token = self.de.tokens.next().unwrap(); if len == enum_len { - self.de.visit_seq(Some(len), visitor) + self.de.visit_seq(Some(len), Token::SeqSep, Token::SeqEnd, visitor) } else { Err(Error::UnexpectedToken(token)) } @@ -834,7 +516,7 @@ impl<'a, I> VariantVisitor for DeserializerEnumVisitor<'a, I> let token = self.de.tokens.next().unwrap(); if fields.len() == enum_len { - self.de.visit_variant_map(Some(fields.len()), visitor) + self.de.visit_map(Some(fields.len()), Token::EnumMapSep, Token::EnumMapEnd, visitor) } else { Err(Error::UnexpectedToken(token)) } @@ -843,7 +525,7 @@ impl<'a, I> VariantVisitor for DeserializerEnumVisitor<'a, I> let token = self.de.tokens.next().unwrap(); if fields.len() == enum_len { - self.de.visit_map(Some(fields.len()), visitor) + self.de.visit_map(Some(fields.len()), Token::MapSep, Token::MapEnd, visitor) } else { Err(Error::UnexpectedToken(token)) } @@ -855,45 +537,3 @@ impl<'a, I> VariantVisitor for DeserializerEnumVisitor<'a, I> } } } - -////////////////////////////////////////////////////////////////////////// - -struct DeserializerVariantMapVisitor<'a, I: 'a> where I: Iterator> { - de: &'a mut Deserializer, - len: Option, -} - -impl<'a, I> MapVisitor for DeserializerVariantMapVisitor<'a, I> - where I: Iterator>, -{ - type Error = Error; - - fn visit_key_seed(&mut self, seed: K) -> Result, Error> - where K: DeserializeSeed, - { - match self.de.tokens.peek() { - Some(&Token::EnumMapSep) => { - self.de.tokens.next(); - self.len = self.len.map(|len| if len > 0 { len - 1} else { 0 }); - seed.deserialize(&mut *self.de).map(Some) - } - Some(&Token::EnumMapEnd) => Ok(None), - Some(_) => { - let token = self.de.tokens.next().unwrap(); - Err(Error::UnexpectedToken(token)) - } - None => Err(Error::EndOfTokens), - } - } - - fn visit_value_seed(&mut self, seed: V) -> Result - where V: DeserializeSeed, - { - seed.deserialize(&mut *self.de) - } - - fn size_hint(&self) -> (usize, Option) { - let len = self.len.unwrap_or(0); - (len, self.len) - } -} diff --git a/serde_test/src/lib.rs b/serde_test/src/lib.rs index 945262c0..53f966f6 100644 --- a/serde_test/src/lib.rs +++ b/serde_test/src/lib.rs @@ -1,3 +1,4 @@ +#[macro_use] extern crate serde; mod assert; diff --git a/test_suite/tests/compile-fail/enum-representation/internal-tuple-variant.rs b/test_suite/tests/compile-fail/enum-representation/internal-tuple-variant.rs new file mode 100644 index 00000000..950e8d02 --- /dev/null +++ b/test_suite/tests/compile-fail/enum-representation/internal-tuple-variant.rs @@ -0,0 +1,10 @@ +#[macro_use] +extern crate serde_derive; + +#[derive(Serialize)] //~ ERROR: custom derive attribute panicked +#[serde(tag = "type")] //~^ HELP: #[serde(tag = "...")] cannot be used with tuple variants +enum E { + Tuple(u8, u8), +} + +fn main() {} diff --git a/test_suite/tests/compile-fail/enum-representation/internally-tagged-struct.rs b/test_suite/tests/compile-fail/enum-representation/internally-tagged-struct.rs new file mode 100644 index 00000000..358dfde4 --- /dev/null +++ b/test_suite/tests/compile-fail/enum-representation/internally-tagged-struct.rs @@ -0,0 +1,8 @@ +#[macro_use] +extern crate serde_derive; + +#[derive(Serialize)] //~ ERROR: custom derive attribute panicked +#[serde(tag = "type")] //~^ HELP: #[serde(tag = "...")] can only be used on enums +struct S; + +fn main() {} diff --git a/test_suite/tests/compile-fail/enum-representation/untagged-and-internal.rs b/test_suite/tests/compile-fail/enum-representation/untagged-and-internal.rs new file mode 100644 index 00000000..7442679c --- /dev/null +++ b/test_suite/tests/compile-fail/enum-representation/untagged-and-internal.rs @@ -0,0 +1,12 @@ +#[macro_use] +extern crate serde_derive; + +#[derive(Serialize)] //~ ERROR: custom derive attribute panicked +#[serde(untagged)] +#[serde(tag = "type")] //~^^ HELP: enum cannot be both untagged and internally tagged +enum E { + A(u8), + B(String), +} + +fn main() {} diff --git a/test_suite/tests/compile-fail/enum-representation/untagged-struct.rs b/test_suite/tests/compile-fail/enum-representation/untagged-struct.rs new file mode 100644 index 00000000..611f8416 --- /dev/null +++ b/test_suite/tests/compile-fail/enum-representation/untagged-struct.rs @@ -0,0 +1,8 @@ +#[macro_use] +extern crate serde_derive; + +#[derive(Serialize)] //~ ERROR: custom derive attribute panicked +#[serde(untagged)] //~^ HELP: #[serde(untagged)] can only be used on enums +struct S; + +fn main() {} diff --git a/test_suite/tests/test_de.rs b/test_suite/tests/test_de.rs index b748bb12..6ae92129 100644 --- a/test_suite/tests/test_de.rs +++ b/test_suite/tests/test_de.rs @@ -225,7 +225,7 @@ declare_tests! { ], () => &[ Token::TupleStructStart("Anything", 0), - Token::SeqEnd, + Token::TupleStructEnd, ], } test_unit_struct { @@ -330,7 +330,7 @@ declare_tests! { ], BTreeSet::::new() => &[ Token::TupleStructStart("Anything", 0), - Token::SeqEnd, + Token::TupleStructEnd, ], } test_hashset { @@ -358,7 +358,7 @@ declare_tests! { ], HashSet::::new() => &[ Token::TupleStructStart("Anything", 0), - Token::SeqEnd, + Token::TupleStructEnd, ], hashset![FnvHasher @ 1, 2, 3] => &[ Token::SeqStart(Some(3)), @@ -408,7 +408,7 @@ declare_tests! { ], Vec::::new() => &[ Token::TupleStructStart("Anything", 0), - Token::SeqEnd, + Token::TupleStructEnd, ], } test_array { @@ -472,7 +472,7 @@ declare_tests! { ], [0; 0] => &[ Token::TupleStructStart("Anything", 0), - Token::SeqEnd, + Token::TupleStructEnd, ], } test_tuple { @@ -564,7 +564,7 @@ declare_tests! { ], BTreeMap::::new() => &[ Token::StructStart("Anything", 0), - Token::MapEnd, + Token::StructEnd, ], } test_hashmap { @@ -618,7 +618,7 @@ declare_tests! { ], HashMap::::new() => &[ Token::StructStart("Anything", 0), - Token::MapEnd, + Token::StructEnd, ], hashmap![FnvHasher @ 1 => 2, 3 => 4] => &[ Token::MapStart(Some(2)), diff --git a/test_suite/tests/test_macros.rs b/test_suite/tests/test_macros.rs index 6e5e1d82..21185412 100644 --- a/test_suite/tests/test_macros.rs +++ b/test_suite/tests/test_macros.rs @@ -1,11 +1,14 @@ extern crate serde_test; use self::serde_test::{ + Error, Token, assert_tokens, assert_ser_tokens, assert_de_tokens, + assert_de_tokens_error, }; +use std::collections::BTreeMap; use std::marker::PhantomData; // That tests that the derived Serialize implementation doesn't trigger @@ -625,3 +628,256 @@ fn test_enum_state_field() { ] ); } + +#[test] +fn test_untagged_enum() { + #[derive(Debug, PartialEq, Serialize, Deserialize)] + #[serde(untagged)] + enum Untagged { + A { + a: u8, + }, + B { + b: u8, + }, + C, + D(u8), + E(String), + F(u8, u8), + } + + assert_tokens( + &Untagged::A { a: 1 }, + &[ + Token::StructStart("Untagged", 1), + + Token::StructSep, + Token::Str("a"), + Token::U8(1), + + Token::StructEnd, + ] + ); + + assert_tokens( + &Untagged::B { b: 2 }, + &[ + Token::StructStart("Untagged", 1), + + Token::StructSep, + Token::Str("b"), + Token::U8(2), + + Token::StructEnd, + ] + ); + + assert_tokens( + &Untagged::C, + &[ + Token::Unit, + ] + ); + + assert_tokens( + &Untagged::D(4), + &[ + Token::U8(4), + ] + ); + assert_tokens( + &Untagged::E("e".to_owned()), + &[ + Token::Str("e"), + ] + ); + + assert_tokens( + &Untagged::F(1, 2), + &[ + Token::TupleStart(2), + + Token::TupleSep, + Token::U8(1), + + Token::TupleSep, + Token::U8(2), + + Token::TupleEnd, + ] + ); + + assert_de_tokens_error::( + &[ + Token::Option(false), + ], + Error::Message("data did not match any variant of untagged enum Untagged".to_owned()), + ); + + assert_de_tokens_error::( + &[ + Token::TupleStart(1), + + Token::TupleSep, + Token::U8(1), + + Token::TupleEnd, + ], + Error::Message("data did not match any variant of untagged enum Untagged".to_owned()), + ); + + assert_de_tokens_error::( + &[ + Token::TupleStart(3), + + Token::TupleSep, + Token::U8(1), + + Token::TupleSep, + Token::U8(2), + + Token::TupleSep, + Token::U8(3), + + Token::TupleEnd, + ], + Error::Message("data did not match any variant of untagged enum Untagged".to_owned()), + ); +} + +#[test] +fn test_internally_tagged_enum() { + #[derive(Debug, PartialEq, Serialize, Deserialize)] + struct Newtype(BTreeMap); + + #[derive(Debug, PartialEq, Serialize, Deserialize)] + struct Struct { + f: u8, + } + + #[derive(Debug, PartialEq, Serialize, Deserialize)] + #[serde(tag = "type")] + enum InternallyTagged { + A { + a: u8, + }, + B { + b: u8, + }, + C, + D(BTreeMap), + E(Newtype), + F(Struct), + } + + assert_tokens( + &InternallyTagged::A { a: 1 }, + &[ + Token::StructStart("InternallyTagged", 2), + + Token::StructSep, + Token::Str("type"), + Token::Str("A"), + + Token::StructSep, + Token::Str("a"), + Token::U8(1), + + Token::StructEnd, + ] + ); + + assert_tokens( + &InternallyTagged::B { b: 2 }, + &[ + Token::StructStart("InternallyTagged", 2), + + Token::StructSep, + Token::Str("type"), + Token::Str("B"), + + Token::StructSep, + Token::Str("b"), + Token::U8(2), + + Token::StructEnd, + ] + ); + + assert_tokens( + &InternallyTagged::C, + &[ + Token::StructStart("InternallyTagged", 1), + + Token::StructSep, + Token::Str("type"), + Token::Str("C"), + + Token::StructEnd, + ] + ); + + assert_tokens( + &InternallyTagged::D(BTreeMap::new()), + &[ + Token::MapStart(Some(1)), + + Token::MapSep, + Token::Str("type"), + Token::Str("D"), + + Token::MapEnd, + ] + ); + + assert_tokens( + &InternallyTagged::E(Newtype(BTreeMap::new())), + &[ + Token::MapStart(Some(1)), + + Token::MapSep, + Token::Str("type"), + Token::Str("E"), + + Token::MapEnd, + ] + ); + + assert_tokens( + &InternallyTagged::F(Struct { f: 6 }), + &[ + Token::StructStart("Struct", 2), + + Token::StructSep, + Token::Str("type"), + Token::Str("F"), + + Token::StructSep, + Token::Str("f"), + Token::U8(6), + + Token::StructEnd, + ] + ); + + assert_de_tokens_error::( + &[ + Token::MapStart(Some(0)), + Token::MapEnd, + ], + Error::Message("missing field `type`".to_owned()), + ); + + assert_de_tokens_error::( + &[ + Token::MapStart(Some(1)), + + Token::MapSep, + Token::Str("type"), + Token::Str("Z"), + + Token::MapEnd, + ], + Error::Message("unknown variant `Z`, expected one of `A`, `B`, `C`, `D`, `E`, `F`".to_owned()), + ); +}