From 7a7a182ab6cb4b9ad4d7b0e332fb8ee010a6fc40 Mon Sep 17 00:00:00 2001 From: Mingun Date: Fri, 23 Oct 2020 16:12:06 +0500 Subject: [PATCH] Allow borrow for field identifiers --- serde/src/private/de.rs | 44 ++++++++--- serde_derive/src/de.rs | 127 +++++++++++++++++++------------- test_suite/tests/test_borrow.rs | 31 ++++++++ 3 files changed, 139 insertions(+), 63 deletions(-) diff --git a/serde/src/private/de.rs b/serde/src/private/de.rs index 27052540..76d260bc 100644 --- a/serde/src/private/de.rs +++ b/serde/src/private/de.rs @@ -1,7 +1,7 @@ use lib::*; use de::{Deserialize, DeserializeSeed, Deserializer, Error, IntoDeserializer, Visitor}; -use de::value::BytesDeserializer; +use de::value::{BytesDeserializer, BorrowedBytesDeserializer}; #[cfg(any(feature = "std", feature = "alloc"))] use de::{MapAccess, Unexpected}; @@ -2527,20 +2527,26 @@ mod content { //////////////////////////////////////////////////////////////////////////////// -// Like `IntoDeserializer` but also implemented for `&[u8]`. This is used for -// the newtype fallthrough case of `field_identifier`. -// -// #[derive(Deserialize)] -// #[serde(field_identifier)] -// enum F { -// A, -// B, -// Other(String), // deserialized using IdentifierDeserializer -// } +/// Like `IntoDeserializer` but also implemented for `&[u8]`. This is used for +/// the newtype fallthrough case of `field_identifier`. +/// +/// ```ignore +/// #[derive(Deserialize)] +/// #[serde(field_identifier)] +/// enum F { +/// A, +/// B, +/// Other(String), // deserialized using IdentifierDeserializer +/// } +/// ``` pub trait IdentifierDeserializer<'de, E: Error> { + /// Deserializer, that refers to data owned by deserializer type Deserializer: Deserializer<'de, Error = E>; + /// Deserializer, that borrows data from the input + type BorrowedDeserializer: Deserializer<'de, Error = E>; fn from(self) -> Self::Deserializer; + fn borrowed(self) -> Self::BorrowedDeserializer; } impl<'de, E> IdentifierDeserializer<'de, E> for u32 @@ -2548,23 +2554,34 @@ where E: Error, { type Deserializer = >::Deserializer; + type BorrowedDeserializer = >::Deserializer; fn from(self) -> Self::Deserializer { self.into_deserializer() } + + fn borrowed(self) -> Self::BorrowedDeserializer { + self.into_deserializer() + } } forward_deserializer!(ref StrDeserializer<'a>(&'a str) => visit_str); +forward_deserializer!(borrowed BorrowedStrDeserializer(&'de str) => visit_borrowed_str); impl<'a, E> IdentifierDeserializer<'a, E> for &'a str where E: Error, { type Deserializer = StrDeserializer<'a, E>; + type BorrowedDeserializer = BorrowedStrDeserializer<'a, E>; fn from(self) -> Self::Deserializer { StrDeserializer::new(self) } + + fn borrowed(self) -> Self::BorrowedDeserializer { + BorrowedStrDeserializer::new(self) + } } impl<'a, E> IdentifierDeserializer<'a, E> for &'a [u8] @@ -2572,10 +2589,15 @@ where E: Error, { type Deserializer = BytesDeserializer<'a, E>; + type BorrowedDeserializer = BorrowedBytesDeserializer<'a, E>; fn from(self) -> Self::Deserializer { BytesDeserializer::new(self) } + + fn borrowed(self) -> Self::BorrowedDeserializer { + BorrowedBytesDeserializer::new(self) + } } /// A DeserializeSeed helper for implementing deserialize_in_place Visitors. diff --git a/serde_derive/src/de.rs b/serde_derive/src/de.rs index 1f5733a6..704985de 100644 --- a/serde_derive/src/de.rs +++ b/serde_derive/src/de.rs @@ -1886,17 +1886,26 @@ fn deserialize_generated_identifier( let (ignore_variant, fallthrough) = if !is_variant && cattrs.has_flatten() { let ignore_variant = quote!(__other(_serde::private::de::Content<'de>),); let fallthrough = quote!(_serde::export::Ok(__Field::__other(__value))); - (Some(ignore_variant), Some(fallthrough)) + ( + Some(ignore_variant), + Some((fallthrough.clone(), fallthrough)) + ) } else if let Some(other_idx) = other_idx { let ignore_variant = fields[other_idx].1.clone(); let fallthrough = quote!(_serde::export::Ok(__Field::#ignore_variant)); - (None, Some(fallthrough)) + ( + None, + Some((fallthrough.clone(), fallthrough)) + ) } else if is_variant || cattrs.deny_unknown_fields() { (None, None) } else { let ignore_variant = quote!(__ignore,); let fallthrough = quote!(_serde::export::Ok(__Field::__ignore)); - (Some(ignore_variant), Some(fallthrough)) + ( + Some(ignore_variant), + Some((fallthrough.clone(), fallthrough)) + ) }; let visitor_impl = Stmts(deserialize_identifier( @@ -1959,16 +1968,27 @@ fn deserialize_custom_identifier( if last.attrs.other() { let ordinary = &variants[..variants.len() - 1]; let fallthrough = quote!(_serde::export::Ok(#this::#last_ident)); - (ordinary, Some(fallthrough)) + ( + ordinary, + Some((fallthrough.clone(), fallthrough)) + ) } else if let Style::Newtype = last.style { let ordinary = &variants[..variants.len() - 1]; - let deserializer = quote!(_serde::private::de::IdentifierDeserializer::from(__value)); - let fallthrough = quote! { + + let fallthrough = |method| quote! { _serde::export::Result::map( - _serde::Deserialize::deserialize(#deserializer), + _serde::Deserialize::deserialize( + _serde::private::de::IdentifierDeserializer::#method(__value) + ), #this::#last_ident) }; - (ordinary, Some(fallthrough)) + ( + ordinary, + Some(( + fallthrough(quote!(from)), + fallthrough(quote!(borrowed)), + )) + ) } else { (variants, None) } @@ -2040,7 +2060,8 @@ fn deserialize_identifier( this: &TokenStream, fields: &[(String, Ident, Vec)], is_variant: bool, - fallthrough: Option, + // .0 for referenced data, .1 -- for borrowed + fallthrough: Option<(TokenStream, TokenStream)>, collect_other_fields: bool, ) -> Fragment { let mut flat_fields = Vec::new(); @@ -2048,14 +2069,11 @@ fn deserialize_identifier( flat_fields.extend(aliases.iter().map(|alias| (alias, ident))) } - let field_strs = flat_fields.iter().map(|(name, _)| name); - let field_borrowed_strs = flat_fields.iter().map(|(name, _)| name); - let field_bytes = flat_fields + let field_strs: &Vec<_> = &flat_fields.iter().map(|(name, _)| name).collect(); + let field_bytes: &Vec<_> = &flat_fields .iter() - .map(|(name, _)| Literal::byte_string(name.as_bytes())); - let field_borrowed_bytes = flat_fields - .iter() - .map(|(name, _)| Literal::byte_string(name.as_bytes())); + .map(|(name, _)| Literal::byte_string(name.as_bytes())) + .collect(); let constructors: &Vec<_> = &flat_fields .iter() @@ -2106,16 +2124,21 @@ fn deserialize_identifier( (None, None, None, None) }; - let fallthrough_arm = if let Some(fallthrough) = fallthrough { + let ( + fallthrough_arm, + fallthrough_borrowed_arm, + ) = if let Some(fallthrough) = fallthrough.clone() { fallthrough } else if is_variant { - quote! { + let fallthrough = quote! { _serde::export::Err(_serde::de::Error::unknown_variant(__value, VARIANTS)) - } + }; + (fallthrough.clone(), fallthrough) } else { - quote! { + let fallthrough = quote! { _serde::export::Err(_serde::de::Error::unknown_field(__value, FIELDS)) - } + }; + (fallthrough.clone(), fallthrough) }; let variant_indices = 0_u64..; @@ -2212,37 +2235,6 @@ fn deserialize_identifier( { _serde::export::Ok(__Field::__other(_serde::private::de::Content::Unit)) } - - fn visit_borrowed_str<__E>(self, __value: &'de str) -> _serde::export::Result - where - __E: _serde::de::Error, - { - match __value { - #( - #field_borrowed_strs => _serde::export::Ok(#constructors), - )* - _ => { - #value_as_borrowed_str_content - #fallthrough_arm - } - } - } - - fn visit_borrowed_bytes<__E>(self, __value: &'de [u8]) -> _serde::export::Result - where - __E: _serde::de::Error, - { - match __value { - #( - #field_borrowed_bytes => _serde::export::Ok(#constructors), - )* - _ => { - #bytes_to_str - #value_as_borrowed_bytes_content - #fallthrough_arm - } - } - } } } else { quote! { @@ -2285,6 +2277,21 @@ fn deserialize_identifier( } } + fn visit_borrowed_str<__E>(self, __value: &'de str) -> _serde::export::Result + where + __E: _serde::de::Error, + { + match __value { + #( + #field_strs => _serde::export::Ok(#constructors), + )* + _ => { + #value_as_borrowed_str_content + #fallthrough_borrowed_arm + } + } + } + fn visit_bytes<__E>(self, __value: &[u8]) -> _serde::export::Result where __E: _serde::de::Error, @@ -2300,6 +2307,22 @@ fn deserialize_identifier( } } } + + fn visit_borrowed_bytes<__E>(self, __value: &'de [u8]) -> _serde::export::Result + where + __E: _serde::de::Error, + { + match __value { + #( + #field_bytes => _serde::export::Ok(#constructors), + )* + _ => { + #bytes_to_str + #value_as_borrowed_bytes_content + #fallthrough_borrowed_arm + } + } + } } } diff --git a/test_suite/tests/test_borrow.rs b/test_suite/tests/test_borrow.rs index e76ede69..15139cb9 100644 --- a/test_suite/tests/test_borrow.rs +++ b/test_suite/tests/test_borrow.rs @@ -90,6 +90,37 @@ fn test_struct() { ); } +#[test] +fn test_field_identifier() { + #[derive(Deserialize, Debug, PartialEq)] + #[serde(field_identifier)] + enum FieldStr<'a> { + #[serde(borrow)] + Str(&'a str), + } + + assert_de_tokens( + &FieldStr::Str("value"), + &[ + Token::BorrowedStr("value"), + ], + ); + + #[derive(Deserialize, Debug, PartialEq)] + #[serde(field_identifier)] + enum FieldBytes<'a> { + #[serde(borrow)] + Bytes(&'a [u8]), + } + + assert_de_tokens( + &FieldBytes::Bytes(b"value"), + &[ + Token::BorrowedBytes(b"value"), + ], + ); +} + #[test] fn test_cow() { #[derive(Deserialize)]