Add #[derive_deserialize] support for enums

This commit is contained in:
Erick Tryzelaar 2015-02-11 08:56:27 -08:00
parent 3fd42e616c
commit 1552eb72dc
2 changed files with 344 additions and 315 deletions

View File

@ -22,7 +22,7 @@ use syntax::ext::deriving::generic::{
Named,
StaticFields,
StaticStruct,
//StaticEnum,
StaticEnum,
Struct,
Substructure,
TraitDef,
@ -302,20 +302,19 @@ fn deserialize_substructure(cx: &ExtCtxt, span: Span, substr: &Substructure) ->
cx,
span,
substr.type_ident,
substr.type_ident,
cx.path(span, vec![substr.type_ident]),
fields,
state)
}
/*
StaticEnum(_, ref fields) => {
deserialize_enum(
cx,
span,
substr.type_ident,
&fields,
deserializer,
token)
state)
}
*/
_ => cx.bug("expected StaticEnum or StaticStruct in derive(Deserialize)")
}
}
@ -324,42 +323,141 @@ fn deserialize_struct(
cx: &ExtCtxt,
span: Span,
type_ident: Ident,
struct_ident: Ident,
struct_path: ast::Path,
fields: &StaticFields,
state: P<ast::Expr>,
) -> P<ast::Expr> {
match *fields {
Unnamed(ref fields) => {
deserialize_struct_unnamed_fields(
cx,
span,
type_ident,
&fields[],
state)
if fields.is_empty() {
deserialize_struct_empty_fields(
cx,
span,
type_ident,
struct_ident,
struct_path,
state)
} else {
deserialize_struct_unnamed_fields(
cx,
span,
type_ident,
struct_ident,
struct_path,
&fields[],
state)
}
}
Named(ref fields) => {
deserialize_struct_named_fields(
cx,
span,
type_ident,
struct_ident,
struct_path,
&fields[],
state)
}
}
}
fn deserialize_struct_empty_fields(
cx: &ExtCtxt,
span: Span,
type_ident: Ident,
struct_ident: Ident,
struct_path: ast::Path,
state: P<ast::Expr>,
) -> P<ast::Expr> {
let struct_name = cx.expr_str(span, token::get_ident(struct_ident));
let result = cx.expr_path(struct_path);
quote_expr!(cx, {
struct __Visitor;
impl ::serde2::de::Visitor for __Visitor {
type Value = $type_ident;
#[inline]
fn visit_unit<
E: ::serde2::de::Error,
>(&mut self) -> Result<$type_ident, E> {
Ok($result)
}
#[inline]
fn visit_named_unit<
E: ::serde2::de::Error,
>(&mut self, name: &str) -> Result<$type_ident, E> {
if name == $struct_name {
self.visit_unit()
} else {
Err(::serde2::de::Error::syntax_error())
}
}
}
$state.visit(&mut __Visitor)
})
}
fn deserialize_struct_unnamed_fields(
cx: &ExtCtxt,
span: Span,
type_ident: Ident,
struct_ident: Ident,
struct_path: ast::Path,
fields: &[Span],
state: P<ast::Expr>,
) -> P<ast::Expr> {
let type_name = cx.expr_str(span, token::get_ident(type_ident));
let struct_name = cx.expr_str(span, token::get_ident(struct_ident));
let field_names: Vec<ast::Ident> = (0 .. fields.len())
.map(|i| token::str_to_ident(&format!("__field{}", i)))
.collect();
let visit_seq_expr = declare_visit_seq(
cx,
span,
struct_path,
&field_names[],
);
quote_expr!(cx, {
struct __Visitor;
impl ::serde2::de::Visitor for __Visitor {
type Value = $type_ident;
fn visit_seq<
__V: ::serde2::de::SeqVisitor,
>(&mut self, mut visitor: __V) -> Result<$type_ident, __V::Error> {
$visit_seq_expr
}
fn visit_named_seq<
__V: ::serde2::de::SeqVisitor,
>(&mut self, name: &str, visitor: __V) -> Result<$type_ident, __V::Error> {
if name == $struct_name {
self.visit_seq(visitor)
} else {
Err(::serde2::de::Error::syntax_error())
}
}
}
$state.visit(&mut __Visitor)
})
}
fn declare_visit_seq(
cx: &ExtCtxt,
span: Span,
struct_path: ast::Path,
field_names: &[Ident],
) -> P<ast::Expr> {
let let_values: Vec<P<ast::Stmt>> = field_names.iter()
.map(|name| {
quote_stmt!(cx,
@ -373,32 +471,72 @@ fn deserialize_struct_unnamed_fields(
})
.collect();
let result = cx.expr_call_ident(
let result = cx.expr_call(
span,
type_ident,
cx.expr_path(struct_path),
field_names.iter().map(|name| cx.expr_ident(span, *name)).collect());
quote_expr!(cx, {
$let_values
try!(visitor.end());
Ok($result)
})
}
fn deserialize_struct_named_fields(
cx: &ExtCtxt,
span: Span,
type_ident: Ident,
struct_ident: Ident,
struct_path: ast::Path,
fields: &[(Ident, Span)],
state: P<ast::Expr>,
) -> P<ast::Expr> {
let struct_name = cx.expr_str(span, token::get_ident(struct_ident));
// Create the field names for the fields.
let field_names: Vec<ast::Ident> = (0 .. fields.len())
.map(|i| token::str_to_ident(&format!("__field{}", i)))
.collect();
let field_deserializer = declare_map_field_deserializer(
cx,
span,
&field_names[],
fields,
);
let visit_map_expr = declare_visit_map(
cx,
span,
struct_path,
&field_names[],
fields,
);
quote_expr!(cx, {
$field_deserializer
struct __Visitor;
impl ::serde2::de::Visitor for __Visitor {
type Value = $type_ident;
fn visit_seq<
__V: ::serde2::de::SeqVisitor,
#[inline]
fn visit_map<
__V: ::serde2::de::MapVisitor,
>(&mut self, mut visitor: __V) -> Result<$type_ident, __V::Error> {
$let_values
try!(visitor.end());
Ok($result)
$visit_map_expr
}
fn visit_named_seq<
__V: ::serde2::de::SeqVisitor,
#[inline]
fn visit_named_map<
__V: ::serde2::de::MapVisitor,
>(&mut self, name: &str, visitor: __V) -> Result<$type_ident, __V::Error> {
if name == $type_name {
self.visit_seq(visitor)
if name == $struct_name {
self.visit_map(visitor)
} else {
Err(::serde2::de::Error::syntax_error())
}
@ -409,20 +547,12 @@ fn deserialize_struct_unnamed_fields(
})
}
fn deserialize_struct_named_fields(
fn declare_map_field_deserializer(
cx: &ExtCtxt,
span: Span,
type_ident: Ident,
field_names: &[ast::Ident],
fields: &[(Ident, Span)],
state: P<ast::Expr>,
) -> P<ast::Expr> {
let type_name = cx.expr_str(span, token::get_ident(type_ident));
// Create the field names for the fields.
let field_names: Vec<ast::Ident> = (0 .. fields.len())
.map(|i| token::str_to_ident(&format!("__field{}", i)))
.collect();
) -> Vec<P<ast::Item>> {
// Create the field names for the fields.
let field_variants: Vec<P<ast::Variant>> = field_names.iter()
.map(|field| {
@ -453,6 +583,52 @@ fn deserialize_struct_named_fields(
})
.collect();
vec![
quote_item!(cx,
#[allow(non_camel_case_types)]
$field_enum
).unwrap(),
quote_item!(cx,
struct __FieldVisitor;
).unwrap(),
quote_item!(cx,
impl ::serde2::de::Visitor for __FieldVisitor {
type Value = __Field;
fn visit_str<
E: ::serde2::de::Error,
>(&mut self, value: &str) -> Result<__Field, E> {
match value {
$field_arms
_ => Err(::serde2::de::Error::syntax_error()),
}
}
}
).unwrap(),
quote_item!(cx,
impl ::serde2::de::Deserialize for __Field {
#[inline]
fn deserialize<
__S: ::serde2::de::Deserializer,
>(state: &mut __S) -> Result<__Field, __S::Error> {
state.visit(&mut __FieldVisitor)
}
}
).unwrap(),
]
}
fn declare_visit_map(
cx: &ExtCtxt,
span: Span,
struct_path: ast::Path,
field_names: &[Ident],
fields: &[(Ident, Span)],
) -> P<ast::Expr> {
// Declare each field.
let let_values: Vec<P<ast::Stmt>> = field_names.iter()
.map(|field| {
@ -484,9 +660,9 @@ fn deserialize_struct_named_fields(
})
.collect();
let result = cx.expr_struct_ident(
let result = cx.expr_struct(
span,
type_ident,
struct_path,
fields.iter()
.zip(field_names.iter())
.map(|(&(name, span), field)| {
@ -496,230 +672,15 @@ fn deserialize_struct_named_fields(
);
quote_expr!(cx, {
#[allow(non_camel_case_types)]
$field_enum
$let_values
struct __FieldVisitor;
impl ::serde2::de::Visitor for __FieldVisitor {
type Value = __Field;
fn visit_str<
E: ::serde2::de::Error,
>(&mut self, value: &str) -> Result<__Field, E> {
match value {
$field_arms
_ => Err(::serde2::de::Error::syntax_error()),
}
while let Some(key) = try!(visitor.visit_key()) {
match key {
$value_arms
}
}
impl ::serde2::de::Deserialize for __Field {
#[inline]
fn deserialize<
__S: ::serde2::de::Deserializer,
>(state: &mut __S) -> Result<__Field, __S::Error> {
state.visit(&mut __FieldVisitor)
}
}
struct __Visitor;
impl ::serde2::de::Visitor for __Visitor {
type Value = $type_ident;
fn visit_map<
__V: ::serde2::de::MapVisitor,
>(&mut self, mut visitor: __V) -> Result<$type_ident, __V::Error> {
$let_values
while let Some(key) = try!(visitor.visit_key()) {
match key {
$value_arms
}
}
$extract_values
Ok($result)
}
fn visit_named_map<
__V: ::serde2::de::MapVisitor,
>(&mut self, name: &str, visitor: __V) -> Result<$type_ident, __V::Error> {
if name == $type_name {
self.visit_map(visitor)
} else {
Err(::serde2::de::Error::syntax_error())
}
}
}
$state.visit(&mut __Visitor)
})
}
/*
fn deserialize_struct(
cx: &ExtCtxt,
span: Span,
type_ident: Ident,
fields: &StaticFields,
deserializer: P<ast::Expr>,
token: P<ast::Expr>
) -> P<ast::Expr> {
/*
let struct_block = deserialize_struct_from_struct(
cx,
span,
type_ident,
fields,
deserializer
);
*/
let map_block = deserialize_struct_from_map(
cx,
span,
type_ident,
fields,
deserializer
);
quote_expr!(
cx,
match $token {
::serde2::de::StructStart(_, _) => $struct_block,
::serde2::de::MapStart(_) => $map_block,
token => {
let expected_tokens = [
::serde2::de::StructStartKind,
::serde2::de::MapStartKind,
];
Err($deserializer.syntax_error(token, expected_tokens))
}
}
)
}
/*
fn deserialize_struct_from_struct(
cx: &ExtCtxt,
span: Span,
type_ident: Ident,
fields: &StaticFields,
deserializer: P<ast::Expr>
) -> P<ast::Expr> {
let expect_struct_field = cx.ident_of("expect_struct_field");
let call = deserialize_static_fields(
cx,
span,
type_ident,
fields,
|cx, span, name| {
let name = cx.expr_str(span, name);
quote_expr!(
cx,
try!($deserializer.expect_struct_field($name))
)
}
);
quote_expr!(cx, {
let result = $call;
try!($deserializer.expect_struct_end());
Ok(result)
})
}
*/
fn deserialize_struct_from_map(
cx: &ExtCtxt,
span: Span,
type_ident: Ident,
fields: &StaticFields,
deserializer: P<ast::Expr>
) -> P<ast::Expr> {
let fields = match *fields {
Unnamed(_) => panic!(),
Named(ref fields) => &fields[],
};
// Declare each field.
let let_fields: Vec<P<ast::Stmt>> = fields.iter()
.map(|&(name, span)| {
quote_stmt!(cx, let mut $name = None)
})
.collect();
// Declare key arms.
let key_arms: Vec<ast::Arm> = fields.iter()
.map(|&(name, span)| {
let s = cx.expr_str(span, token::get_ident(name));
quote_arm!(cx,
$s => {
$name = Some(
try!(::serde2::de::Deserialize::deserialize($deserializer))
);
continue;
})
})
.collect();
let extract_fields: Vec<P<ast::Stmt>> = fields.iter()
.map(|&(name, span)| {
let name_str = cx.expr_str(span, token::get_ident(name));
quote_stmt!(cx,
let $name = match $name {
Some($name) => $name,
None => try!($deserializer.missing_field($name_str)),
};
)
})
.collect();
let result = cx.expr_struct_ident(
span,
type_ident,
fields.iter()
.map(|&(name, span)| {
cx.field_imm(span, name, cx.expr_ident(span, name))
})
.collect()
);
quote_expr!(cx, {
$let_fields
loop {
let token = match try!($deserializer.expect_token()) {
::serde2::de::End => { break; }
token => token,
};
{
let key = match token {
::serde2::de::Str(s) => s,
::serde2::de::String(ref s) => &s,
token => {
let expected_tokens = [
::serde2::de::StrKind,
::serde2::de::StringKind,
];
return Err($deserializer.syntax_error(token, expected_tokens));
}
};
match key {
$key_arms
_ => { }
}
}
try!($deserializer.ignore_field(token))
}
$extract_fields
$extract_values
Ok($result)
})
}
@ -729,89 +690,144 @@ fn deserialize_enum(
span: Span,
type_ident: Ident,
fields: &[(Ident, Span, StaticFields)],
deserializer: P<ast::Expr>,
token: P<ast::Expr>
state: P<ast::Expr>,
) -> P<ast::Expr> {
let type_name = cx.expr_str(span, token::get_ident(type_ident));
let variants = fields.iter()
.map(|&(name, span, _)| {
cx.expr_str(span, token::get_ident(name))
})
.collect();
let variants = cx.expr_vec(span, variants);
let arms: Vec<ast::Arm> = fields.iter()
.enumerate()
.map(|(i, &(name, span, ref parts))| {
let call = deserialize_static_fields(
// Match arms to extract a variant from a string
let variant_arms: Vec<ast::Arm> = fields.iter()
.map(|&(name, span, ref fields)| {
let value = deserialize_enum_variant(
cx,
span,
type_ident,
name,
parts,
|cx, span, _| {
quote_expr!(cx, try!($deserializer.expect_enum_elt()))
}
fields,
cx.expr_ident(span, cx.ident_of("visitor")),
);
quote_arm!(cx, $i => $call,)
let s = cx.expr_str(span, token::get_ident(name));
quote_arm!(cx, $s => $value,)
})
.collect();
quote_expr!(cx, {
let i = try!($deserializer.expect_enum_start($token, $type_name, $variants));
struct __Visitor;
let result = match i {
$arms
_ => { unreachable!() }
};
impl ::serde2::de::Visitor for __Visitor {
type Value = $type_ident;
try!($deserializer.expect_enum_end());
fn visit_enum<
__V: ::serde2::de::EnumVisitor,
>(&mut self, name: &str, variant: &str, mut visitor: __V) -> Result<$type_ident, __V::Error> {
if name == $type_name {
self.visit_variant(variant, visitor)
} else {
Err(::serde2::de::Error::syntax_error())
}
}
Ok(result)
fn visit_variant<
__V: ::serde2::de::EnumVisitor,
>(&mut self, name: &str, mut visitor: __V) -> Result<$type_ident, __V::Error> {
match name {
$variant_arms
_ => Err(::serde2::de::Error::syntax_error()),
}
}
}
$state.visit(&mut __Visitor)
})
}
/// Create a deserializer for a single enum variant/struct:
/// - `outer_pat_ident` is the name of this enum variant/struct
/// - `getarg` should retrieve the `u32`-th field with name `&str`.
fn deserialize_static_fields(
fn deserialize_enum_variant(
cx: &ExtCtxt,
span: Span,
outer_pat_ident: Ident,
type_ident: Ident,
variant_ident: Ident,
fields: &StaticFields,
getarg: |&ExtCtxt, Span, token::InternedString| -> P<Expr>
) -> P<Expr> {
state: P<ast::Expr>,
) -> P<ast::Expr> {
let variant_path = cx.path(span, vec![type_ident, variant_ident]);
match *fields {
Unnamed(ref fields) => {
if fields.is_empty() {
cx.expr_ident(span, outer_pat_ident)
} else {
let fields = fields.iter().enumerate().map(|(i, &span)| {
getarg(
cx,
span,
token::intern_and_get_ident(&format!("_field{}", i))
)
}).collect();
let result = cx.expr_path(variant_path);
cx.expr_call_ident(span, outer_pat_ident, fields)
quote_expr!(cx, {
try!($state.visit_unit());
Ok($result)
})
} else {
// Create the field names for the fields.
let field_names: Vec<ast::Ident> = (0 .. fields.len())
.map(|i| token::str_to_ident(&format!("__field{}", i)))
.collect();
let visit_seq_expr = declare_visit_seq(
cx,
span,
variant_path,
&field_names[],
);
quote_expr!(cx, {
struct __Visitor;
impl ::serde2::de::EnumSeqVisitor for __Visitor {
type Value = $type_ident;
fn visit<
V: ::serde2::de::SeqVisitor,
>(&mut self, mut visitor: V) -> Result<$type_ident, V::Error> {
$visit_seq_expr
}
}
$state.visit_seq(&mut __Visitor)
})
}
}
Named(ref fields) => {
// use the field's span to get nicer error messages.
let fields = fields.iter().map(|&(name, span)| {
let arg = getarg(
cx,
span,
token::get_ident(name)
);
cx.field_imm(span, name, arg)
}).collect();
// Create the field names for the fields.
let field_names: Vec<ast::Ident> = (0 .. fields.len())
.map(|i| token::str_to_ident(&format!("__field{}", i)))
.collect();
cx.expr_struct_ident(span, outer_pat_ident, fields)
let field_deserializer = declare_map_field_deserializer(
cx,
span,
&field_names[],
fields,
);
let visit_map_expr = declare_visit_map(
cx,
span,
variant_path,
&field_names[],
fields,
);
quote_expr!(cx, {
$field_deserializer
struct __Visitor;
impl ::serde2::de::EnumMapVisitor for __Visitor {
type Value = $type_ident;
fn visit<
V: ::serde2::de::MapVisitor,
>(&mut self, mut visitor: V) -> Result<$type_ident, V::Error> {
$visit_map_expr
}
}
$state.visit_map(&mut __Visitor)
})
}
}
}
*/

View File

@ -201,6 +201,13 @@ pub trait Visitor {
>(&mut self, _name: &str, _variant: &str, _visitor: V) -> Result<Self::Value, V::Error> {
Err(Error::syntax_error())
}
#[inline]
fn visit_variant<
V: EnumVisitor,
>(&mut self, _name: &str, _visitor: V) -> Result<Self::Value, V::Error> {
Err(Error::syntax_error())
}
}
pub trait SeqVisitor {
@ -1177,11 +1184,17 @@ mod tests {
fn visit_enum<
V: super::EnumVisitor,
>(&mut self, name: &str, variant: &str, mut visitor: V) -> Result<Enum, V::Error> {
if name != "Enum" {
return Err(super::Error::syntax_error());
if name == "Enum" {
self.visit_variant(variant, visitor)
} else {
Err(super::Error::syntax_error());
}
}
match variant {
fn visit_variant<
V: super::EnumVisitor,
>(&mut self, name: &str, mut visitor: V) -> Result<Enum, V::Error> {
match name {
"Unit" => {
try!(visitor.visit_unit());
Ok(Enum::Unit)