diff --git a/serde/src/de/value.rs b/serde/src/de/value.rs index 3a522969..d0a185da 100644 --- a/serde/src/de/value.rs +++ b/serde/src/de/value.rs @@ -1287,10 +1287,40 @@ where visitor.visit_map(self.map) } + fn deserialize_enum( + self, + _name: &str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + visitor.visit_enum(self) + } + forward_to_deserialize_any! { bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string bytes byte_buf option unit unit_struct newtype_struct seq tuple - tuple_struct map struct enum identifier ignored_any + tuple_struct map struct identifier ignored_any + } +} + +impl<'de, A> de::EnumAccess<'de> for MapAccessDeserializer +where + A: de::MapAccess<'de>, +{ + type Error = A::Error; + type Variant = private::MapAsEnum; + + fn variant_seed(mut self, seed: T) -> Result<(T::Value, Self::Variant), Self::Error> + where + T: de::DeserializeSeed<'de>, + { + match self.map.next_key_seed(seed)? { + Some(key) => Ok((key, private::map_as_enum(self.map))), + None => Err(de::Error::invalid_type(de::Unexpected::Map, &"enum")), + } } } @@ -1299,7 +1329,7 @@ where mod private { use lib::*; - use de::{self, Unexpected}; + use de::{self, DeserializeSeed, Deserializer, MapAccess, Unexpected, VariantAccess, Visitor}; #[derive(Clone, Debug)] pub struct UnitOnly { @@ -1360,6 +1390,92 @@ mod private { } } + #[derive(Clone, Debug)] + pub struct MapAsEnum { + map: A, + } + + pub fn map_as_enum(map: A) -> MapAsEnum { + MapAsEnum { map: map } + } + + impl<'de, A> VariantAccess<'de> for MapAsEnum + where + A: MapAccess<'de>, + { + type Error = A::Error; + + fn unit_variant(mut self) -> Result<(), Self::Error> { + self.map.next_value() + } + + fn newtype_variant_seed(mut self, seed: T) -> Result + where + T: DeserializeSeed<'de>, + { + self.map.next_value_seed(seed) + } + + fn tuple_variant(mut self, len: usize, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.map.next_value_seed(SeedTupleVariant { + len: len, + visitor: visitor, + }) + } + + fn struct_variant( + mut self, + _fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + self.map + .next_value_seed(SeedStructVariant { visitor: visitor }) + } + } + + struct SeedTupleVariant { + len: usize, + visitor: V, + } + + impl<'de, V> DeserializeSeed<'de> for SeedTupleVariant + where + V: Visitor<'de>, + { + type Value = V::Value; + + fn deserialize(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_tuple(self.len, self.visitor) + } + } + + struct SeedStructVariant { + visitor: V, + } + + impl<'de, V> DeserializeSeed<'de> for SeedStructVariant + where + V: Visitor<'de>, + { + type Value = V::Value; + + fn deserialize(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_map(self.visitor) + } + } + /// Avoid having to restate the generic types on `MapDeserializer`. The /// `Iterator::Item` contains enough information to figure out K and V. pub trait Pair { diff --git a/test_suite/tests/test_value.rs b/test_suite/tests/test_value.rs index 8d8413cb..0bb169df 100644 --- a/test_suite/tests/test_value.rs +++ b/test_suite/tests/test_value.rs @@ -1,5 +1,8 @@ -use serde::de::{value, IntoDeserializer}; -use serde::Deserialize; +use serde::de::value::{self, MapAccessDeserializer}; +use serde::de::{IntoDeserializer, MapAccess, Visitor}; +use serde::{Deserialize, Deserializer}; +use serde_test::{assert_de_tokens, Token}; +use std::fmt; #[test] fn test_u32_to_enum() { @@ -32,3 +35,60 @@ fn test_integer128() { // i128 to i128 assert_eq!(1i128, i128::deserialize(de_i128).unwrap()); } + +#[test] +fn test_map_access_to_enum() { + #[derive(PartialEq, Debug)] + struct Potential(PotentialKind); + + #[derive(PartialEq, Debug, Deserialize)] + enum PotentialKind { + Airebo(Airebo), + } + + #[derive(PartialEq, Debug, Deserialize)] + struct Airebo { + lj_sigma: f64, + } + + impl<'de> Deserialize<'de> for Potential { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct PotentialVisitor; + + impl<'de> Visitor<'de> for PotentialVisitor { + type Value = Potential; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "a map") + } + + fn visit_map(self, map: A) -> Result + where + A: MapAccess<'de>, + { + Deserialize::deserialize(MapAccessDeserializer::new(map)).map(Potential) + } + } + + deserializer.deserialize_any(PotentialVisitor) + } + } + + let expected = Potential(PotentialKind::Airebo(Airebo { lj_sigma: 14.0 })); + + assert_de_tokens( + &expected, + &[ + Token::Map { len: Some(1) }, + Token::Str("Airebo"), + Token::Map { len: Some(1) }, + Token::Str("lj_sigma"), + Token::F64(14.0), + Token::MapEnd, + Token::MapEnd, + ], + ); +}