Rewrite enum deserialization to not require allocations

This commit is contained in:
Erick Tryzelaar 2015-03-15 22:04:17 -07:00
parent b40d8f7bac
commit 78137ee3a4
7 changed files with 273 additions and 147 deletions

View File

@ -245,9 +245,11 @@ mod deserializer {
use serde::de;
#[derive(Debug)]
enum State {
AnimalState(Animal),
IsizeState(isize),
StrState(&'static str),
StringState(String),
UnitState,
}
@ -273,29 +275,51 @@ mod deserializer {
where V: de::Visitor,
{
match self.stack.pop() {
Some(State::AnimalState(Animal::Dog)) => {
self.stack.push(State::UnitState);
visitor.visit_enum("Animal", "Dog", DogVisitor {
de: self,
})
}
Some(State::AnimalState(Animal::Frog(x0, x1))) => {
self.stack.push(State::IsizeState(x1));
self.stack.push(State::StringState(x0));
visitor.visit_enum("Animal", "Frog", FrogVisitor {
de: self,
state: 0,
})
}
Some(State::IsizeState(value)) => {
visitor.visit_isize(value)
}
Some(State::StringState(value)) => {
visitor.visit_string(value)
}
Some(State::StrState(value)) => {
visitor.visit_str(value)
}
Some(State::UnitState) => {
visitor.visit_unit()
}
Some(_) => {
Err(Error::SyntaxError)
}
None => {
Err(Error::EndOfStreamError)
}
}
}
#[inline]
fn visit_enum<V>(&mut self, _name: &str, mut visitor: V) -> Result<V::Value, Error>
where V: de::EnumVisitor,
{
match self.stack.pop() {
Some(State::AnimalState(Animal::Dog)) => {
self.stack.push(State::UnitState);
self.stack.push(State::StrState("Dog"));
visitor.visit(DogVisitor {
de: self,
})
}
Some(State::AnimalState(Animal::Frog(x0, x1))) => {
self.stack.push(State::IsizeState(x1));
self.stack.push(State::StringState(x0));
self.stack.push(State::StrState("Frog"));
visitor.visit(FrogVisitor {
de: self,
state: 0,
})
}
Some(_) => {
Err(Error::SyntaxError)
}
None => {
Err(Error::EndOfStreamError)
}
@ -307,12 +331,30 @@ mod deserializer {
de: &'a mut AnimalDeserializer,
}
impl<'a> de::EnumVisitor for DogVisitor<'a> {
impl<'a> de::VariantVisitor for DogVisitor<'a> {
type Error = Error;
fn visit_variant<V>(&mut self) -> Result<V, Error>
where V: de::Deserialize
{
de::Deserialize::deserialize(self.de)
}
fn visit_unit(&mut self) -> Result<(), Error> {
de::Deserialize::deserialize(self.de)
}
fn visit_seq<V>(&mut self, _visitor: V) -> Result<V::Value, Error>
where V: de::EnumSeqVisitor
{
Err(de::Error::syntax_error())
}
fn visit_map<V>(&mut self, _visitor: V) -> Result<V::Value, Error>
where V: de::EnumMapVisitor
{
Err(de::Error::syntax_error())
}
}
struct FrogVisitor<'a> {
@ -320,14 +362,30 @@ mod deserializer {
state: usize,
}
impl<'a> de::EnumVisitor for FrogVisitor<'a> {
impl<'a> de::VariantVisitor for FrogVisitor<'a> {
type Error = Error;
fn visit_variant<V>(&mut self) -> Result<V, Error>
where V: de::Deserialize
{
de::Deserialize::deserialize(self.de)
}
fn visit_unit(&mut self) -> Result<(), Error> {
Err(de::Error::syntax_error())
}
fn visit_seq<V>(&mut self, mut visitor: V) -> Result<V::Value, Error>
where V: de::EnumSeqVisitor,
{
visitor.visit(self)
}
fn visit_map<V>(&mut self, _visitor: V) -> Result<V::Value, Error>
where V: de::EnumMapVisitor
{
Err(de::Error::syntax_error())
}
}
impl<'a> de::SeqVisitor for FrogVisitor<'a> {

View File

@ -386,17 +386,32 @@ fn deserialize_item_enum(
let type_name = builder.expr().str(type_ident);
let variant_visitor = deserialize_field_visitor(
cx,
builder,
enum_def.variants.iter()
.map(|variant| builder.expr().str(variant.node.name))
.collect()
);
// Match arms to extract a variant from a string
let variant_arms: Vec<_> = enum_def.variants.iter()
.map(|variant| {
deserialize_variant(
.enumerate()
.map(|(i, variant)| {
let variant_name = builder.expr().path()
.id("__Field").id(format!("__field{}", i))
.build();
let expr = deserialize_variant(
cx,
builder,
type_ident,
impl_generics,
ty.clone(),
variant,
)
);
quote_arm!(cx, $variant_name => { $expr })
})
.collect();
@ -406,37 +421,23 @@ fn deserialize_item_enum(
);
quote_expr!(cx, {
$variant_visitor
$visitor_item
impl $impl_generics ::serde::de::Visitor for $visitor_ty $where_clause {
impl $impl_generics ::serde::de::EnumVisitor for $visitor_ty $where_clause {
type Value = $ty;
fn visit_enum<__V>(&mut self,
name: &str,
variant: &str,
visitor: __V) -> Result<$ty, __V::Error>
where __V: ::serde::de::EnumVisitor,
fn visit<__V>(&mut self, mut visitor: __V) -> Result<$ty, __V::Error>
where __V: ::serde::de::VariantVisitor,
{
if name == $type_name {
self.visit_variant(variant, visitor)
} else {
Err(::serde::de::Error::syntax_error())
}
}
fn visit_variant<__V>(&mut self,
name: &str,
mut visitor: __V) -> Result<$ty, __V::Error>
where __V: ::serde::de::EnumVisitor
{
match name {
match try!(visitor.visit_variant()) {
$variant_arms
_ => Err(::serde::de::Error::syntax_error()),
}
}
}
deserializer.visit_enum($visitor_expr)
deserializer.visit_enum($type_name, $visitor_expr)
})
}
@ -447,21 +448,18 @@ fn deserialize_variant(
generics: &ast::Generics,
ty: P<ast::Ty>,
variant: &ast::Variant,
) -> ast::Arm {
) -> P<ast::Expr> {
let variant_ident = variant.node.name;
let variant_name = builder.expr().str(variant_ident);
match variant.node.kind {
ast::TupleVariantKind(ref args) if args.is_empty() => {
quote_arm!(cx,
$variant_name => {
try!(visitor.visit_unit());
Ok($type_ident::$variant_ident)
}
)
quote_expr!(cx, {
try!(visitor.visit_unit());
Ok($type_ident::$variant_ident)
})
}
ast::TupleVariantKind(ref args) => {
let expr = deserialize_tuple_variant(
deserialize_tuple_variant(
cx,
builder,
type_ident,
@ -469,12 +467,10 @@ fn deserialize_variant(
generics,
ty,
args.len(),
);
quote_arm!(cx, $variant_name => { $expr })
)
}
ast::StructVariantKind(ref struct_def) => {
let expr = deserialize_struct_variant(
deserialize_struct_variant(
cx,
builder,
type_ident,
@ -482,9 +478,7 @@ fn deserialize_variant(
generics,
ty,
struct_def,
);
quote_arm!(cx, $variant_name => { $expr })
)
}
}
}
@ -574,10 +568,10 @@ fn deserialize_struct_variant(
fn deserialize_field_visitor(
cx: &ExtCtxt,
builder: &aster::AstBuilder,
struct_def: &StructDef,
field_exprs: Vec<P<ast::Expr>>,
) -> Vec<P<ast::Item>> {
// Create the field names for the fields.
let field_names: Vec<ast::Ident> = (0 .. struct_def.fields.len())
let field_idents: Vec<ast::Ident> = (0 .. field_exprs.len())
.map(|i| builder.id(format!("__field{}", i)))
.collect();
@ -585,20 +579,17 @@ fn deserialize_field_visitor(
.attr().allow(&["non_camel_case_types"])
.enum_("__Field")
.with_variants(
field_names.iter().map(|field| {
builder.variant(field).tuple().build()
field_idents.iter().map(|field_ident| {
builder.variant(field_ident).tuple().build()
})
)
.build();
// Get aliases
let aliases = field::struct_field_strs(cx, builder, struct_def);
// Match arms to extract a field from a string
let field_arms: Vec<ast::Arm> = aliases.iter()
.zip(field_names.iter())
.map(|(alias, field_name)| {
quote_arm!(cx, $alias => { Ok(__Field::$field_name) })
let field_arms: Vec<_> = field_idents.iter()
.zip(field_exprs.into_iter())
.map(|(field_ident, field_expr)| {
quote_arm!(cx, $field_expr => { Ok(__Field::$field_ident) })
})
.collect();
@ -642,7 +633,7 @@ fn deserialize_struct_visitor(
let field_visitor = deserialize_field_visitor(
cx,
builder,
struct_def,
field::struct_field_strs(cx, builder, struct_def),
);
let visit_map_expr = deserialize_map(

View File

@ -43,10 +43,10 @@ pub trait Deserializer {
/// deserializers that provide a custom enumeration serialization to
/// properly deserialize the type.
#[inline]
fn visit_enum<V>(&mut self, visitor: V) -> Result<V::Value, Self::Error>
where V: Visitor,
fn visit_enum<V>(&mut self, _enum: &str, _visitor: V) -> Result<V::Value, Self::Error>
where V: EnumVisitor,
{
self.visit(visitor)
Err(Error::syntax_error())
}
}
@ -204,23 +204,6 @@ pub trait Visitor {
{
self.visit_map(visitor)
}
#[inline]
fn visit_enum<V>(&mut self,
_name: &str,
_variant: &str,
_visitor: V) -> Result<Self::Value, V::Error>
where V: EnumVisitor,
{
Err(Error::syntax_error())
}
#[inline]
fn visit_variant<V>(&mut self, _name: &str, _visitor: V) -> Result<Self::Value, V::Error>
where V: EnumVisitor,
{
Err(Error::syntax_error())
}
}
///////////////////////////////////////////////////////////////////////////////
@ -338,22 +321,53 @@ impl<'a, V_> MapVisitor for &'a mut V_ where V_: MapVisitor {
///////////////////////////////////////////////////////////////////////////////
pub trait EnumVisitor {
type Value;
fn visit<V>(&mut self, visitor: V) -> Result<Self::Value, V::Error>
where V: VariantVisitor;
}
///////////////////////////////////////////////////////////////////////////////
pub trait VariantVisitor {
type Error: Error;
fn visit_unit(&mut self) -> Result<(), Self::Error> {
Err(Error::syntax_error())
}
fn visit_variant<V>(&mut self) -> Result<V, Self::Error>
where V: Deserialize;
fn visit_unit(&mut self) -> Result<(), Self::Error>;
fn visit_seq<V>(&mut self, _visitor: V) -> Result<V::Value, Self::Error>
where V: EnumSeqVisitor,
{
Err(Error::syntax_error())
}
where V: EnumSeqVisitor;
fn visit_map<V>(&mut self, _visitor: V) -> Result<V::Value, Self::Error>
where V: EnumMapVisitor,
where V: EnumMapVisitor;
}
impl<'a, T> VariantVisitor for &'a mut T where T: VariantVisitor {
type Error = T::Error;
fn visit_variant<V>(&mut self) -> Result<V, T::Error>
where V: Deserialize
{
Err(Error::syntax_error())
(**self).visit_variant()
}
{
fn visit_unit(&mut self) -> Result<(), T::Error> {
(**self).visit_unit()
}
fn visit_seq<V>(&mut self, visitor: V) -> Result<V::Value, T::Error>
where V: EnumSeqVisitor
{
(**self).visit_seq(visitor)
}
fn visit_map<V>(&mut self, visitor: V) -> Result<V::Value, T::Error>
where V: EnumMapVisitor
{
(**self).visit_map(visitor)
}
}

View File

@ -409,8 +409,8 @@ impl<Iter> de::Deserializer for Deserializer<Iter>
}
#[inline]
fn visit_enum<V>(&mut self, mut visitor: V) -> Result<V::Value, Error>
where V: de::Visitor,
fn visit_enum<V>(&mut self, _name: &str, mut visitor: V) -> Result<V::Value, Error>
where V: de::EnumVisitor,
{
self.parse_whitespace();
@ -418,14 +418,9 @@ impl<Iter> de::Deserializer for Deserializer<Iter>
self.bump();
self.parse_whitespace();
try!(self.parse_string());
try!(self.parse_object_colon());
let variant = str::from_utf8(&self.buf).unwrap().to_string();
let value = try!(visitor.visit_variant(&variant, EnumVisitor {
de: self,
}));
let value = {
try!(visitor.visit(&mut *self))
};
self.parse_whitespace();
@ -433,7 +428,7 @@ impl<Iter> de::Deserializer for Deserializer<Iter>
self.bump();
Ok(value)
} else {
return Err(self.error(ErrorCode::ExpectedSomeValue));
Err(self.error(ErrorCode::ExpectedSomeValue))
}
} else {
Err(self.error(ErrorCode::ExpectedSomeValue))
@ -541,7 +536,6 @@ impl<'a, Iter> de::MapVisitor for MapVisitor<'a, Iter>
}
if self.de.eof() {
println!("here3");
return Err(self.de.error(ErrorCode::EOFWhileParsingValue));
}
@ -599,42 +593,58 @@ impl<'a, Iter> de::MapVisitor for MapVisitor<'a, Iter>
}
}
struct EnumVisitor<'a, Iter: 'a> {
de: &'a mut Deserializer<Iter>,
}
impl<'a, Iter> de::EnumVisitor for EnumVisitor<'a, Iter>
impl<Iter> de::VariantVisitor for Deserializer<Iter>
where Iter: Iterator<Item=u8>,
{
type Error = Error;
fn visit_variant<V>(&mut self) -> Result<V, Error>
where V: de::Deserialize
{
de::Deserialize::deserialize(self)
}
/*
fn visit_value<V>(&mut self) -> Result<V, Error>
where V: de::Deserialize
{
de::Deserialize::deserialize(self)
}
*/
fn visit_unit(&mut self) -> Result<(), Error> {
de::Deserialize::deserialize(self.de)
try!(self.parse_object_colon());
de::Deserialize::deserialize(self)
}
fn visit_seq<V>(&mut self, mut visitor: V) -> Result<V::Value, Error>
where V: de::EnumSeqVisitor,
where V: de::EnumSeqVisitor
{
self.de.parse_whitespace();
try!(self.parse_object_colon());
if self.de.ch_is(b'[') {
self.de.bump();
visitor.visit(SeqVisitor::new(self.de))
self.parse_whitespace();
if self.ch_is(b'[') {
self.bump();
visitor.visit(SeqVisitor::new(self))
} else {
Err(self.de.error(ErrorCode::ExpectedSomeValue))
Err(self.error(ErrorCode::ExpectedSomeValue))
}
}
fn visit_map<V>(&mut self, mut visitor: V) -> Result<V::Value, Error>
where V: de::EnumMapVisitor,
where V: de::EnumMapVisitor
{
self.de.parse_whitespace();
try!(self.parse_object_colon());
if self.de.ch_is(b'{') {
self.de.bump();
visitor.visit(MapVisitor::new(self.de))
self.parse_whitespace();
if self.ch_is(b'{') {
self.bump();
visitor.visit(MapVisitor::new(self))
} else {
Err(self.de.error(ErrorCode::ExpectedSomeValue))
Err(self.error(ErrorCode::ExpectedSomeValue))
}
}
}

View File

@ -420,8 +420,8 @@ impl de::Deserializer for Deserializer {
}
#[inline]
fn visit_enum<V>(&mut self, mut visitor: V) -> Result<V::Value, Error>
where V: de::Visitor,
fn visit_enum<V>(&mut self, _name: &str, mut visitor: V) -> Result<V::Value, Error>
where V: de::EnumVisitor,
{
let value = match self.value.take() {
Some(Value::Object(value)) => value,
@ -433,16 +433,20 @@ impl de::Deserializer for Deserializer {
let value = match iter.next() {
Some((variant, Value::Array(fields))) => {
self.value = Some(Value::String(variant));
let len = fields.len();
try!(visitor.visit_variant(&variant, SeqDeserializer {
try!(visitor.visit(SeqDeserializer {
de: self,
iter: fields.into_iter(),
len: len,
}))
}
Some((variant, Value::Object(fields))) => {
self.value = Some(Value::String(variant));
let len = fields.len();
try!(visitor.visit_variant(&variant, MapDeserializer {
try!(visitor.visit(MapDeserializer {
de: self,
iter: fields.into_iter(),
value: None,
@ -495,9 +499,15 @@ impl<'a> de::SeqVisitor for SeqDeserializer<'a> {
}
}
impl<'a> de::EnumVisitor for SeqDeserializer<'a> {
impl<'a> de::VariantVisitor for SeqDeserializer<'a> {
type Error = Error;
fn visit_variant<V>(&mut self) -> Result<V, Error>
where V: de::Deserialize,
{
de::Deserialize::deserialize(self.de)
}
fn visit_unit(&mut self) -> Result<(), Error> {
if self.len == 0 {
Ok(())
@ -511,6 +521,12 @@ impl<'a> de::EnumVisitor for SeqDeserializer<'a> {
{
visitor.visit(self)
}
fn visit_map<V>(&mut self, _visitor: V) -> Result<V::Value, Error>
where V: de::EnumMapVisitor
{
Err(de::Error::syntax_error())
}
}
struct MapDeserializer<'a> {
@ -583,9 +599,25 @@ impl<'a> de::MapVisitor for MapDeserializer<'a> {
}
}
impl<'a> de::EnumVisitor for MapDeserializer<'a> {
impl<'a> de::VariantVisitor for MapDeserializer<'a> {
type Error = Error;
fn visit_variant<V>(&mut self) -> Result<V, Error>
where V: de::Deserialize,
{
de::Deserialize::deserialize(self.de)
}
fn visit_unit(&mut self) -> Result<(), Error> {
Err(de::Error::syntax_error())
}
fn visit_seq<V>(&mut self, _visitor: V) -> Result<V::Value, Error>
where V: de::EnumSeqVisitor
{
Err(de::Error::syntax_error())
}
fn visit_map<V>(&mut self, mut visitor: V) -> Result<V::Value, Error>
where V: de::EnumMapVisitor,
{

View File

@ -44,7 +44,7 @@ enum Token<'a> {
MapSep(bool),
MapEnd,
EnumStart(&'a str, &'a str),
EnumStart(&'a str),
EnumEnd,
}
@ -132,11 +132,6 @@ impl<'a> Deserializer for TokenDeserializer<'a> {
first: true,
})
}
Some(Token::EnumStart(name, variant)) => {
visitor.visit_enum(name, variant, TokenDeserializerEnumVisitor {
de: self,
})
}
Some(_) => Err(Error::SyntaxError),
None => Err(Error::EndOfStreamError),
}
@ -144,7 +139,6 @@ impl<'a> Deserializer for TokenDeserializer<'a> {
/// Hook into `Option` deserializing so we can treat `Unit` as a
/// `None`, or a regular value as `Some(value)`.
#[inline]
fn visit_option<V>(&mut self, mut visitor: V) -> Result<V::Value, Error>
where V: Visitor,
{
@ -165,6 +159,24 @@ impl<'a> Deserializer for TokenDeserializer<'a> {
None => Err(Error::EndOfStreamError),
}
}
fn visit_enum<V>(&mut self, name: &str, mut visitor: V) -> Result<V::Value, Error>
where V: de::EnumVisitor,
{
match self.tokens.next() {
Some(Token::EnumStart(n)) => {
if name == n {
visitor.visit(TokenDeserializerVariantVisitor {
de: self,
})
} else {
Err(Error::SyntaxError)
}
}
Some(_) => Err(Error::SyntaxError),
None => Err(Error::EndOfStreamError),
}
}
}
//////////////////////////////////////////////////////////////////////////
@ -263,13 +275,19 @@ impl<'a, 'b> de::MapVisitor for TokenDeserializerMapVisitor<'a, 'b> {
//////////////////////////////////////////////////////////////////////////
struct TokenDeserializerEnumVisitor<'a, 'b: 'a> {
struct TokenDeserializerVariantVisitor<'a, 'b: 'a> {
de: &'a mut TokenDeserializer<'b>,
}
impl<'a, 'b> de::EnumVisitor for TokenDeserializerEnumVisitor<'a, 'b> {
impl<'a, 'b> de::VariantVisitor for TokenDeserializerVariantVisitor<'a, 'b> {
type Error = Error;
fn visit_kind<V>(&mut self) -> Result<V, Error>
where V: de::Deserialize,
{
de::Deserialize::deserialize(self.de)
}
fn visit_unit(&mut self) -> Result<(), Error> {
let value = try!(Deserialize::deserialize(self.de));
@ -611,16 +629,18 @@ declare_tests! {
Token::MapEnd,
],
}
test_enum {
test_enum_unit {
Enum::Unit => vec![
Token::EnumStart("Enum", "Unit"),
Token::EnumStart("Enum"),
Token::Str("Unit"),
Token::Unit,
Token::EnumEnd,
],
}
test_enum_seq {
Enum::Seq(1, 2, 3) => vec![
Token::EnumStart("Enum", "Seq"),
Token::EnumStart("Enum"),
Token::Str("Seq"),
Token::SeqStart(3),
Token::SeqSep(true),
Token::I32(1),
@ -636,7 +656,8 @@ declare_tests! {
}
test_enum_map {
Enum::Map { a: 1, b: 2, c: 3 } => vec![
Token::EnumStart("Enum", "Map"),
Token::EnumStart("Enum"),
Token::Str("Map"),
Token::MapStart(3),
Token::MapSep(true),
Token::Str("a"),

View File

@ -907,7 +907,7 @@ fn test_parse_option() {
#[test]
fn test_parse_enum() {
test_parse_err::<Animal>(&[
("{}", Error::SyntaxError(ErrorCode::EOFWhileParsingString, 1, 3)),
("{}", Error::SyntaxError(ErrorCode::ExpectedSomeValue, 1, 2)),
("{\"unknown\":[]}", Error::SyntaxError(ErrorCode::ExpectedSomeValue, 0, 0)),
("{\"Dog\":{}}", Error::SyntaxError(ErrorCode::ExpectedSomeValue, 0, 0)),
("{\"Frog\":{}}", Error::SyntaxError(ErrorCode::ExpectedSomeValue, 1, 9)),