diff --git a/serde/src/de/mod.rs b/serde/src/de/mod.rs index 826662df..7ff92104 100644 --- a/serde/src/de/mod.rs +++ b/serde/src/de/mod.rs @@ -1132,6 +1132,20 @@ pub trait Deserializer<'de>: Sized { fn is_human_readable(&self) -> bool { true } + + // Not public API. + #[doc(hidden)] + fn private_deserialize_internally_tagged_enum( + self, + visitor: V, + tag: &'static str, + ) -> Result + where + V: Visitor<'de>, + { + let _ = tag; + self.deserialize_any(visitor) + } } //////////////////////////////////////////////////////////////////////////////// diff --git a/serde/src/private/de.rs b/serde/src/private/de.rs index b05ecb96..6a2ccff1 100644 --- a/serde/src/private/de.rs +++ b/serde/src/private/de.rs @@ -2715,6 +2715,22 @@ where byte_buf option unit unit_struct seq tuple tuple_struct identifier ignored_any } + + fn private_deserialize_internally_tagged_enum( + self, + visitor: V, + tag: &'static str, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_map(FlatInternallyTaggedAccess { + iter: self.0.iter_mut(), + pending: Pending::None, + tag: tag, + _marker: PhantomData, + }) + } } #[cfg(any(feature = "std", feature = "alloc"))] @@ -2786,3 +2802,80 @@ where } } } + +#[cfg(any(feature = "std", feature = "alloc"))] +pub struct FlatInternallyTaggedAccess<'a, 'de: 'a, E> { + iter: slice::IterMut<'a, Option<(Content<'de>, Content<'de>)>>, + pending: Pending<'a, 'de>, + tag: &'static str, + _marker: PhantomData, +} + +#[cfg(any(feature = "std", feature = "alloc"))] +enum Pending<'a, 'de: 'a> { + Content(Content<'de>), + ContentRef(&'a Content<'de>), + None, +} + +#[cfg(any(feature = "std", feature = "alloc"))] +impl<'a, 'de, E> MapAccess<'de> for FlatInternallyTaggedAccess<'a, 'de, E> +where + E: Error, +{ + type Error = E; + + fn next_key_seed(&mut self, seed: T) -> Result, Self::Error> + where + T: DeserializeSeed<'de>, + { + while let Some(item) = self.iter.next() { + let is_tag = match *item { + Some((ref key, _)) => key.as_str().map_or(false, |key| key == self.tag), + None => continue, + }; + + let ret = if is_tag { + let (key, content) = item.take().unwrap(); + self.pending = Pending::Content(content); + + // Could use `ContentDeserializer::new(key)` here but we prefer + // to avoid instantiating `seed.deserialize` twice with + // different Deserializer type parameters. The visitor that + // decides which variant we are looking at does not benefit from + // having ownership of this string. + seed.deserialize(ContentRefDeserializer::new(&key)) + } else { + // Do not take(), instead borrow this entry. The internally + // tagged enum does its own buffering so we can't tell whether + // this entry is going to be consumed. Borrowing here leaves the + // entry available for later flattened fields. + let (ref key, ref content) = *item.as_ref().unwrap(); + self.pending = Pending::ContentRef(content); + seed.deserialize(ContentRefDeserializer::new(key)) + }; + + return ret.map(Some); + } + Ok(None) + } + + fn next_value_seed(&mut self, seed: T) -> Result + where + T: DeserializeSeed<'de>, + { + match mem::replace(&mut self.pending, Pending::None) { + Pending::Content(value) => { + // Could use `ContentDeserializer::new(value)` here but we + // prefer to avoid instantiating `seed.deserialize` twice with + // different Deserializer type parameters. Flatten and internal + // tagging are both relatively slow at runtime anyway so the + // improvement in compile time is more important here than + // potentially saving some string copies. + seed.deserialize(ContentRefDeserializer::new(&value)) + } + Pending::ContentRef(value) => seed.deserialize(ContentRefDeserializer::new(value)), + Pending::None => panic!("value is missing"), + } + } +} diff --git a/serde_derive/src/de.rs b/serde_derive/src/de.rs index d8970090..c8b1e7f3 100644 --- a/serde_derive/src/de.rs +++ b/serde_derive/src/de.rs @@ -1193,9 +1193,10 @@ fn deserialize_internally_tagged_enum( #variants_stmt - let __tagged = try!(_serde::Deserializer::deserialize_any( + let __tagged = try!(_serde::Deserializer::private_deserialize_internally_tagged_enum( __deserializer, - _serde::private::de::TaggedContentVisitor::<__Field>::new(#tag))); + _serde::private::de::TaggedContentVisitor::<__Field>::new(#tag), + #tag)); match __tagged.tag { #(#variant_arms)* diff --git a/test_suite/tests/test_annotations.rs b/test_suite/tests/test_annotations.rs index 5e38ad34..815a47b6 100644 --- a/test_suite/tests/test_annotations.rs +++ b/test_suite/tests/test_annotations.rs @@ -1793,3 +1793,49 @@ fn test_flatten_enum_newtype() { ], ); } + +#[test] +fn test_flatten_internally_tagged() { + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct S { + #[serde(flatten)] + x: X, + #[serde(flatten)] + y: Y, + } + + #[derive(Serialize, Deserialize, PartialEq, Debug)] + #[serde(tag = "typeX")] + enum X { + A { a: i32 }, + B { b: i32 }, + } + + #[derive(Serialize, Deserialize, PartialEq, Debug)] + #[serde(tag = "typeY")] + enum Y { + C { c: i32 }, + D { d: i32 }, + } + + let s = S { + x: X::B { b: 1 }, + y: Y::D { d: 2 }, + }; + + assert_tokens( + &s, + &[ + Token::Map { len: None }, + Token::Str("typeX"), + Token::Str("B"), + Token::Str("b"), + Token::I32(1), + Token::Str("typeY"), + Token::Str("D"), + Token::Str("d"), + Token::I32(2), + Token::MapEnd, + ], + ); +}