From d17846eff1f4cda0e80a6526be285f1b9252c52d Mon Sep 17 00:00:00 2001 From: Erick Tryzelaar Date: Fri, 20 Mar 2015 08:32:33 -0700 Subject: [PATCH] Add deserializer type hinting hooks Formats like xml have trouble knowing if they should deserialize tags into a sequence from the stream they are deserializing from. This PR adds hooks so the deserializee can inform the deserializer to provide them a sequence if possible. Closes #38. --- benches/bench_struct.rs | 93 ++++++++++++++++++------------ serde_macros/src/de.rs | 43 +------------- src/de.rs | 81 +++++++++++++++++++-------- tests/test_de.rs | 121 +++++++++++++++++++++++++++------------- 4 files changed, 200 insertions(+), 138 deletions(-) diff --git a/benches/bench_struct.rs b/benches/bench_struct.rs index c625a555..c77a782d 100644 --- a/benches/bench_struct.rs +++ b/benches/bench_struct.rs @@ -30,7 +30,7 @@ pub struct Outer { ////////////////////////////////////////////////////////////////////////////// -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub enum Error { EndOfStream, SyntaxError, @@ -366,30 +366,6 @@ mod deserializer { where V: de::Visitor, { match self.stack.pop() { - Some(State::OuterState(Outer { inner })) => { - self.stack.push(State::VecState(inner)); - self.stack.push(State::StrState("inner")); - - visitor.visit_named_map("Outer", OuterMapVisitor { - de: self, - state: 0, - }) - } - Some(State::InnerState(Inner { a: (), b, c })) => { - self.stack.push(State::MapState(c)); - self.stack.push(State::StrState("c")); - - self.stack.push(State::UsizeState(b)); - self.stack.push(State::StrState("b")); - - self.stack.push(State::NullState); - self.stack.push(State::StrState("a")); - - visitor.visit_named_map("Inner", InnerMapVisitor { - de: self, - state: 0, - }) - } Some(State::VecState(value)) => { visitor.visit_seq(OuterSeqVisitor { de: self, @@ -423,9 +399,52 @@ mod deserializer { Some(State::OptionState(true)) => { visitor.visit_some(self) } + Some(token) => Err(Error::SyntaxError), None => Err(Error::EndOfStream), } } + + fn visit_named_map(&mut self, name: &str, mut visitor: V) -> Result + where V: de::Visitor, + { + match self.stack.pop() { + Some(State::OuterState(Outer { inner })) => { + if name != "Outer" { + return Err(Error::SyntaxError); + } + + self.stack.push(State::VecState(inner)); + self.stack.push(State::StrState("inner")); + + visitor.visit_map(OuterMapVisitor { + de: self, + state: 0, + }) + } + Some(State::InnerState(Inner { a: (), b, c })) => { + if name != "Inner" { + return Err(Error::SyntaxError); + } + + self.stack.push(State::MapState(c)); + self.stack.push(State::StrState("c")); + + self.stack.push(State::UsizeState(b)); + self.stack.push(State::StrState("b")); + + self.stack.push(State::NullState); + self.stack.push(State::StrState("a")); + + visitor.visit_map(InnerMapVisitor { + de: self, + state: 0, + }) + } + _ => { + Err(Error::SyntaxError) + } + } + } } struct OuterMapVisitor<'a> { @@ -605,9 +624,9 @@ fn bench_decoder_0_0(b: &mut Bencher) { }; let mut d = decoder::OuterDecoder::new(outer.clone()); - let value: Outer = Decodable::decode(&mut d).unwrap(); + let value: Result = Decodable::decode(&mut d); - assert_eq!(value, outer); + assert_eq!(value, Ok(outer)); }) } @@ -627,9 +646,9 @@ fn bench_decoder_1_0(b: &mut Bencher) { }; let mut d = decoder::OuterDecoder::new(outer.clone()); - let value: Outer = Decodable::decode(&mut d).unwrap(); + let value: Result = Decodable::decode(&mut d); - assert_eq!(value, outer); + assert_eq!(value, Ok(outer)); }) } @@ -654,9 +673,9 @@ fn bench_decoder_1_5(b: &mut Bencher) { }; let mut d = decoder::OuterDecoder::new(outer.clone()); - let value: Outer = Decodable::decode(&mut d).unwrap(); + let value: Result = Decodable::decode(&mut d); - assert_eq!(value, outer); + assert_eq!(value, Ok(outer)); }) } @@ -668,9 +687,9 @@ fn bench_deserializer_0_0(b: &mut Bencher) { }; let mut d = deserializer::OuterDeserializer::new(outer.clone()); - let value: Outer = Deserialize::deserialize(&mut d).unwrap(); + let value: Result = Deserialize::deserialize(&mut d); - assert_eq!(value, outer); + assert_eq!(value, Ok(outer)); }) } @@ -690,9 +709,9 @@ fn bench_deserializer_1_0(b: &mut Bencher) { }; let mut d = deserializer::OuterDeserializer::new(outer.clone()); - let value: Outer = Deserialize::deserialize(&mut d).unwrap(); + let value: Result = Deserialize::deserialize(&mut d); - assert_eq!(value, outer); + assert_eq!(value, Ok(outer)); }) } @@ -717,8 +736,8 @@ fn bench_deserializer_1_5(b: &mut Bencher) { }; let mut d = deserializer::OuterDeserializer::new(outer.clone()); - let value: Outer = Deserialize::deserialize(&mut d).unwrap(); + let value: Result = Deserialize::deserialize(&mut d); - assert_eq!(value, outer); + assert_eq!(value, Ok(outer)); }) } diff --git a/serde_macros/src/de.rs b/serde_macros/src/de.rs index ecbfc95d..b48e61ca 100644 --- a/serde_macros/src/de.rs +++ b/serde_macros/src/de.rs @@ -205,18 +205,6 @@ fn deserialize_unit_struct( Ok($type_ident) } - #[inline] - fn visit_named_unit< - E: ::serde::de::Error, - >(&mut self, name: &str) -> Result<$type_ident, E> { - if name == $type_name { - self.visit_unit() - } else { - Err(::serde::de::Error::syntax_error()) - } - } - - #[inline] fn visit_seq(&mut self, mut visitor: V) -> Result<$type_ident, V::Error> where V: ::serde::de::SeqVisitor, @@ -226,7 +214,7 @@ fn deserialize_unit_struct( } } - deserializer.visit(__Visitor) + deserializer.visit_named_unit($type_name, __Visitor) }) } @@ -265,21 +253,9 @@ fn deserialize_tuple_struct( { $visit_seq_expr } - - fn visit_named_seq<__V>(&mut self, - name: &str, - visitor: __V) -> Result<$ty, __V::Error> - where __V: ::serde::de::SeqVisitor, - { - if name == $type_name { - self.visit_seq(visitor) - } else { - Err(::serde::de::Error::syntax_error()) - } - } } - deserializer.visit($visitor_expr) + deserializer.visit_named_seq($type_name, $visitor_expr) }) } @@ -355,22 +331,9 @@ fn deserialize_struct( { $visit_map_expr } - - #[inline] - fn visit_named_map<__V>(&mut self, - name: &str, - visitor: __V) -> Result<$ty, __V::Error> - where __V: ::serde::de::MapVisitor, - { - if name == $type_name { - self.visit_map(visitor) - } else { - Err(::serde::de::Error::syntax_error()) - } - } } - deserializer.visit($visitor_expr) + deserializer.visit_named_map($type_name, $visitor_expr) }) } diff --git a/src/de.rs b/src/de.rs index 31d468c8..08abb83c 100644 --- a/src/de.rs +++ b/src/de.rs @@ -18,22 +18,24 @@ pub trait Error { /////////////////////////////////////////////////////////////////////////////// pub trait Deserialize { + /// Deserialize this value given this `Deserializer`. fn deserialize(deserializer: &mut D) -> Result where D: Deserializer; } /////////////////////////////////////////////////////////////////////////////// +/// `Deserializer` is an abstract trait that can deserialize values into a `Visitor`. pub trait Deserializer { type Error: Error; + /// The `visit` method walks a visitor through a value as it is being deserialized. fn visit(&mut self, visitor: V) -> Result where V: Visitor; - /// The `visit_option` method allows a `Deserialize` type to inform the - /// `Deserializer` that it's expecting an optional value. This allows - /// deserializers that encode an optional value as a nullable value to - /// convert the null value into a `None`, and a regular value as + /// The `visit_option` method allows a `Deserialize` type to inform the `Deserializer` that + /// it's expecting an optional value. This allows deserializers that encode an optional value + /// as a nullable value to convert the null value into a `None`, and a regular value as /// `Some(value)`. #[inline] fn visit_option(&mut self, visitor: V) -> Result @@ -42,10 +44,59 @@ pub trait Deserializer { self.visit(visitor) } - /// The `visit_enum` method allows a `Deserialize` type to inform the - /// `Deserializer` that it's expecting an enum value. This allows - /// deserializers that provide a custom enumeration serialization to - /// properly deserialize the type. + /// The `visit_seq` method allows a `Deserialize` type to inform the `Deserializer` that it's + /// expecting a sequence of values. This allows deserializers to parse sequences that aren't + /// tagged as sequences. + #[inline] + fn visit_seq(&mut self, visitor: V) -> Result + where V: Visitor, + { + self.visit(visitor) + } + + /// The `visit_map` method allows a `Deserialize` type to inform the `Deserializer` that it's + /// expecting a map of values. This allows deserializers to parse sequences that aren't tagged + /// as maps. + #[inline] + fn visit_map(&mut self, visitor: V) -> Result + where V: Visitor, + { + self.visit(visitor) + } + + /// The `visit_named_unit` method allows a `Deserialize` type to inform the `Deserializer` that + /// it's expecting a named unit. This allows deserializers to a named unit that aren't tagged + /// as a named unit. + #[inline] + fn visit_named_unit(&mut self, _name: &str, visitor: V) -> Result + where V: Visitor, + { + self.visit(visitor) + } + + /// The `visit_named_seq` method allows a `Deserialize` type to inform the `Deserializer` that + /// it's expecting a named sequence of values. This allows deserializers to parse sequences + /// that aren't tagged as sequences. + #[inline] + fn visit_named_seq(&mut self, _name: &str, visitor: V) -> Result + where V: Visitor, + { + self.visit_seq(visitor) + } + + /// The `visit_named_map` method allows a `Deserialize` type to inform the `Deserializer` that + /// it's expecting a map of values. This allows deserializers to parse sequences that aren't + /// tagged as maps. + #[inline] + fn visit_named_map(&mut self, _name: &str, visitor: V) -> Result + where V: Visitor, + { + self.visit_map(visitor) + } + + /// The `visit_enum` method allows a `Deserialize` type to inform the `Deserializer` that it's + /// expecting an enum value. This allows deserializers that provide a custom enumeration + /// serialization to properly deserialize the type. #[inline] fn visit_enum(&mut self, _enum: &str, _visitor: V) -> Result where V: EnumVisitor, @@ -191,25 +242,11 @@ pub trait Visitor { Err(Error::syntax_error()) } - #[inline] - fn visit_named_seq(&mut self, _name: &str, visitor: V) -> Result - where V: SeqVisitor, - { - self.visit_seq(visitor) - } - fn visit_map(&mut self, _visitor: V) -> Result where V: MapVisitor, { Err(Error::syntax_error()) } - - #[inline] - fn visit_named_map(&mut self, _name: &str, visitor: V) -> Result - where V: MapVisitor, - { - self.visit_map(visitor) - } } /////////////////////////////////////////////////////////////////////////////// diff --git a/tests/test_de.rs b/tests/test_de.rs index 53e3a50e..fe584d50 100644 --- a/tests/test_de.rs +++ b/tests/test_de.rs @@ -11,7 +11,7 @@ use std::vec; use serde::de::{self, Deserialize, Deserializer, Visitor}; #[derive(Debug)] -enum Token<'a> { +enum Token { Bool(bool), Isize(isize), I8(i8), @@ -26,34 +26,33 @@ enum Token<'a> { F32(f32), F64(f64), Char(char), - Str(&'a str), + Str(&'static str), String(String), Option(bool), + Name(&'static str), + Unit, - NamedUnit(&'a str), SeqStart(usize), - NamedSeqStart(&'a str, usize), SeqSep(bool), SeqEnd, MapStart(usize), - NamedMapStart(&'a str, usize), MapSep(bool), MapEnd, - EnumStart(&'a str), + EnumStart(&'static str), EnumEnd, } -struct TokenDeserializer<'a> { - tokens: iter::Peekable>>, +struct TokenDeserializer { + tokens: iter::Peekable>, } -impl<'a> TokenDeserializer<'a> { - fn new(tokens: Vec>) -> TokenDeserializer<'a> { +impl<'a> TokenDeserializer { + fn new(tokens: Vec) -> TokenDeserializer { TokenDeserializer { tokens: tokens.into_iter().peekable(), } @@ -65,6 +64,7 @@ enum Error { SyntaxError, EndOfStreamError, MissingFieldError(&'static str), + InvalidName(&'static str), } impl de::Error for Error { @@ -77,7 +77,7 @@ impl de::Error for Error { } } -impl<'a> Deserializer for TokenDeserializer<'a> { +impl Deserializer for TokenDeserializer { type Error = Error; fn visit(&mut self, mut visitor: V) -> Result @@ -103,7 +103,6 @@ impl<'a> Deserializer for TokenDeserializer<'a> { Some(Token::Option(false)) => visitor.visit_none(), Some(Token::Option(true)) => visitor.visit_some(self), Some(Token::Unit) => visitor.visit_unit(), - Some(Token::NamedUnit(name)) => visitor.visit_named_unit(name), Some(Token::SeqStart(len)) => { visitor.visit_seq(TokenDeserializerSeqVisitor { de: self, @@ -111,13 +110,6 @@ impl<'a> Deserializer for TokenDeserializer<'a> { first: true, }) } - Some(Token::NamedSeqStart(name, len)) => { - visitor.visit_named_seq(name, TokenDeserializerSeqVisitor { - de: self, - len: len, - first: true, - }) - } Some(Token::MapStart(len)) => { visitor.visit_map(TokenDeserializerMapVisitor { de: self, @@ -125,13 +117,7 @@ impl<'a> Deserializer for TokenDeserializer<'a> { first: true, }) } - Some(Token::NamedMapStart(name, len)) => { - visitor.visit_named_map(name, TokenDeserializerMapVisitor { - de: self, - len: len, - first: true, - }) - } + Some(Token::Name(_)) => self.visit(visitor), Some(_) => Err(Error::SyntaxError), None => Err(Error::EndOfStreamError), } @@ -177,17 +163,68 @@ impl<'a> Deserializer for TokenDeserializer<'a> { None => Err(Error::EndOfStreamError), } } + + fn visit_named_unit(&mut self, name: &str, visitor: V) -> Result + where V: de::Visitor, + { + match self.tokens.peek() { + Some(&Token::Name(n)) => { + if name == n { + self.tokens.next(); + self.visit_seq(visitor) + } else { + Err(Error::InvalidName(n)) + } + } + Some(_) => self.visit(visitor), + None => Err(Error::EndOfStreamError), + } + } + + fn visit_named_seq(&mut self, name: &str, visitor: V) -> Result + where V: de::Visitor, + { + match self.tokens.peek() { + Some(&Token::Name(n)) => { + if name == n { + self.tokens.next(); + self.visit_seq(visitor) + } else { + Err(Error::InvalidName(n)) + } + } + Some(_) => self.visit_seq(visitor), + None => Err(Error::EndOfStreamError), + } + } + + fn visit_named_map(&mut self, name: &str, visitor: V) -> Result + where V: de::Visitor, + { + match self.tokens.peek() { + Some(&Token::Name(n)) => { + if name == n { + self.tokens.next(); + self.visit_map(visitor) + } else { + Err(Error::InvalidName(n)) + } + } + Some(_) => self.visit_map(visitor), + None => Err(Error::EndOfStreamError), + } + } } ////////////////////////////////////////////////////////////////////////// -struct TokenDeserializerSeqVisitor<'a, 'b: 'a> { - de: &'a mut TokenDeserializer<'b>, +struct TokenDeserializerSeqVisitor<'a> { + de: &'a mut TokenDeserializer, len: usize, first: bool, } -impl<'a, 'b> de::SeqVisitor for TokenDeserializerSeqVisitor<'a, 'b> { +impl<'a> de::SeqVisitor for TokenDeserializerSeqVisitor<'a> { type Error = Error; fn visit(&mut self) -> Result, Error> @@ -226,13 +263,13 @@ impl<'a, 'b> de::SeqVisitor for TokenDeserializerSeqVisitor<'a, 'b> { ////////////////////////////////////////////////////////////////////////// -struct TokenDeserializerMapVisitor<'a, 'b: 'a> { - de: &'a mut TokenDeserializer<'b>, +struct TokenDeserializerMapVisitor<'a> { + de: &'a mut TokenDeserializer, len: usize, first: bool, } -impl<'a, 'b> de::MapVisitor for TokenDeserializerMapVisitor<'a, 'b> { +impl<'a> de::MapVisitor for TokenDeserializerMapVisitor<'a> { type Error = Error; fn visit_key(&mut self) -> Result, Error> @@ -275,11 +312,11 @@ impl<'a, 'b> de::MapVisitor for TokenDeserializerMapVisitor<'a, 'b> { ////////////////////////////////////////////////////////////////////////// -struct TokenDeserializerVariantVisitor<'a, 'b: 'a> { - de: &'a mut TokenDeserializer<'b>, +struct TokenDeserializerVariantVisitor<'a> { + de: &'a mut TokenDeserializer, } -impl<'a, 'b> de::VariantVisitor for TokenDeserializerVariantVisitor<'a, 'b> { +impl<'a> de::VariantVisitor for TokenDeserializerVariantVisitor<'a> { type Error = Error; fn visit_variant(&mut self) -> Result @@ -418,13 +455,17 @@ declare_tests! { Token::SeqEnd, ], () => vec![ - Token::NamedSeqStart("Anything", 0), + Token::Name("Anything"), + Token::SeqStart(0), Token::SeqEnd, ], } test_named_unit { NamedUnit => vec![Token::Unit], - NamedUnit => vec![Token::NamedUnit("NamedUnit")], + NamedUnit => vec![ + Token::Name("NamedUnit"), + Token::Unit, + ], NamedUnit => vec![ Token::SeqStart(0), Token::SeqEnd, @@ -444,7 +485,8 @@ declare_tests! { Token::SeqEnd, ], NamedSeq(1, 2, 3) => vec![ - Token::NamedSeqStart("NamedSeq", 3), + Token::Name("NamedSeq"), + Token::SeqStart(3), Token::SeqSep(true), Token::I32(1), @@ -561,7 +603,8 @@ declare_tests! { Token::MapEnd, ], NamedMap { a: 1, b: 2, c: 3 } => vec![ - Token::NamedMapStart("NamedMap", 3), + Token::Name("NamedMap"), + Token::MapStart(3), Token::MapSep(true), Token::Str("a"), Token::I32(1),