From bfac1a581c8a4f4a2c7d3b6622114f395e98cd96 Mon Sep 17 00:00:00 2001 From: John Heitmann Date: Sun, 10 Jan 2016 19:34:48 -0800 Subject: [PATCH] Implemented disallow_unknown * Added codegen for disallow_unknown * ... with new default to ignore unknown values during deserialization * Added ContainerAttrs --- README.md | 9 +++ serde/src/de/impls.rs | 104 ++++++++++++++++++++++++++ serde_codegen/src/attr.rs | 63 ++++++++++++++++ serde_codegen/src/de.rs | 62 +++++++++++++-- serde_codegen/src/field.rs | 9 ++- serde_tests/tests/test_annotations.rs | 68 ++++++++++++++++- serde_tests/tests/test_de.rs | 6 +- serde_tests/tests/token.rs | 55 +++++++++++++- 8 files changed, 366 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 6cd422cd..b78781ad 100644 --- a/README.md +++ b/README.md @@ -585,6 +585,8 @@ Annotations `serde_codegen` and `serde_macros` support annotations that help to customize how types are serialized. Here are the supported annotations: +Field Annotations: + | Annotation | Function | | ---------- | -------- | | `#[serde(rename(json="name1", xml="name2"))` | Serialize this field with the given name for the given formats | @@ -594,6 +596,13 @@ how types are serialized. Here are the supported annotations: | `#[serde(skip_serializing_if_empty)` | Do not serialize this value if `$value.is_empty()` is `true` | | `#[serde(skip_serializing_if_none)` | Do not serialize this value if `$value.is_none()` is `true` | +Structure Annotations: + +| Annotation | Function | +| ---------- | -------- | +| `#[serde(disallow_unknown)` | Always error during serialization when encountering unknown fields. When absent, unknown fields are ignored for self-describing formats like JSON. | + + Serialization Formats Using Serde ================================= diff --git a/serde/src/de/impls.rs b/serde/src/de/impls.rs index 68349ba3..59925707 100644 --- a/serde/src/de/impls.rs +++ b/serde/src/de/impls.rs @@ -896,6 +896,110 @@ impl Deserialize for Result where T: Deserialize, E: Deserialize { } } +/////////////////////////////////////////////////////////////////////////////// + +/// A target for deserializers that want to ignore data. Implements +/// Deserialize and silently eats data given to it. +pub struct IgnoredAny; + +impl Deserialize for IgnoredAny { + #[inline] + fn deserialize(deserializer: &mut D) -> Result + where D: Deserializer, + { + struct IgnoredAnyVisitor; + + impl Visitor for IgnoredAnyVisitor { + type Value = IgnoredAny; + + #[inline] + fn visit_bool(&mut self, _: bool) -> Result { + Ok(IgnoredAny) + } + + #[inline] + fn visit_i64(&mut self, _: i64) -> Result { + Ok(IgnoredAny) + } + + #[inline] + fn visit_u64(&mut self, _: u64) -> Result { + Ok(IgnoredAny) + } + + #[inline] + fn visit_f64(&mut self, _: f64) -> Result { + Ok(IgnoredAny) + } + + #[inline] + fn visit_str(&mut self, _: &str) -> Result + where E: Error, + { + Ok(IgnoredAny) + } + + #[inline] + fn visit_none(&mut self) -> Result { + Ok(IgnoredAny) + } + + #[inline] + fn visit_some(&mut self, _: &mut D) -> Result + where D: Deserializer, + { + Ok(IgnoredAny) + } + + #[inline] + fn visit_newtype_struct(&mut self, _: &mut D) -> Result + where D: Deserializer, + { + Ok(IgnoredAny) + } + + #[inline] + fn visit_unit(&mut self) -> Result { + Ok(IgnoredAny) + } + + #[inline] + fn visit_seq(&mut self, mut visitor: V) -> Result + where V: SeqVisitor, + { + while let Some(_) = try!(visitor.visit::()) { + // Gobble + } + + try!(visitor.end()); + Ok(IgnoredAny) + } + + #[inline] + fn visit_map(&mut self, mut visitor: V) -> Result + where V: MapVisitor, + { + while let Some((_, _)) = try!(visitor.visit::()) { + // Gobble + } + + try!(visitor.end()); + Ok(IgnoredAny) + } + + #[inline] + fn visit_bytes(&mut self, _: &[u8]) -> Result + where E: Error, + { + Ok(IgnoredAny) + } + } + + deserializer.deserialize(IgnoredAnyVisitor) + } +} + + /////////////////////////////////////////////////////////////////////////////// #[cfg(feature = "num-bigint")] diff --git a/serde_codegen/src/attr.rs b/serde_codegen/src/attr.rs index af65f6d1..3932dbb2 100644 --- a/serde_codegen/src/attr.rs +++ b/serde_codegen/src/attr.rs @@ -243,3 +243,66 @@ impl<'a> FieldAttrsBuilder<'a> { } } } + +/// Represents container (e.g. struct) attribute information +#[derive(Debug)] +pub struct ContainerAttrs { + disallow_unknown: bool, +} + +impl ContainerAttrs { + pub fn disallow_unknown(&self) -> bool { + self.disallow_unknown + } +} + +pub struct ContainerAttrsBuilder { + disallow_unknown: bool, +} + +impl ContainerAttrsBuilder { + pub fn new() -> ContainerAttrsBuilder { + ContainerAttrsBuilder { + disallow_unknown: false, + } + } + + pub fn attrs(self, attrs: &[ast::Attribute]) -> ContainerAttrsBuilder { + attrs.iter().fold(self, ContainerAttrsBuilder::attr) + } + + pub fn attr(self, attr: &ast::Attribute) -> ContainerAttrsBuilder { + match attr.node.value.node { + ast::MetaList(ref name, ref items) if name == &"serde" => { + attr::mark_used(&attr); + items.iter().fold(self, ContainerAttrsBuilder::meta_item) + } + _ => { + self + } + } + } + + pub fn meta_item(self, meta_item: &P) -> ContainerAttrsBuilder { + match meta_item.node { + ast::MetaWord(ref name) if name == &"disallow_unknown" => { + self.disallow_unknown() + } + _ => { + // Ignore unknown meta variables for now. + self + } + } + } + + pub fn disallow_unknown(mut self) -> ContainerAttrsBuilder { + self.disallow_unknown = true; + self + } + + pub fn build(self) -> ContainerAttrs { + ContainerAttrs { + disallow_unknown: self.disallow_unknown, + } + } +} diff --git a/serde_codegen/src/de.rs b/serde_codegen/src/de.rs index b257091a..82bc567d 100644 --- a/serde_codegen/src/de.rs +++ b/serde_codegen/src/de.rs @@ -14,7 +14,7 @@ use syntax::ext::base::{Annotatable, ExtCtxt}; use syntax::ext::build::AstBuilder; use syntax::ptr::P; -use attr; +use attr::{self, ContainerAttrs}; use field; pub fn expand_derive_deserialize( @@ -82,6 +82,8 @@ fn deserialize_body( impl_generics: &ast::Generics, ty: P, ) -> P { + let container_attrs = field::container_attrs(cx, item); + match item.node { ast::ItemStruct(ref variant_data, _) => { deserialize_item_struct( @@ -91,6 +93,7 @@ fn deserialize_body( impl_generics, ty, variant_data, + &container_attrs, ) } ast::ItemEnum(ref enum_def, _) => { @@ -101,6 +104,7 @@ fn deserialize_body( impl_generics, ty, enum_def, + &container_attrs, ) } _ => cx.bug("expected ItemStruct or ItemEnum in #[derive(Deserialize)]") @@ -114,6 +118,7 @@ fn deserialize_item_struct( impl_generics: &ast::Generics, ty: P, variant_data: &ast::VariantData, + container_attrs: &ContainerAttrs, ) -> P { match *variant_data { ast::VariantData::Unit(_) => { @@ -158,6 +163,7 @@ fn deserialize_item_struct( impl_generics, ty, fields, + container_attrs, ) } } @@ -461,6 +467,7 @@ fn deserialize_struct( impl_generics: &ast::Generics, ty: P, fields: &[ast::StructField], + container_attrs: &ContainerAttrs, ) -> P { let where_clause = &impl_generics.where_clause; @@ -485,6 +492,7 @@ fn deserialize_struct( builder, type_path.clone(), fields, + container_attrs ); let type_name = builder.expr().str(type_ident); @@ -525,6 +533,7 @@ fn deserialize_item_enum( impl_generics: &ast::Generics, ty: P, enum_def: &EnumDef, + container_attrs: &ContainerAttrs ) -> P { let where_clause = &impl_generics.where_clause; @@ -541,7 +550,8 @@ fn deserialize_item_enum( .default() .build() }) - .collect() + .collect(), + container_attrs, ); let variants_expr = builder.expr().addr_of().slice() @@ -557,6 +567,12 @@ fn deserialize_item_enum( const VARIANTS: &'static [&'static str] = $variants_expr; ).unwrap(); + let ignored_arm = if !container_attrs.disallow_unknown() { + Some(quote_arm!(cx, __Field::__ignore => { Err(::serde::de::Error::end_of_stream()) })) + } else { + None + }; + // Match arms to extract a variant from a string let variant_arms: Vec<_> = enum_def.variants.iter() .enumerate() @@ -572,10 +588,12 @@ fn deserialize_item_enum( impl_generics, ty.clone(), variant, + container_attrs, ); quote_arm!(cx, $variant_name => { $expr }) }) + .chain(ignored_arm.into_iter()) .collect(); let (visitor_item, visitor_ty, visitor_expr, visitor_generics) = @@ -616,6 +634,7 @@ fn deserialize_variant( generics: &ast::Generics, ty: P, variant: &ast::Variant, + container_attrs: &ContainerAttrs, ) -> P { let variant_ident = variant.node.name; @@ -652,6 +671,7 @@ fn deserialize_variant( generics, ty, fields, + container_attrs, ) } } @@ -708,6 +728,7 @@ fn deserialize_struct_variant( generics: &ast::Generics, ty: P, fields: &[ast::StructField], + container_attrs: &ContainerAttrs, ) -> P { let where_clause = &generics.where_clause; @@ -728,6 +749,7 @@ fn deserialize_struct_variant( builder, type_path, fields, + container_attrs, ); let (visitor_item, visitor_ty, visitor_expr, visitor_generics) = @@ -771,12 +793,20 @@ fn deserialize_field_visitor( cx: &ExtCtxt, builder: &aster::AstBuilder, field_attrs: Vec, + container_attrs: &ContainerAttrs, ) -> Vec> { // Create the field names for the fields. let field_idents: Vec = (0 .. field_attrs.len()) .map(|i| builder.id(format!("__field{}", i))) .collect(); + let ignore_variant = if !container_attrs.disallow_unknown() { + let skip_ident = builder.id("__ignore"); + Some(builder.variant(skip_ident).unit()) + } else { + None + }; + let field_enum = builder.item() .attr().allow(&["non_camel_case_types"]) .enum_("__Field") @@ -785,6 +815,7 @@ fn deserialize_field_visitor( builder.variant(field_ident).unit() }) ) + .with_variants(ignore_variant.into_iter()) .build(); let index_field_arms: Vec<_> = field_idents.iter() @@ -817,12 +848,18 @@ fn deserialize_field_visitor( }) .collect(); + let fallthrough_arm_expr = if !container_attrs.disallow_unknown() { + quote_expr!(cx, Ok(__Field::__ignore)) + } else { + quote_expr!(cx, Err(::serde::de::Error::unknown_field(value))) + }; + let str_body = if formats.is_empty() { // No formats specific attributes, so no match on format required quote_expr!(cx, match value { $default_field_arms - _ => { Err(::serde::de::Error::unknown_field(value)) } + _ => { $fallthrough_arm_expr } }) } else { let field_arms: Vec<_> = formats.iter() @@ -844,7 +881,7 @@ fn deserialize_field_visitor( match value { $arms _ => { - Err(::serde::de::Error::unknown_field(value)) + $fallthrough_arm_expr } }}) }) @@ -855,7 +892,7 @@ fn deserialize_field_visitor( $fmt_matches _ => match value { $default_field_arms - _ => { Err(::serde::de::Error::unknown_field(value)) } + _ => $fallthrough_arm_expr } } ) @@ -920,11 +957,13 @@ fn deserialize_struct_visitor( builder: &aster::AstBuilder, struct_path: ast::Path, fields: &[ast::StructField], + container_attrs: &ContainerAttrs, ) -> (Vec>, P, P) { let field_visitor = deserialize_field_visitor( cx, builder, field::struct_field_attrs(cx, builder, fields), + container_attrs ); let visit_map_expr = deserialize_map( @@ -932,6 +971,7 @@ fn deserialize_struct_visitor( builder, struct_path, fields, + container_attrs, ); let fields_expr = builder.expr().addr_of().slice() @@ -958,6 +998,7 @@ fn deserialize_map( builder: &aster::AstBuilder, struct_path: ast::Path, fields: &[ast::StructField], + container_attrs: &ContainerAttrs, ) -> P { // Create the field names for the fields. let field_names: Vec = (0 .. fields.len()) @@ -969,6 +1010,16 @@ fn deserialize_map( .map(|field_name| quote_stmt!(cx, let mut $field_name = None;).unwrap()) .collect(); + + // Visit ignored values to consume them + let ignored_arm = if !container_attrs.disallow_unknown() { + Some(quote_arm!(cx, + _ => { try!(visitor.visit_value::<::serde::de::impls::IgnoredAny>()); } + )) + } else { + None + }; + // Match arms to extract a value for a field. let value_arms: Vec = field_names.iter() .map(|field_name| { @@ -978,6 +1029,7 @@ fn deserialize_map( } ) }) + .chain(ignored_arm.into_iter()) .collect(); let extract_values: Vec> = field_names.iter() diff --git a/serde_codegen/src/field.rs b/serde_codegen/src/field.rs index bf1a522b..a0a992f5 100644 --- a/serde_codegen/src/field.rs +++ b/serde_codegen/src/field.rs @@ -2,7 +2,7 @@ use syntax::ast; use syntax::ext::base::ExtCtxt; use aster; -use attr::{FieldAttrs, FieldAttrsBuilder}; +use attr::{ContainerAttrs, ContainerAttrsBuilder, FieldAttrs, FieldAttrsBuilder}; pub fn struct_field_attrs( _cx: &ExtCtxt, @@ -15,3 +15,10 @@ pub fn struct_field_attrs( }) .collect() } + +pub fn container_attrs( + _cx: &ExtCtxt, + container: &ast::Item, +) -> ContainerAttrs { + ContainerAttrsBuilder::new().attrs(container.attrs()).build() +} \ No newline at end of file diff --git a/serde_tests/tests/test_annotations.rs b/serde_tests/tests/test_annotations.rs index 87e87960..f53f08e1 100644 --- a/serde_tests/tests/test_annotations.rs +++ b/serde_tests/tests/test_annotations.rs @@ -1,6 +1,13 @@ use std::default; -use token::{Token, assert_tokens, assert_ser_tokens, assert_de_tokens}; +use token::{ + Error, + Token, + assert_tokens, + assert_ser_tokens, + assert_de_tokens, + assert_de_tokens_error +}; #[derive(Debug, PartialEq, Serialize, Deserialize)] struct Default { @@ -9,6 +16,12 @@ struct Default { a2: i32, } +#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[serde(disallow_unknown)] +struct DisallowUnknown { + a1: i32, +} + #[derive(Debug, PartialEq, Serialize, Deserialize)] struct Rename { a1: i32, @@ -86,6 +99,59 @@ fn test_default() { ); } +#[test] +fn test_ignore_unknown() { + // 'Default' allows unknown. Basic smoke test of ignore... + assert_de_tokens( + &Default { a1: 1, a2: 2}, + vec![ + Token::StructStart("Default", Some(5)), + + Token::MapSep, + Token::Str("whoops1"), + Token::I32(2), + + Token::MapSep, + Token::Str("a1"), + Token::I32(1), + + Token::MapSep, + Token::Str("whoops2"), + Token::SeqStart(Some(1)), + Token::SeqSep, + Token::I32(2), + Token::SeqEnd, + + Token::MapSep, + Token::Str("a2"), + Token::I32(2), + + Token::MapSep, + Token::Str("whoops3"), + Token::I32(2), + + Token::MapEnd, + ] + ); + + assert_de_tokens_error::( + vec![ + Token::StructStart("DisallowUnknown", Some(2)), + + Token::MapSep, + Token::Str("a1"), + Token::I32(1), + + Token::MapSep, + Token::Str("whoops"), + Token::I32(2), + + Token::MapEnd, + ], + Error::UnknownFieldError("whoops".to_owned()) + ); +} + #[test] fn test_rename() { assert_tokens( diff --git a/serde_tests/tests/test_de.rs b/serde_tests/tests/test_de.rs index 5f847770..9e0b2f83 100644 --- a/serde_tests/tests/test_de.rs +++ b/serde_tests/tests/test_de.rs @@ -7,7 +7,7 @@ use num::rational::Ratio; use serde::de::{Deserializer, Visitor}; -use token::{Token, assert_de_tokens}; +use token::{Error, Token, assert_de_tokens, assert_de_tokens_ignore}; ////////////////////////////////////////////////////////////////////////// @@ -39,7 +39,11 @@ macro_rules! declare_test { #[test] fn $name() { $( + // Test ser/de roundtripping assert_de_tokens(&$value, $tokens); + + // Test that the tokens are ignorable + assert_de_tokens_ignore($tokens); )+ } } diff --git a/serde_tests/tests/token.rs b/serde_tests/tests/token.rs index fa4468b0..c9e6ef78 100644 --- a/serde_tests/tests/token.rs +++ b/serde_tests/tests/token.rs @@ -310,7 +310,7 @@ impl<'a, I> ser::Serializer for Serializer ////////////////////////////////////////////////////////////////////////////// #[derive(Clone, PartialEq, Debug)] -enum Error { +pub enum Error { SyntaxError, EndOfStreamError, UnknownFieldError(String), @@ -644,7 +644,7 @@ impl<'a, I> de::MapVisitor for DeserializerMapVisitor<'a, I> match self.de.tokens.peek() { Some(&Token::MapSep) => { self.de.tokens.next(); - self.len = self.len.map(|len| len - 1); + self.len = self.len.map(|len| if len > 0 { len - 1} else { 0 }); Ok(Some(try!(de::Deserialize::deserialize(self.de)))) } Some(&Token::MapEnd) => Ok(None), @@ -799,6 +799,57 @@ pub fn assert_de_tokens(value: &T, tokens: Vec>) assert_eq!(de.tokens.next(), None); } +// Expect an error deserializing tokens into a T +pub fn assert_de_tokens_error(tokens: Vec>, error: Error) + where T: de::Deserialize + PartialEq + fmt::Debug, +{ + let mut de = Deserializer::new(tokens.into_iter()); + let v: Result = de::Deserialize::deserialize(&mut de); + assert_eq!(v.as_ref(), Err(&error)); +} + +// Tests that the given token stream is ignorable when embedded in +// an otherwise normal struct +pub fn assert_de_tokens_ignore(ignorable_tokens: Vec>) { + #[derive(PartialEq, Debug, Deserialize)] + struct IgnoreBase { + a: i32, + } + + let expected = IgnoreBase{a: 1}; + + // Embed the tokens to be ignored in the normal token + // stream for an IgnoreBase type + let concated_tokens : Vec> = vec![ + Token::MapStart(Some(2)), + Token::MapSep, + Token::Str("a"), + Token::I32(1), + + Token::MapSep, + Token::Str("ignored") + ] + .into_iter() + .chain(ignorable_tokens.into_iter()) + .chain(vec![ + Token::MapEnd, + ].into_iter()) + .collect(); + + let mut de = Deserializer::new(concated_tokens.into_iter()); + let v: Result = de::Deserialize::deserialize(&mut de); + + // We run this test on every token stream for convenience, but + // some token streams don't make sense embedded as a map value, + // so we ignore those. SyntaxError is the real sign of trouble. + if let Err(Error::UnexpectedToken(_)) = v { + return; + } + + assert_eq!(v.as_ref(), Ok(&expected)); + assert_eq!(de.tokens.next(), None); +} + pub fn assert_tokens(value: &T, tokens: Vec>) where T: ser::Serialize + de::Deserialize + PartialEq + fmt::Debug, {