From 1552eb72dcf6cfaffc0be263522fa9bdca408d61 Mon Sep 17 00:00:00 2001 From: Erick Tryzelaar Date: Wed, 11 Feb 2015 08:56:27 -0800 Subject: [PATCH] Add #[derive_deserialize] support for enums --- serde2/serde2_macros/src/lib.rs | 640 ++++++++++++++++---------------- serde2/src/de.rs | 19 +- 2 files changed, 344 insertions(+), 315 deletions(-) diff --git a/serde2/serde2_macros/src/lib.rs b/serde2/serde2_macros/src/lib.rs index e76ceee0..831ea512 100644 --- a/serde2/serde2_macros/src/lib.rs +++ b/serde2/serde2_macros/src/lib.rs @@ -22,7 +22,7 @@ use syntax::ext::deriving::generic::{ Named, StaticFields, StaticStruct, - //StaticEnum, + StaticEnum, Struct, Substructure, TraitDef, @@ -302,20 +302,19 @@ fn deserialize_substructure(cx: &ExtCtxt, span: Span, substr: &Substructure) -> cx, span, substr.type_ident, + substr.type_ident, + cx.path(span, vec![substr.type_ident]), fields, state) } - /* StaticEnum(_, ref fields) => { deserialize_enum( cx, span, substr.type_ident, &fields, - deserializer, - token) + state) } - */ _ => cx.bug("expected StaticEnum or StaticStruct in derive(Deserialize)") } } @@ -324,42 +323,141 @@ fn deserialize_struct( cx: &ExtCtxt, span: Span, type_ident: Ident, + struct_ident: Ident, + struct_path: ast::Path, fields: &StaticFields, state: P, ) -> P { match *fields { Unnamed(ref fields) => { - deserialize_struct_unnamed_fields( - cx, - span, - type_ident, - &fields[], - state) + if fields.is_empty() { + deserialize_struct_empty_fields( + cx, + span, + type_ident, + struct_ident, + struct_path, + state) + } else { + deserialize_struct_unnamed_fields( + cx, + span, + type_ident, + struct_ident, + struct_path, + &fields[], + state) + } } Named(ref fields) => { deserialize_struct_named_fields( cx, span, type_ident, + struct_ident, + struct_path, &fields[], state) } } } +fn deserialize_struct_empty_fields( + cx: &ExtCtxt, + span: Span, + type_ident: Ident, + struct_ident: Ident, + struct_path: ast::Path, + state: P, +) -> P { + let struct_name = cx.expr_str(span, token::get_ident(struct_ident)); + + let result = cx.expr_path(struct_path); + + quote_expr!(cx, { + struct __Visitor; + + impl ::serde2::de::Visitor for __Visitor { + type Value = $type_ident; + + #[inline] + fn visit_unit< + E: ::serde2::de::Error, + >(&mut self) -> Result<$type_ident, E> { + Ok($result) + } + + #[inline] + fn visit_named_unit< + E: ::serde2::de::Error, + >(&mut self, name: &str) -> Result<$type_ident, E> { + if name == $struct_name { + self.visit_unit() + } else { + Err(::serde2::de::Error::syntax_error()) + } + } + } + + $state.visit(&mut __Visitor) + }) +} + fn deserialize_struct_unnamed_fields( cx: &ExtCtxt, span: Span, type_ident: Ident, + struct_ident: Ident, + struct_path: ast::Path, fields: &[Span], state: P, ) -> P { - let type_name = cx.expr_str(span, token::get_ident(type_ident)); + let struct_name = cx.expr_str(span, token::get_ident(struct_ident)); let field_names: Vec = (0 .. fields.len()) .map(|i| token::str_to_ident(&format!("__field{}", i))) .collect(); + let visit_seq_expr = declare_visit_seq( + cx, + span, + struct_path, + &field_names[], + ); + + quote_expr!(cx, { + struct __Visitor; + + impl ::serde2::de::Visitor for __Visitor { + type Value = $type_ident; + + fn visit_seq< + __V: ::serde2::de::SeqVisitor, + >(&mut self, mut visitor: __V) -> Result<$type_ident, __V::Error> { + $visit_seq_expr + } + + fn visit_named_seq< + __V: ::serde2::de::SeqVisitor, + >(&mut self, name: &str, visitor: __V) -> Result<$type_ident, __V::Error> { + if name == $struct_name { + self.visit_seq(visitor) + } else { + Err(::serde2::de::Error::syntax_error()) + } + } + } + + $state.visit(&mut __Visitor) + }) +} + +fn declare_visit_seq( + cx: &ExtCtxt, + span: Span, + struct_path: ast::Path, + field_names: &[Ident], +) -> P { let let_values: Vec> = field_names.iter() .map(|name| { quote_stmt!(cx, @@ -373,32 +471,72 @@ fn deserialize_struct_unnamed_fields( }) .collect(); - let result = cx.expr_call_ident( + let result = cx.expr_call( span, - type_ident, + cx.expr_path(struct_path), field_names.iter().map(|name| cx.expr_ident(span, *name)).collect()); quote_expr!(cx, { + $let_values + + try!(visitor.end()); + + Ok($result) + }) +} + +fn deserialize_struct_named_fields( + cx: &ExtCtxt, + span: Span, + type_ident: Ident, + struct_ident: Ident, + struct_path: ast::Path, + fields: &[(Ident, Span)], + state: P, +) -> P { + let struct_name = cx.expr_str(span, token::get_ident(struct_ident)); + + // Create the field names for the fields. + let field_names: Vec = (0 .. fields.len()) + .map(|i| token::str_to_ident(&format!("__field{}", i))) + .collect(); + + let field_deserializer = declare_map_field_deserializer( + cx, + span, + &field_names[], + fields, + ); + + let visit_map_expr = declare_visit_map( + cx, + span, + struct_path, + &field_names[], + fields, + ); + + quote_expr!(cx, { + $field_deserializer + struct __Visitor; impl ::serde2::de::Visitor for __Visitor { type Value = $type_ident; - fn visit_seq< - __V: ::serde2::de::SeqVisitor, + #[inline] + fn visit_map< + __V: ::serde2::de::MapVisitor, >(&mut self, mut visitor: __V) -> Result<$type_ident, __V::Error> { - $let_values - - try!(visitor.end()); - - Ok($result) + $visit_map_expr } - fn visit_named_seq< - __V: ::serde2::de::SeqVisitor, + #[inline] + fn visit_named_map< + __V: ::serde2::de::MapVisitor, >(&mut self, name: &str, visitor: __V) -> Result<$type_ident, __V::Error> { - if name == $type_name { - self.visit_seq(visitor) + if name == $struct_name { + self.visit_map(visitor) } else { Err(::serde2::de::Error::syntax_error()) } @@ -409,20 +547,12 @@ fn deserialize_struct_unnamed_fields( }) } -fn deserialize_struct_named_fields( +fn declare_map_field_deserializer( cx: &ExtCtxt, span: Span, - type_ident: Ident, + field_names: &[ast::Ident], fields: &[(Ident, Span)], - state: P, -) -> P { - let type_name = cx.expr_str(span, token::get_ident(type_ident)); - - // Create the field names for the fields. - let field_names: Vec = (0 .. fields.len()) - .map(|i| token::str_to_ident(&format!("__field{}", i))) - .collect(); - +) -> Vec> { // Create the field names for the fields. let field_variants: Vec> = field_names.iter() .map(|field| { @@ -453,6 +583,52 @@ fn deserialize_struct_named_fields( }) .collect(); + vec![ + quote_item!(cx, + #[allow(non_camel_case_types)] + $field_enum + ).unwrap(), + + quote_item!(cx, + struct __FieldVisitor; + ).unwrap(), + + quote_item!(cx, + impl ::serde2::de::Visitor for __FieldVisitor { + type Value = __Field; + + fn visit_str< + E: ::serde2::de::Error, + >(&mut self, value: &str) -> Result<__Field, E> { + match value { + $field_arms + _ => Err(::serde2::de::Error::syntax_error()), + } + } + } + ).unwrap(), + + quote_item!(cx, + impl ::serde2::de::Deserialize for __Field { + #[inline] + fn deserialize< + __S: ::serde2::de::Deserializer, + >(state: &mut __S) -> Result<__Field, __S::Error> { + state.visit(&mut __FieldVisitor) + } + } + ).unwrap(), + ] +} + +fn declare_visit_map( + cx: &ExtCtxt, + span: Span, + struct_path: ast::Path, + field_names: &[Ident], + fields: &[(Ident, Span)], +) -> P { + // Declare each field. let let_values: Vec> = field_names.iter() .map(|field| { @@ -484,9 +660,9 @@ fn deserialize_struct_named_fields( }) .collect(); - let result = cx.expr_struct_ident( + let result = cx.expr_struct( span, - type_ident, + struct_path, fields.iter() .zip(field_names.iter()) .map(|(&(name, span), field)| { @@ -496,230 +672,15 @@ fn deserialize_struct_named_fields( ); quote_expr!(cx, { - #[allow(non_camel_case_types)] - $field_enum + $let_values - struct __FieldVisitor; - - impl ::serde2::de::Visitor for __FieldVisitor { - type Value = __Field; - - fn visit_str< - E: ::serde2::de::Error, - >(&mut self, value: &str) -> Result<__Field, E> { - match value { - $field_arms - _ => Err(::serde2::de::Error::syntax_error()), - } + while let Some(key) = try!(visitor.visit_key()) { + match key { + $value_arms } } - impl ::serde2::de::Deserialize for __Field { - #[inline] - fn deserialize< - __S: ::serde2::de::Deserializer, - >(state: &mut __S) -> Result<__Field, __S::Error> { - state.visit(&mut __FieldVisitor) - } - } - - struct __Visitor; - - impl ::serde2::de::Visitor for __Visitor { - type Value = $type_ident; - - fn visit_map< - __V: ::serde2::de::MapVisitor, - >(&mut self, mut visitor: __V) -> Result<$type_ident, __V::Error> { - $let_values - - while let Some(key) = try!(visitor.visit_key()) { - match key { - $value_arms - } - } - - $extract_values - Ok($result) - } - - fn visit_named_map< - __V: ::serde2::de::MapVisitor, - >(&mut self, name: &str, visitor: __V) -> Result<$type_ident, __V::Error> { - if name == $type_name { - self.visit_map(visitor) - } else { - Err(::serde2::de::Error::syntax_error()) - } - } - } - - $state.visit(&mut __Visitor) - }) -} - -/* -fn deserialize_struct( - cx: &ExtCtxt, - span: Span, - type_ident: Ident, - fields: &StaticFields, - deserializer: P, - token: P -) -> P { - /* - let struct_block = deserialize_struct_from_struct( - cx, - span, - type_ident, - fields, - deserializer - ); - */ - - let map_block = deserialize_struct_from_map( - cx, - span, - type_ident, - fields, - deserializer - ); - - quote_expr!( - cx, - match $token { - ::serde2::de::StructStart(_, _) => $struct_block, - ::serde2::de::MapStart(_) => $map_block, - token => { - let expected_tokens = [ - ::serde2::de::StructStartKind, - ::serde2::de::MapStartKind, - ]; - Err($deserializer.syntax_error(token, expected_tokens)) - } - } - ) -} - -/* -fn deserialize_struct_from_struct( - cx: &ExtCtxt, - span: Span, - type_ident: Ident, - fields: &StaticFields, - deserializer: P -) -> P { - let expect_struct_field = cx.ident_of("expect_struct_field"); - - let call = deserialize_static_fields( - cx, - span, - type_ident, - fields, - |cx, span, name| { - let name = cx.expr_str(span, name); - quote_expr!( - cx, - try!($deserializer.expect_struct_field($name)) - ) - } - ); - - quote_expr!(cx, { - let result = $call; - try!($deserializer.expect_struct_end()); - Ok(result) - }) -} -*/ - -fn deserialize_struct_from_map( - cx: &ExtCtxt, - span: Span, - type_ident: Ident, - fields: &StaticFields, - deserializer: P -) -> P { - let fields = match *fields { - Unnamed(_) => panic!(), - Named(ref fields) => &fields[], - }; - - // Declare each field. - let let_fields: Vec> = fields.iter() - .map(|&(name, span)| { - quote_stmt!(cx, let mut $name = None) - }) - .collect(); - - // Declare key arms. - let key_arms: Vec = fields.iter() - .map(|&(name, span)| { - let s = cx.expr_str(span, token::get_ident(name)); - quote_arm!(cx, - $s => { - $name = Some( - try!(::serde2::de::Deserialize::deserialize($deserializer)) - ); - continue; - }) - }) - .collect(); - - let extract_fields: Vec> = fields.iter() - .map(|&(name, span)| { - let name_str = cx.expr_str(span, token::get_ident(name)); - quote_stmt!(cx, - let $name = match $name { - Some($name) => $name, - None => try!($deserializer.missing_field($name_str)), - }; - ) - }) - .collect(); - - let result = cx.expr_struct_ident( - span, - type_ident, - fields.iter() - .map(|&(name, span)| { - cx.field_imm(span, name, cx.expr_ident(span, name)) - }) - .collect() - ); - - quote_expr!(cx, { - $let_fields - - loop { - let token = match try!($deserializer.expect_token()) { - ::serde2::de::End => { break; } - token => token, - }; - - { - let key = match token { - ::serde2::de::Str(s) => s, - ::serde2::de::String(ref s) => &s, - token => { - let expected_tokens = [ - ::serde2::de::StrKind, - ::serde2::de::StringKind, - ]; - return Err($deserializer.syntax_error(token, expected_tokens)); - } - }; - - match key { - $key_arms - _ => { } - } - } - - try!($deserializer.ignore_field(token)) - } - - $extract_fields + $extract_values Ok($result) }) } @@ -729,89 +690,144 @@ fn deserialize_enum( span: Span, type_ident: Ident, fields: &[(Ident, Span, StaticFields)], - deserializer: P, - token: P + state: P, ) -> P { let type_name = cx.expr_str(span, token::get_ident(type_ident)); - let variants = fields.iter() - .map(|&(name, span, _)| { - cx.expr_str(span, token::get_ident(name)) - }) - .collect(); - - let variants = cx.expr_vec(span, variants); - - let arms: Vec = fields.iter() - .enumerate() - .map(|(i, &(name, span, ref parts))| { - let call = deserialize_static_fields( + // Match arms to extract a variant from a string + let variant_arms: Vec = fields.iter() + .map(|&(name, span, ref fields)| { + let value = deserialize_enum_variant( cx, span, + type_ident, name, - parts, - |cx, span, _| { - quote_expr!(cx, try!($deserializer.expect_enum_elt())) - } + fields, + cx.expr_ident(span, cx.ident_of("visitor")), ); - quote_arm!(cx, $i => $call,) + let s = cx.expr_str(span, token::get_ident(name)); + quote_arm!(cx, $s => $value,) }) .collect(); quote_expr!(cx, { - let i = try!($deserializer.expect_enum_start($token, $type_name, $variants)); + struct __Visitor; - let result = match i { - $arms - _ => { unreachable!() } - }; + impl ::serde2::de::Visitor for __Visitor { + type Value = $type_ident; - try!($deserializer.expect_enum_end()); + fn visit_enum< + __V: ::serde2::de::EnumVisitor, + >(&mut self, name: &str, variant: &str, mut visitor: __V) -> Result<$type_ident, __V::Error> { + if name == $type_name { + self.visit_variant(variant, visitor) + } else { + Err(::serde2::de::Error::syntax_error()) + } + } - Ok(result) + fn visit_variant< + __V: ::serde2::de::EnumVisitor, + >(&mut self, name: &str, mut visitor: __V) -> Result<$type_ident, __V::Error> { + match name { + $variant_arms + _ => Err(::serde2::de::Error::syntax_error()), + } + } + } + + $state.visit(&mut __Visitor) }) } -/// Create a deserializer for a single enum variant/struct: -/// - `outer_pat_ident` is the name of this enum variant/struct -/// - `getarg` should retrieve the `u32`-th field with name `&str`. -fn deserialize_static_fields( +fn deserialize_enum_variant( cx: &ExtCtxt, span: Span, - outer_pat_ident: Ident, + type_ident: Ident, + variant_ident: Ident, fields: &StaticFields, - getarg: |&ExtCtxt, Span, token::InternedString| -> P -) -> P { + state: P, +) -> P { + let variant_path = cx.path(span, vec![type_ident, variant_ident]); + match *fields { Unnamed(ref fields) => { if fields.is_empty() { - cx.expr_ident(span, outer_pat_ident) - } else { - let fields = fields.iter().enumerate().map(|(i, &span)| { - getarg( - cx, - span, - token::intern_and_get_ident(&format!("_field{}", i)) - ) - }).collect(); + let result = cx.expr_path(variant_path); - cx.expr_call_ident(span, outer_pat_ident, fields) + quote_expr!(cx, { + try!($state.visit_unit()); + Ok($result) + }) + } else { + // Create the field names for the fields. + let field_names: Vec = (0 .. fields.len()) + .map(|i| token::str_to_ident(&format!("__field{}", i))) + .collect(); + + let visit_seq_expr = declare_visit_seq( + cx, + span, + variant_path, + &field_names[], + ); + + quote_expr!(cx, { + struct __Visitor; + + impl ::serde2::de::EnumSeqVisitor for __Visitor { + type Value = $type_ident; + + fn visit< + V: ::serde2::de::SeqVisitor, + >(&mut self, mut visitor: V) -> Result<$type_ident, V::Error> { + $visit_seq_expr + } + } + + $state.visit_seq(&mut __Visitor) + }) } } Named(ref fields) => { - // use the field's span to get nicer error messages. - let fields = fields.iter().map(|&(name, span)| { - let arg = getarg( - cx, - span, - token::get_ident(name) - ); - cx.field_imm(span, name, arg) - }).collect(); + // Create the field names for the fields. + let field_names: Vec = (0 .. fields.len()) + .map(|i| token::str_to_ident(&format!("__field{}", i))) + .collect(); - cx.expr_struct_ident(span, outer_pat_ident, fields) + let field_deserializer = declare_map_field_deserializer( + cx, + span, + &field_names[], + fields, + ); + + let visit_map_expr = declare_visit_map( + cx, + span, + variant_path, + &field_names[], + fields, + ); + + quote_expr!(cx, { + $field_deserializer + + struct __Visitor; + + impl ::serde2::de::EnumMapVisitor for __Visitor { + type Value = $type_ident; + + fn visit< + V: ::serde2::de::MapVisitor, + >(&mut self, mut visitor: V) -> Result<$type_ident, V::Error> { + $visit_map_expr + } + } + + $state.visit_map(&mut __Visitor) + }) } } } -*/ diff --git a/serde2/src/de.rs b/serde2/src/de.rs index 462553c0..62aa048c 100644 --- a/serde2/src/de.rs +++ b/serde2/src/de.rs @@ -201,6 +201,13 @@ pub trait Visitor { >(&mut self, _name: &str, _variant: &str, _visitor: V) -> Result { Err(Error::syntax_error()) } + + #[inline] + fn visit_variant< + V: EnumVisitor, + >(&mut self, _name: &str, _visitor: V) -> Result { + Err(Error::syntax_error()) + } } pub trait SeqVisitor { @@ -1177,11 +1184,17 @@ mod tests { fn visit_enum< V: super::EnumVisitor, >(&mut self, name: &str, variant: &str, mut visitor: V) -> Result { - if name != "Enum" { - return Err(super::Error::syntax_error()); + if name == "Enum" { + self.visit_variant(variant, visitor) + } else { + Err(super::Error::syntax_error()); } + } - match variant { + fn visit_variant< + V: super::EnumVisitor, + >(&mut self, name: &str, mut visitor: V) -> Result { + match name { "Unit" => { try!(visitor.visit_unit()); Ok(Enum::Unit)