From 886670134a5382bc181cd44b93592712e7c01076 Mon Sep 17 00:00:00 2001 From: David Tolnay Date: Wed, 13 Apr 2016 00:34:29 -0700 Subject: [PATCH] feat(codegen): Infer Default and Deserialize bounds correctly --- serde_codegen/src/attr.rs | 19 ++-- serde_codegen/src/bound.rs | 135 ++++++++++++++++++++++++++ serde_codegen/src/de.rs | 94 +++++++++++++++--- serde_codegen/src/lib.rs.in | 1 + serde_codegen/src/ser.rs | 129 +----------------------- serde_tests/tests/test_annotations.rs | 96 +++++++++++++++++- 6 files changed, 325 insertions(+), 149 deletions(-) create mode 100644 serde_codegen/src/bound.rs diff --git a/serde_codegen/src/attr.rs b/serde_codegen/src/attr.rs index 6caf91a2..10f2068c 100644 --- a/serde_codegen/src/attr.rs +++ b/serde_codegen/src/attr.rs @@ -180,7 +180,7 @@ pub struct FieldAttrs { skip_serializing_field_if: Option>, default_expr_if_missing: Option>, serialize_with: Option>, - deserialize_with: Option>, + deserialize_with: P, } impl FieldAttrs { @@ -197,6 +197,8 @@ impl FieldAttrs { None => { cx.span_bug(field.span, "struct field has no name?") } }; + let identity = quote_expr!(cx, |x| x); + let mut field_attrs = FieldAttrs { name: Name::new(field_ident), skip_serializing_field: false, @@ -204,7 +206,7 @@ impl FieldAttrs { skip_serializing_field_if: None, default_expr_if_missing: None, serialize_with: None, - deserialize_with: None, + deserialize_with: identity, }; for meta_items in field.attrs.iter().filter_map(get_serde_meta_items) { @@ -292,7 +294,7 @@ impl FieldAttrs { try!(parse_lit_into_path(cx, name, lit)), ); - field_attrs.deserialize_with = Some(expr); + field_attrs.deserialize_with = expr; } _ => { @@ -346,8 +348,8 @@ impl FieldAttrs { self.serialize_with.as_ref() } - pub fn deserialize_with(&self) -> Option<&P> { - self.deserialize_with.as_ref() + pub fn deserialize_with(&self) -> &P { + &self.deserialize_with } } @@ -626,7 +628,7 @@ fn wrap_deserialize_with(cx: &ExtCtxt, let where_clause = &generics.where_clause; - quote_expr!(cx, { + quote_expr!(cx, ({ struct __SerdeDeserializeWithStruct $generics $where_clause { value: $field_ty, } @@ -640,7 +642,6 @@ fn wrap_deserialize_with(cx: &ExtCtxt, } } - let value: $ty_path = try!(visitor.visit_value()); - Ok(value.value) - }) + |visit: $ty_path| visit.value + })) } diff --git a/serde_codegen/src/bound.rs b/serde_codegen/src/bound.rs new file mode 100644 index 00000000..db7c5a79 --- /dev/null +++ b/serde_codegen/src/bound.rs @@ -0,0 +1,135 @@ +use std::collections::HashSet; + +use aster::AstBuilder; + +use syntax::ast; +use syntax::ext::base::ExtCtxt; +use syntax::ptr::P; +use syntax::visit; + +pub fn with_bound( + cx: &ExtCtxt, + builder: &AstBuilder, + item: &ast::Item, + generics: &ast::Generics, + filter: &Fn(&ast::StructField) -> bool, + bound: &[&'static str], +) -> ast::Generics { + let path = builder.path().global().ids(bound).build(); + + builder.from_generics(generics.clone()) + .with_predicates( + all_variants(cx, item).iter() + .flat_map(|variant_data| all_struct_fields(variant_data)) + .filter(|field| filter(field)) + .map(|field| &field.ty) + // TODO this filter can be removed later, see comment on function + .filter(|ty| contains_generic(ty, generics)) + .map(|ty| strip_reference(ty)) + .map(|ty| builder.where_predicate() + // the type that is being bounded e.g. T + .bound().build(ty.clone()) + // the bound e.g. Serialize + .bound().trait_(path.clone()).build() + .build())) + .build() +} + +fn all_variants<'a>(cx: &ExtCtxt, item: &'a ast::Item) -> Vec<&'a ast::VariantData> { + match item.node { + ast::ItemKind::Struct(ref variant_data, _) => { + vec![variant_data] + } + ast::ItemKind::Enum(ref enum_def, _) => { + enum_def.variants.iter() + .map(|variant| &variant.node.data) + .collect() + } + _ => { + cx.span_bug(item.span, "expected Item to be Struct or Enum"); + } + } +} + +fn all_struct_fields(variant_data: &ast::VariantData) -> &[ast::StructField] { + match *variant_data { + ast::VariantData::Struct(ref fields, _) | + ast::VariantData::Tuple(ref fields, _) => { + fields + } + ast::VariantData::Unit(_) => { + &[] + } + } +} + +// Rust <1.7 enforces that `where` clauses involve generic type parameters. The +// corresponding compiler error is E0193. It is no longer enforced in Rust >=1.7 +// so this filtering can be removed in the future when we stop supporting <1.7. +// +// E0193 means we must not generate a `where` clause like `i32: Serialize` +// because even though i32 implements Serialize, i32 is not a generic type +// parameter. Clauses like `T: Serialize` and `Option: Serialize` are okay. +// This function decides whether a given type references any of the generic type +// parameters in the input `Generics`. +fn contains_generic(ty: &ast::Ty, generics: &ast::Generics) -> bool { + struct FindGeneric<'a> { + generic_names: &'a HashSet, + found_generic: bool, + } + impl<'a, 'v> visit::Visitor<'v> for FindGeneric<'a> { + fn visit_path(&mut self, path: &'v ast::Path, _id: ast::NodeId) { + if !path.global + && path.segments.len() == 1 + && self.generic_names.contains(&path.segments[0].identifier.name) { + self.found_generic = true; + } else { + visit::walk_path(self, path); + } + } + } + + let generic_names: HashSet<_> = generics.ty_params.iter() + .map(|ty_param| ty_param.ident.name) + .collect(); + + let mut visitor = FindGeneric { + generic_names: &generic_names, + found_generic: false, + }; + visit::walk_ty(&mut visitor, ty); + visitor.found_generic +} + +// This is required to handle types that use both a reference and a value of +// the same type, as in: +// +// enum Test<'a, T> where T: 'a { +// Lifetime(&'a T), +// NoLifetime(T), +// } +// +// Preserving references, we would generate an impl like: +// +// impl<'a, T> Serialize for Test<'a, T> +// where &'a T: Serialize, +// T: Serialize { ... } +// +// And taking a reference to one of the elements would fail with: +// +// error: cannot infer an appropriate lifetime for pattern due +// to conflicting requirements [E0495] +// Test::NoLifetime(ref v) => { ... } +// ^~~~~ +// +// Instead, we strip references before adding `T: Serialize` bounds in order to +// generate: +// +// impl<'a, T> Serialize for Test<'a, T> +// where T: Serialize { ... } +fn strip_reference(ty: &P) -> &P { + match ty.node { + ast::TyKind::Rptr(_, ref mut_ty) => &mut_ty.ty, + _ => ty + } +} diff --git a/serde_codegen/src/de.rs b/serde_codegen/src/de.rs index 4b7b4623..dd649841 100644 --- a/serde_codegen/src/de.rs +++ b/serde_codegen/src/de.rs @@ -14,6 +14,7 @@ use syntax::parse::token::InternedString; use syntax::ptr::P; use attr; +use bound; use error::Error; pub fn expand_derive_deserialize( @@ -46,11 +47,7 @@ pub fn expand_derive_deserialize( } }; - let impl_generics = builder.from_generics(generics.clone()) - .add_ty_param_bound( - builder.path().global().ids(&["serde", "de", "Deserialize"]).build() - ) - .build(); + let impl_generics = build_impl_generics(cx, &builder, item, generics); let ty = builder.ty().path() .segment(item.ident).with_generics(impl_generics.clone()).build() @@ -79,6 +76,82 @@ pub fn expand_derive_deserialize( push(Annotatable::Item(impl_item)) } +// All the generics in the input, plus a bound `T: Deserialize` for each generic +// field type that will be deserialized by us, plus a bound `T: Default` for +// each generic field type that will be set to a default value. +fn build_impl_generics( + cx: &ExtCtxt, + builder: &aster::AstBuilder, + item: &Item, + generics: &ast::Generics, +) -> ast::Generics { + let generics = bound::with_bound(cx, builder, item, generics, + &deserialized_by_us, + &["serde", "de", "Deserialize"]); + let generics = bound::with_bound(cx, builder, item, &generics, + &requires_default, + &["std", "default", "Default"]); + generics +} + +// Fields with a `skip_deserializing` or `deserialize_with` attribute are not +// deserialized by us. All other fields may need a `T: Deserialize` bound where +// T is the type of the field. +fn deserialized_by_us(field: &ast::StructField) -> bool { + for meta_items in field.attrs.iter().filter_map(attr::get_serde_meta_items) { + for meta_item in meta_items { + match meta_item.node { + ast::MetaItemKind::Word(ref name) if name == &"skip_deserializing" => { + return false + } + ast::MetaItemKind::NameValue(ref name, _) if name == &"deserialize_with" => { + // TODO: For now we require `T: Deserialize` even if the + // field has `deserialize_with`. The reason is the signature + // of serde::de::MapVisitor::missing_field which looks like: + // + // fn missing_field(...) -> Result where T: Deserialize + // + // So in order to use missing_field, the type must have the + // `T: Deserialize` bound. Some formats rely on this bound + // because they treat missing fields as unit. + // + // Long-term the fix would be to change the signature of + // missing_field so it can, for example, use the + // `deserialize_with` function to visit a unit in place of + // the missing field. + // + // See https://github.com/serde-rs/serde/issues/259 + } + _ => {} + } + } + } + true +} + +// Fields with a `default` attribute (not `default=...`), and fields with a +// `skip_deserializing` attribute that do not also have `default=...`. +fn requires_default(field: &ast::StructField) -> bool { + let mut has_skip_deserializing = false; + for meta_items in field.attrs.iter().filter_map(attr::get_serde_meta_items) { + for meta_item in meta_items { + match meta_item.node { + ast::MetaItemKind::Word(ref name) if name == &"default" => { + return true + } + ast::MetaItemKind::NameValue(ref name, _) if name == &"default" => { + return false + } + ast::MetaItemKind::Word(ref name) if name == &"skip_deserializing" => { + has_skip_deserializing = true + } + _ => {} + } + } + } + has_skip_deserializing +} + fn deserialize_body( cx: &ExtCtxt, builder: &aster::AstBuilder, @@ -442,9 +515,10 @@ fn deserialize_struct_as_seq( let $name = $default; ).unwrap() } else { + let deserialize_with = attrs.deserialize_with(); quote_stmt!(cx, let $name = match try!(visitor.visit()) { - Some(value) => { value }, + Some(value) => { $deserialize_with(value) }, None => { return Err(::serde::de::Error::end_of_stream()); } @@ -1040,14 +1114,10 @@ fn deserialize_map( let value_arms = fields_attrs_names.iter() .filter(|&&(_, ref attrs, _)| !attrs.skip_deserializing_field()) .map(|&(_, ref attrs, name)| { - let expr = match attrs.deserialize_with() { - Some(expr) => expr.clone(), - None => quote_expr!(cx, visitor.visit_value()), - }; - + let deserialize_with = attrs.deserialize_with(); quote_arm!(cx, __Field::$name => { - $name = Some(try!($expr)); + $name = Some($deserialize_with(try!(visitor.visit_value()))); } ) }) diff --git a/serde_codegen/src/lib.rs.in b/serde_codegen/src/lib.rs.in index 9c7c34fb..66c25338 100644 --- a/serde_codegen/src/lib.rs.in +++ b/serde_codegen/src/lib.rs.in @@ -1,4 +1,5 @@ mod attr; +mod bound; mod de; mod error; mod ser; diff --git a/serde_codegen/src/ser.rs b/serde_codegen/src/ser.rs index 955a7165..ed79b542 100644 --- a/serde_codegen/src/ser.rs +++ b/serde_codegen/src/ser.rs @@ -1,5 +1,3 @@ -use std::collections::HashSet; - use aster; use syntax::ast::{ @@ -10,11 +8,10 @@ use syntax::ast::{ use syntax::ast; use syntax::codemap::Span; use syntax::ext::base::{Annotatable, ExtCtxt}; -use syntax::ext::build::AstBuilder; use syntax::ptr::P; -use syntax::visit; use attr; +use bound; use error::Error; pub fn expand_derive_serialize( @@ -96,56 +93,9 @@ fn build_impl_generics( item: &Item, generics: &ast::Generics, ) -> ast::Generics { - let serialize_path = builder.path() - .global() - .ids(&["serde", "ser", "Serialize"]) - .build(); - - builder.from_generics(generics.clone()) - .with_predicates( - all_variants(cx, item).iter() - .flat_map(|variant_data| all_struct_fields(variant_data)) - .filter(|field| serialized_by_us(field)) - .map(|field| &field.ty) - // TODO this filter can be removed later, see comment on function - .filter(|ty| contains_generic(ty, generics)) - .map(|ty| strip_reference(ty)) - .map(|ty| builder.where_predicate() - // the type that is being bounded i.e. T - .bound().build(ty.clone()) - // the bound i.e. Serialize - .bound().trait_(serialize_path.clone()).build() - .build())) - .build() -} - -fn all_variants<'a>(cx: &ExtCtxt, item: &'a Item) -> Vec<&'a ast::VariantData> { - match item.node { - ast::ItemKind::Struct(ref variant_data, _) => { - vec![variant_data] - } - ast::ItemKind::Enum(ref enum_def, _) => { - enum_def.variants.iter() - .map(|variant| &variant.node.data) - .collect() - } - _ => { - cx.span_bug(item.span, - "expected Item to be Struct or Enum in #[derive(Serialize)]"); - } - } -} - -fn all_struct_fields(variant_data: &ast::VariantData) -> &[ast::StructField] { - match *variant_data { - ast::VariantData::Struct(ref fields, _) | - ast::VariantData::Tuple(ref fields, _) => { - fields - } - ast::VariantData::Unit(_) => { - &[] - } - } + bound::with_bound(cx, builder, item, generics, + &serialized_by_us, + &["serde", "ser", "Serialize"]) } // Fields with a `skip_serializing` or `serialize_with` attribute are not @@ -168,77 +118,6 @@ fn serialized_by_us(field: &ast::StructField) -> bool { true } -// Rust <1.7 enforces that `where` clauses involve generic type parameters. The -// corresponding compiler error is E0193. It is no longer enforced in Rust >=1.7 -// so this filtering can be removed in the future when we stop supporting <1.7. -// -// E0193 means we must not generate a `where` clause like `i32: Serialize` -// because even though i32 implements Serialize, i32 is not a generic type -// parameter. Clauses like `T: Serialize` and `Option: Serialize` are okay. -// This function decides whether a given type references any of the generic type -// parameters in the input `Generics`. -fn contains_generic(ty: &ast::Ty, generics: &ast::Generics) -> bool { - struct FindGeneric<'a> { - generic_names: &'a HashSet, - found_generic: bool, - } - impl<'a, 'v> visit::Visitor<'v> for FindGeneric<'a> { - fn visit_path(&mut self, path: &'v ast::Path, _id: ast::NodeId) { - if !path.global - && path.segments.len() == 1 - && self.generic_names.contains(&path.segments[0].identifier.name) { - self.found_generic = true; - } else { - visit::walk_path(self, path); - } - } - } - - let generic_names: HashSet<_> = generics.ty_params.iter() - .map(|ty_param| ty_param.ident.name) - .collect(); - - let mut visitor = FindGeneric { - generic_names: &generic_names, - found_generic: false, - }; - visit::walk_ty(&mut visitor, ty); - visitor.found_generic -} - -// This is required to handle types that use both a reference and a value of -// the same type, as in: -// -// enum Test<'a, T> where T: 'a { -// Lifetime(&'a T), -// NoLifetime(T), -// } -// -// Preserving references, we would generate an impl like: -// -// impl<'a, T> Serialize for Test<'a, T> -// where &'a T: Serialize, -// T: Serialize { ... } -// -// And taking a reference to one of the elements would fail with: -// -// error: cannot infer an appropriate lifetime for pattern due -// to conflicting requirements [E0495] -// Test::NoLifetime(ref v) => { ... } -// ^~~~~ -// -// Instead, we strip references before adding `T: Serialize` bounds in order to -// generate: -// -// impl<'a, T> Serialize for Test<'a, T> -// where T: Serialize { ... } -fn strip_reference(ty: &P) -> &P { - match ty.node { - ast::TyKind::Rptr(_, ref mut_ty) => &mut_ty.ty, - _ => ty - } -} - fn serialize_body( cx: &ExtCtxt, builder: &aster::AstBuilder, diff --git a/serde_tests/tests/test_annotations.rs b/serde_tests/tests/test_annotations.rs index 5d25af74..865105e7 100644 --- a/serde_tests/tests/test_annotations.rs +++ b/serde_tests/tests/test_annotations.rs @@ -1,4 +1,3 @@ -use std::default::Default; use serde::{Serialize, Serializer, Deserialize, Deserializer}; use token::{ @@ -61,7 +60,7 @@ impl DeserializeWith for i32 { } #[derive(Debug, PartialEq, Serialize, Deserialize)] -struct DefaultStruct +struct DefaultStruct where C: MyDefault, E: MyDefault, { @@ -122,7 +121,7 @@ fn test_default_struct() { } #[derive(Debug, PartialEq, Serialize, Deserialize)] -enum DefaultEnum +enum DefaultEnum where C: MyDefault, E: MyDefault { @@ -184,6 +183,97 @@ fn test_default_enum() { ); } +// Does not implement std::default::Default. +#[derive(Debug, PartialEq, Deserialize)] +struct NoStdDefault(i8); + +impl MyDefault for NoStdDefault { + fn my_default() -> Self { + NoStdDefault(123) + } +} + +#[derive(Debug, PartialEq, Deserialize)] +struct ContainsNoStdDefault { + #[serde(default="MyDefault::my_default")] + a: A, +} + +// Tests that a struct field does not need to implement std::default::Default if +// it is annotated with `default=...`. +#[test] +fn test_no_std_default() { + assert_de_tokens( + &ContainsNoStdDefault { a: NoStdDefault(123) }, + vec![ + Token::StructStart("ContainsNoStdDefault", Some(1)), + Token::StructEnd, + ] + ); + + assert_de_tokens( + &ContainsNoStdDefault { a: NoStdDefault(8) }, + vec![ + Token::StructStart("ContainsNoStdDefault", Some(1)), + + Token::StructSep, + Token::Str("a"), + Token::StructNewType("NoStdDefault"), + Token::I8(8), + + Token::StructEnd, + ] + ); +} + +// Does not implement Deserialize. +#[derive(Debug, PartialEq)] +struct NotDeserializeStruct(i8); + +impl Default for NotDeserializeStruct { + fn default() -> Self { + NotDeserializeStruct(123) + } +} + +// Does not implement Deserialize. +#[derive(Debug, PartialEq)] +enum NotDeserializeEnum { Trouble } + +impl MyDefault for NotDeserializeEnum { + fn my_default() -> Self { + NotDeserializeEnum::Trouble + } +} + +#[derive(Debug, PartialEq, Deserialize)] +struct ContainsNotDeserialize { + #[serde(skip_deserializing)] + a: A, + #[serde(skip_deserializing, default)] + b: B, + #[serde(skip_deserializing, default="MyDefault::my_default")] + c: C, +} + +// Tests that a struct field does not need to implement Deserialize if it is +// annotated with skip_deserializing, whether using the std Default or a +// custom default. +#[test] +fn test_elt_not_deserialize() { + assert_de_tokens( + &ContainsNotDeserialize { + a: NotDeserializeStruct(123), + b: NotDeserializeStruct(123), + c: NotDeserializeEnum::Trouble, + }, + vec![ + Token::StructStart("ContainsNotDeserialize", Some(3)), + Token::StructEnd, + ] + ); +} + #[derive(Debug, PartialEq, Serialize, Deserialize)] #[serde(deny_unknown_fields)] struct DenyUnknown {