From cc0d045f5c0a9a586f0eb8d4143c1a90214a0dc6 Mon Sep 17 00:00:00 2001 From: Robert O'Callahan Date: Fri, 10 Mar 2017 08:39:27 +1300 Subject: [PATCH] Add Deserialize impl for std::ops::Range Resolves #796 --- serde/src/de/impls.rs | 134 ++++++++++++++++++++++++++++++++++++ test_suite/tests/test_de.rs | 22 ++++++ 2 files changed, 156 insertions(+) diff --git a/serde/src/de/impls.rs b/serde/src/de/impls.rs index 85bf64c4..35f3cdf2 100644 --- a/serde/src/de/impls.rs +++ b/serde/src/de/impls.rs @@ -46,6 +46,9 @@ use alloc::boxed::Box; #[cfg(feature = "std")] use std::time::Duration; +#[cfg(feature = "std")] +use std; + #[cfg(feature = "unstable")] use core::nonzero::{NonZero, Zeroable}; @@ -1108,6 +1111,137 @@ impl Deserialize for Duration { } } +/////////////////////////////////////////////////////////////////////////////// + +// Similar to: +// +// #[derive(Deserialize)] +// #[serde(deny_unknown_fields)] +// struct Range { +// start: u64, +// end: u32, +// } +#[cfg(feature = "std")] +impl Deserialize for std::ops::Range { + fn deserialize(deserializer: D) -> Result + where D: Deserializer + { + enum Field { + Start, + End, + }; + + impl Deserialize for Field { + fn deserialize(deserializer: D) -> Result + where D: Deserializer + { + struct FieldVisitor; + + impl Visitor for FieldVisitor { + type Value = Field; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("`start` or `end`") + } + + fn visit_str(self, value: &str) -> Result + where E: Error + { + match value { + "start" => Ok(Field::Start), + "end" => Ok(Field::End), + _ => Err(Error::unknown_field(value, FIELDS)), + } + } + + fn visit_bytes(self, value: &[u8]) -> Result + where E: Error + { + match value { + b"start" => Ok(Field::Start), + b"end" => Ok(Field::End), + _ => { + let value = String::from_utf8_lossy(value); + Err(Error::unknown_field(&value, FIELDS)) + } + } + } + } + + deserializer.deserialize_struct_field(FieldVisitor) + } + } + + struct RangeVisitor { + phantom: PhantomData, + } + + impl Visitor for RangeVisitor { + type Value = std::ops::Range; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("struct Range") + } + + fn visit_seq(self, mut visitor: V) -> Result, V::Error> + where V: SeqVisitor + { + let start: Idx = match try!(visitor.visit()) { + Some(value) => value, + None => { + return Err(Error::invalid_length(0, &self)); + } + }; + let end: Idx = match try!(visitor.visit()) { + Some(value) => value, + None => { + return Err(Error::invalid_length(1, &self)); + } + }; + Ok(start..end) + } + + fn visit_map(self, mut visitor: V) -> Result, V::Error> + where V: MapVisitor + { + let mut start: Option = None; + let mut end: Option = None; + while let Some(key) = try!(visitor.visit_key::()) { + match key { + Field::Start => { + if start.is_some() { + return Err(::duplicate_field("start")); + } + start = Some(try!(visitor.visit_value())); + } + Field::End => { + if end.is_some() { + return Err(::duplicate_field("end")); + } + end = Some(try!(visitor.visit_value())); + } + } + } + let start = match start { + Some(start) => start, + None => return Err(::missing_field("start")), + }; + let end = match end { + Some(end) => end, + None => return Err(::missing_field("end")), + }; + Ok(start..end) + } + } + + const FIELDS: &'static [&'static str] = &["start", "end"]; + deserializer.deserialize_struct("Range", FIELDS, RangeVisitor { phantom: PhantomData }) + } +} + +/////////////////////////////////////////////////////////////////////////////// + + /////////////////////////////////////////////////////////////////////////////// #[cfg(feature = "unstable")] diff --git a/test_suite/tests/test_de.rs b/test_suite/tests/test_de.rs index f77359e2..066eb2a3 100644 --- a/test_suite/tests/test_de.rs +++ b/test_suite/tests/test_de.rs @@ -868,6 +868,28 @@ declare_tests! { Token::SeqEnd, ], } + test_range { + 1u32..2u32 => &[ + Token::StructStart("Range", 2), + Token::StructSep, + Token::Str("start"), + Token::U32(1), + + Token::StructSep, + Token::Str("end"), + Token::U32(2), + Token::StructEnd, + ], + 1u32..2u32 => &[ + Token::SeqStart(Some(2)), + Token::SeqSep, + Token::U64(1), + + Token::SeqSep, + Token::U64(2), + Token::SeqEnd, + ], + } test_net_ipv4addr { "1.2.3.4".parse::().unwrap() => &[Token::Str("1.2.3.4")], }