From 4ff93398fd275d1ee9f31ee7d92ae59bf9559f60 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Mon, 10 Jul 2023 15:19:00 +0200 Subject: [PATCH 1/3] Skip buildin subtrees for builtin derives --- crates/hir-expand/src/builtin_derive_macro.rs | 86 ++++++----- crates/hir-expand/src/db.rs | 94 ++++++++---- crates/mbe/src/lib.rs | 3 +- crates/mbe/src/syntax_bridge.rs | 136 ++++++++++++++++++ 4 files changed, 243 insertions(+), 76 deletions(-) diff --git a/crates/hir-expand/src/builtin_derive_macro.rs b/crates/hir-expand/src/builtin_derive_macro.rs index 78218a83452..c8c3cebb88e 100644 --- a/crates/hir-expand/src/builtin_derive_macro.rs +++ b/crates/hir-expand/src/builtin_derive_macro.rs @@ -12,9 +12,7 @@ use crate::{ name::{AsName, Name}, tt::{self, TokenId}, }; -use syntax::ast::{ - self, AstNode, FieldList, HasAttrs, HasGenericParams, HasModuleItem, HasName, HasTypeBounds, -}; +use syntax::ast::{self, AstNode, FieldList, HasAttrs, HasGenericParams, HasName, HasTypeBounds}; use crate::{db::ExpandDatabase, name, quote, ExpandError, ExpandResult, MacroCallId}; @@ -30,12 +28,13 @@ macro_rules! register_builtin { &self, db: &dyn ExpandDatabase, id: MacroCallId, - tt: &tt::Subtree, + tt: &ast::Adt, + token_map: &TokenMap, ) -> ExpandResult { let expander = match *self { $( BuiltinDeriveExpander::$trait => $expand, )* }; - expander(db, id, tt) + expander(db, id, tt, token_map) } fn find_by_name(name: &name::Name) -> Option { @@ -118,13 +117,13 @@ impl VariantShape { } } - fn from(value: Option, token_map: &TokenMap) -> Result { + fn from(tm: &TokenMap, value: Option) -> Result { let r = match value { None => VariantShape::Unit, Some(FieldList::RecordFieldList(it)) => VariantShape::Struct( it.fields() .map(|it| it.name()) - .map(|it| name_to_token(token_map, it)) + .map(|it| name_to_token(tm, it)) .collect::>()?, ), Some(FieldList::TupleFieldList(it)) => VariantShape::Tuple(it.fields().count()), @@ -190,25 +189,12 @@ struct BasicAdtInfo { associated_types: Vec, } -fn parse_adt(tt: &tt::Subtree) -> Result { - let (parsed, token_map) = mbe::token_tree_to_syntax_node(tt, mbe::TopEntryPoint::MacroItems); - let macro_items = ast::MacroItems::cast(parsed.syntax_node()).ok_or_else(|| { - debug!("derive node didn't parse"); - ExpandError::other("invalid item definition") - })?; - let item = macro_items.items().next().ok_or_else(|| { - debug!("no module item parsed"); - ExpandError::other("no item found") - })?; - let adt = ast::Adt::cast(item.syntax().clone()).ok_or_else(|| { - debug!("expected adt, found: {:?}", item); - ExpandError::other("expected struct, enum or union") - })?; +fn parse_adt(tm: &TokenMap, adt: &ast::Adt) -> Result { let (name, generic_param_list, shape) = match &adt { ast::Adt::Struct(it) => ( it.name(), it.generic_param_list(), - AdtShape::Struct(VariantShape::from(it.field_list(), &token_map)?), + AdtShape::Struct(VariantShape::from(tm, it.field_list())?), ), ast::Adt::Enum(it) => { let default_variant = it @@ -227,8 +213,8 @@ fn parse_adt(tt: &tt::Subtree) -> Result { .flat_map(|it| it.variants()) .map(|it| { Ok(( - name_to_token(&token_map, it.name())?, - VariantShape::from(it.field_list(), &token_map)?, + name_to_token(tm, it.name())?, + VariantShape::from(tm, it.field_list())?, )) }) .collect::>()?, @@ -298,7 +284,7 @@ fn parse_adt(tt: &tt::Subtree) -> Result { }) .map(|it| mbe::syntax_node_to_token_tree(it.syntax()).0) .collect(); - let name_token = name_to_token(&token_map, name)?; + let name_token = name_to_token(&tm, name)?; Ok(BasicAdtInfo { name: name_token, shape, param_types, associated_types }) } @@ -345,11 +331,12 @@ fn name_to_token(token_map: &TokenMap, name: Option) -> Result tt::Subtree, ) -> ExpandResult { - let info = match parse_adt(tt) { + let info = match parse_adt(tm, tt) { Ok(info) => info, Err(e) => return ExpandResult::new(tt::Subtree::empty(), e), }; @@ -405,19 +392,21 @@ fn find_builtin_crate(db: &dyn ExpandDatabase, id: MacroCallId) -> tt::TokenTree fn copy_expand( db: &dyn ExpandDatabase, id: MacroCallId, - tt: &tt::Subtree, + tt: &ast::Adt, + tm: &TokenMap, ) -> ExpandResult { let krate = find_builtin_crate(db, id); - expand_simple_derive(tt, quote! { #krate::marker::Copy }, |_| quote! {}) + expand_simple_derive(tt, tm, quote! { #krate::marker::Copy }, |_| quote! {}) } fn clone_expand( db: &dyn ExpandDatabase, id: MacroCallId, - tt: &tt::Subtree, + tt: &ast::Adt, + tm: &TokenMap, ) -> ExpandResult { let krate = find_builtin_crate(db, id); - expand_simple_derive(tt, quote! { #krate::clone::Clone }, |adt| { + expand_simple_derive(tt, tm, quote! { #krate::clone::Clone }, |adt| { if matches!(adt.shape, AdtShape::Union) { let star = tt::Punct { char: '*', @@ -479,10 +468,11 @@ fn and_and() -> ::tt::Subtree { fn default_expand( db: &dyn ExpandDatabase, id: MacroCallId, - tt: &tt::Subtree, + tt: &ast::Adt, + tm: &TokenMap, ) -> ExpandResult { let krate = &find_builtin_crate(db, id); - expand_simple_derive(tt, quote! { #krate::default::Default }, |adt| { + expand_simple_derive(tt, tm, quote! { #krate::default::Default }, |adt| { let body = match &adt.shape { AdtShape::Struct(fields) => { let name = &adt.name; @@ -518,10 +508,11 @@ fn default_expand( fn debug_expand( db: &dyn ExpandDatabase, id: MacroCallId, - tt: &tt::Subtree, + tt: &ast::Adt, + tm: &TokenMap, ) -> ExpandResult { let krate = &find_builtin_crate(db, id); - expand_simple_derive(tt, quote! { #krate::fmt::Debug }, |adt| { + expand_simple_derive(tt, tm, quote! { #krate::fmt::Debug }, |adt| { let for_variant = |name: String, v: &VariantShape| match v { VariantShape::Struct(fields) => { let for_fields = fields.iter().map(|it| { @@ -598,10 +589,11 @@ fn debug_expand( fn hash_expand( db: &dyn ExpandDatabase, id: MacroCallId, - tt: &tt::Subtree, + tt: &ast::Adt, + tm: &TokenMap, ) -> ExpandResult { let krate = &find_builtin_crate(db, id); - expand_simple_derive(tt, quote! { #krate::hash::Hash }, |adt| { + expand_simple_derive(tt, tm, quote! { #krate::hash::Hash }, |adt| { if matches!(adt.shape, AdtShape::Union) { // FIXME: Return expand error here return quote! {}; @@ -646,19 +638,21 @@ fn hash_expand( fn eq_expand( db: &dyn ExpandDatabase, id: MacroCallId, - tt: &tt::Subtree, + tt: &ast::Adt, + tm: &TokenMap, ) -> ExpandResult { let krate = find_builtin_crate(db, id); - expand_simple_derive(tt, quote! { #krate::cmp::Eq }, |_| quote! {}) + expand_simple_derive(tt, tm, quote! { #krate::cmp::Eq }, |_| quote! {}) } fn partial_eq_expand( db: &dyn ExpandDatabase, id: MacroCallId, - tt: &tt::Subtree, + tt: &ast::Adt, + tm: &TokenMap, ) -> ExpandResult { let krate = find_builtin_crate(db, id); - expand_simple_derive(tt, quote! { #krate::cmp::PartialEq }, |adt| { + expand_simple_derive(tt, tm, quote! { #krate::cmp::PartialEq }, |adt| { if matches!(adt.shape, AdtShape::Union) { // FIXME: Return expand error here return quote! {}; @@ -722,10 +716,11 @@ fn self_and_other_patterns( fn ord_expand( db: &dyn ExpandDatabase, id: MacroCallId, - tt: &tt::Subtree, + tt: &ast::Adt, + tm: &TokenMap, ) -> ExpandResult { let krate = &find_builtin_crate(db, id); - expand_simple_derive(tt, quote! { #krate::cmp::Ord }, |adt| { + expand_simple_derive(tt, tm, quote! { #krate::cmp::Ord }, |adt| { fn compare( krate: &tt::TokenTree, left: tt::Subtree, @@ -786,10 +781,11 @@ fn ord_expand( fn partial_ord_expand( db: &dyn ExpandDatabase, id: MacroCallId, - tt: &tt::Subtree, + tt: &ast::Adt, + tm: &TokenMap, ) -> ExpandResult { let krate = &find_builtin_crate(db, id); - expand_simple_derive(tt, quote! { #krate::cmp::PartialOrd }, |adt| { + expand_simple_derive(tt, tm, quote! { #krate::cmp::PartialOrd }, |adt| { fn compare( krate: &tt::TokenTree, left: tt::Subtree, diff --git a/crates/hir-expand/src/db.rs b/crates/hir-expand/src/db.rs index 3f08a098693..f528dfed31e 100644 --- a/crates/hir-expand/src/db.rs +++ b/crates/hir-expand/src/db.rs @@ -55,7 +55,9 @@ impl TokenExpander { TokenExpander::Builtin(it) => it.expand(db, id, tt).map_err(Into::into), TokenExpander::BuiltinEager(it) => it.expand(db, id, tt).map_err(Into::into), TokenExpander::BuiltinAttr(it) => it.expand(db, id, tt), - TokenExpander::BuiltinDerive(it) => it.expand(db, id, tt), + TokenExpander::BuiltinDerive(_) => { + unreachable!("builtin derives should be expanded manually") + } TokenExpander::ProcMacro(_) => { unreachable!("ExpandDatabase::expand_proc_macro should be used for proc macros") } @@ -232,6 +234,11 @@ pub fn expand_speculative( MacroDefKind::BuiltInAttr(BuiltinAttrExpander::Derive, _) => { pseudo_derive_attr_expansion(&tt, attr_arg.as_ref()?) } + MacroDefKind::BuiltInDerive(expander, ..) => { + // this cast is a bit sus, can we avoid losing the typedness here? + let adt = ast::Adt::cast(speculative_args.clone()).unwrap(); + expander.expand(db, actual_macro_call, &adt, &spec_args_tmap) + } _ => macro_def.expand(db, actual_macro_call, &tt), }; @@ -333,6 +340,9 @@ fn macro_arg( Some(Arc::new((tt, tmap, fixups.undo_info))) } +/// Certain macro calls expect some nodes in the input to be preprocessed away, namely: +/// - derives expect all `#[derive(..)]` invocations up to the currently invoked one to be stripped +/// - attributes expect the invoking attribute to be stripped fn censor_for_macro_input(loc: &MacroCallLoc, node: &SyntaxNode) -> FxHashSet { // FIXME: handle `cfg_attr` (|| { @@ -451,37 +461,59 @@ fn macro_expand(db: &dyn ExpandDatabase, id: MacroCallId) -> ExpandResult return db.expand_proc_macro(id), + MacroDefKind::BuiltInDerive(expander, ..) => { + let arg = db.macro_arg_text(id).unwrap(); - let expander = match db.macro_def(loc.def) { - Ok(it) => it, - // FIXME: We should make sure to enforce a variant that invalid macro - // definitions do not get expanders that could reach this call path! - Err(err) => { - return ExpandResult { - value: Arc::new(tt::Subtree { - delimiter: tt::Delimiter::UNSPECIFIED, - token_trees: vec![], - }), - err: Some(ExpandError::other(format!("invalid macro definition: {err}"))), - } + let node = SyntaxNode::new_root(arg); + let censor = censor_for_macro_input(&loc, &node); + let mut fixups = fixup::fixup_syntax(&node); + fixups.replace.extend(censor.into_iter().map(|node| (node.into(), Vec::new()))); + let (tmap, _) = mbe::syntax_node_to_token_map_with_modifications( + &node, + fixups.token_map, + fixups.next_id, + fixups.replace, + fixups.append, + ); + + // this cast is a bit sus, can we avoid losing the typedness here? + let adt = ast::Adt::cast(node).unwrap(); + (expander.expand(db, id, &adt, &tmap), Some((tmap, fixups.undo_info))) + } + _ => { + let expander = match db.macro_def(loc.def) { + Ok(it) => it, + // FIXME: We should make sure to enforce a variant that invalid macro + // definitions do not get expanders that could reach this call path! + Err(err) => { + return ExpandResult { + value: Arc::new(tt::Subtree { + delimiter: tt::Delimiter::UNSPECIFIED, + token_trees: vec![], + }), + err: Some(ExpandError::other(format!("invalid macro definition: {err}"))), + } + } + }; + let Some(macro_arg) = db.macro_arg(id) else { + return ExpandResult { + value: Arc::new(tt::Subtree { + delimiter: tt::Delimiter::UNSPECIFIED, + token_trees: Vec::new(), + }), + // FIXME: We should make sure to enforce an invariant that invalid macro + // calls do not reach this call path! + err: Some(ExpandError::other("invalid token tree")), + }; + }; + let (arg, arg_tm, undo_info) = &*macro_arg; + let mut res = expander.expand(db, id, arg); + fixup::reverse_fixups(&mut res.value, arg_tm, undo_info); + (res, None) } }; - let Some(macro_arg) = db.macro_arg(id) else { - return ExpandResult { - value: Arc::new(tt::Subtree { - delimiter: tt::Delimiter::UNSPECIFIED, - token_trees: Vec::new(), - }), - // FIXME: We should make sure to enforce an invariant that invalid macro - // calls do not reach this call path! - err: Some(ExpandError::other("invalid token tree")), - }; - }; - let (arg_tt, arg_tm, undo_info) = &*macro_arg; - let ExpandResult { value: mut tt, mut err } = expander.expand(db, id, arg_tt); if let Some(EagerCallInfo { error, .. }) = loc.eager.as_deref() { // FIXME: We should report both errors! @@ -493,7 +525,9 @@ fn macro_expand(db: &dyn ExpandDatabase, id: MacroCallId) -> ExpandResult TokenMap { + syntax_node_to_token_map_with_modifications( + node, + Default::default(), + 0, + Default::default(), + Default::default(), + ) + .0 +} + +/// Convert the syntax node to a `TokenTree` (what macro will consume) +/// with the censored range excluded. +pub fn syntax_node_to_token_map_with_modifications( + node: &SyntaxNode, + existing_token_map: TokenMap, + next_id: u32, + replace: FxHashMap>, + append: FxHashMap>, +) -> (TokenMap, u32) { + let global_offset = node.text_range().start(); + let mut c = Converter::new(node, global_offset, existing_token_map, next_id, replace, append); + collect_tokens(&mut c); + c.id_alloc.map.shrink_to_fit(); + always!(c.replace.is_empty(), "replace: {:?}", c.replace); + always!(c.append.is_empty(), "append: {:?}", c.append); + (c.id_alloc.map, c.id_alloc.next_id) +} + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct SyntheticTokenId(pub u32); @@ -327,6 +358,111 @@ fn convert_tokens(conv: &mut C) -> tt::Subtree { } } +fn collect_tokens(conv: &mut C) { + struct StackEntry { + idx: usize, + open_range: TextRange, + delimiter: tt::DelimiterKind, + } + + let entry = StackEntry { + delimiter: tt::DelimiterKind::Invisible, + // never used (delimiter is `None`) + idx: !0, + open_range: TextRange::empty(TextSize::of('.')), + }; + let mut stack = NonEmptyVec::new(entry); + + loop { + let StackEntry { delimiter, .. } = stack.last_mut(); + let (token, range) = match conv.bump() { + Some(it) => it, + None => break, + }; + let synth_id = token.synthetic_id(conv); + + let kind = token.kind(conv); + if kind == COMMENT { + // Since `convert_doc_comment` can fail, we need to peek the next id, so that we can + // figure out which token id to use for the doc comment, if it is converted successfully. + let next_id = conv.id_alloc().peek_next_id(); + if let Some(_tokens) = conv.convert_doc_comment(&token, next_id) { + let id = conv.id_alloc().alloc(range, synth_id); + debug_assert_eq!(id, next_id); + } + continue; + } + if kind.is_punct() && kind != UNDERSCORE { + if synth_id.is_none() { + assert_eq!(range.len(), TextSize::of('.')); + } + + let expected = match delimiter { + tt::DelimiterKind::Parenthesis => Some(T![')']), + tt::DelimiterKind::Brace => Some(T!['}']), + tt::DelimiterKind::Bracket => Some(T![']']), + tt::DelimiterKind::Invisible => None, + }; + + if let Some(expected) = expected { + if kind == expected { + if let Some(entry) = stack.pop() { + conv.id_alloc().close_delim(entry.idx, Some(range)); + } + continue; + } + } + + let delim = match kind { + T!['('] => Some(tt::DelimiterKind::Parenthesis), + T!['{'] => Some(tt::DelimiterKind::Brace), + T!['['] => Some(tt::DelimiterKind::Bracket), + _ => None, + }; + + if let Some(kind) = delim { + let (_id, idx) = conv.id_alloc().open_delim(range, synth_id); + + stack.push(StackEntry { idx, open_range: range, delimiter: kind }); + continue; + } + + conv.id_alloc().alloc(range, synth_id); + } else { + macro_rules! make_leaf { + ($i:ident) => {{ + conv.id_alloc().alloc(range, synth_id); + }}; + } + match kind { + T![true] | T![false] => make_leaf!(Ident), + IDENT => make_leaf!(Ident), + UNDERSCORE => make_leaf!(Ident), + k if k.is_keyword() => make_leaf!(Ident), + k if k.is_literal() => make_leaf!(Literal), + LIFETIME_IDENT => { + let char_unit = TextSize::of('\''); + let r = TextRange::at(range.start(), char_unit); + conv.id_alloc().alloc(r, synth_id); + + let r = TextRange::at(range.start() + char_unit, range.len() - char_unit); + conv.id_alloc().alloc(r, synth_id); + continue; + } + _ => continue, + }; + }; + + // If we get here, we've consumed all input tokens. + // We might have more than one subtree in the stack, if the delimiters are improperly balanced. + // Merge them so we're left with one. + while let Some(entry) = stack.pop() { + conv.id_alloc().close_delim(entry.idx, None); + conv.id_alloc().alloc(entry.open_range, None); + } + } +} + fn is_single_token_op(kind: SyntaxKind) -> bool { matches!( kind, From d5f64f875a1eac5eb723a077af4a08778dc978fe Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Mon, 10 Jul 2023 16:23:29 +0200 Subject: [PATCH 2/3] Infallibe ExpandDatabase::macro_def --- .../hir-def/src/macro_expansion_tests/mod.rs | 49 +++-- crates/hir-expand/src/db.rs | 195 ++++++++++-------- crates/hir-expand/src/hygiene.rs | 10 +- crates/hir-expand/src/lib.rs | 12 +- crates/hir/src/db.rs | 6 +- crates/hir/src/lib.rs | 22 +- crates/ide-db/src/apply_change.rs | 2 +- crates/ide-db/src/lib.rs | 2 +- crates/mbe/src/benchmark.rs | 7 +- crates/mbe/src/lib.rs | 74 +++++-- 10 files changed, 215 insertions(+), 164 deletions(-) diff --git a/crates/hir-def/src/macro_expansion_tests/mod.rs b/crates/hir-def/src/macro_expansion_tests/mod.rs index 836f0afb1d8..1c15c1b7f06 100644 --- a/crates/hir-def/src/macro_expansion_tests/mod.rs +++ b/crates/hir-def/src/macro_expansion_tests/mod.rs @@ -20,8 +20,8 @@ use ::mbe::TokenMap; use base_db::{fixture::WithFixture, ProcMacro, SourceDatabase}; use expect_test::Expect; use hir_expand::{ - db::{ExpandDatabase, TokenExpander}, - AstId, InFile, MacroDefId, MacroDefKind, MacroFile, + db::{DeclarativeMacroExpander, ExpandDatabase}, + AstId, InFile, MacroFile, }; use stdx::format_to; use syntax::{ @@ -100,32 +100,29 @@ pub fn identity_when_valid(_attr: TokenStream, item: TokenStream) -> TokenStream let call_offset = macro_.syntax().text_range().start().into(); let file_ast_id = db.ast_id_map(source.file_id).ast_id(¯o_); let ast_id = AstId::new(source.file_id, file_ast_id.upcast()); - let kind = MacroDefKind::Declarative(ast_id); - let macro_def = db - .macro_def(MacroDefId { krate, kind, local_inner: false, allow_internal_unsafe: false }) - .unwrap(); - if let TokenExpander::DeclarativeMacro { mac, def_site_token_map } = &*macro_def { - let tt = match ¯o_ { - ast::Macro::MacroRules(mac) => mac.token_tree().unwrap(), - ast::Macro::MacroDef(_) => unimplemented!(""), - }; + let DeclarativeMacroExpander { mac, def_site_token_map } = + &*db.decl_macro_expander(krate, ast_id); + assert_eq!(mac.err(), None); + let tt = match ¯o_ { + ast::Macro::MacroRules(mac) => mac.token_tree().unwrap(), + ast::Macro::MacroDef(_) => unimplemented!(""), + }; - let tt_start = tt.syntax().text_range().start(); - tt.syntax().descendants_with_tokens().filter_map(SyntaxElement::into_token).for_each( - |token| { - let range = token.text_range().checked_sub(tt_start).unwrap(); - if let Some(id) = def_site_token_map.token_by_range(range) { - let offset = (range.end() + tt_start).into(); - text_edits.push((offset..offset, format!("#{}", id.0))); - } - }, - ); - text_edits.push(( - call_offset..call_offset, - format!("// call ids will be shifted by {:?}\n", mac.shift()), - )); - } + let tt_start = tt.syntax().text_range().start(); + tt.syntax().descendants_with_tokens().filter_map(SyntaxElement::into_token).for_each( + |token| { + let range = token.text_range().checked_sub(tt_start).unwrap(); + if let Some(id) = def_site_token_map.token_by_range(range) { + let offset = (range.end() + tt_start).into(); + text_edits.push((offset..offset, format!("#{}", id.0))); + } + }, + ); + text_edits.push(( + call_offset..call_offset, + format!("// call ids will be shifted by {:?}\n", mac.shift()), + )); } for macro_call in source_file.syntax().descendants().filter_map(ast::MacroCall::cast) { diff --git a/crates/hir-expand/src/db.rs b/crates/hir-expand/src/db.rs index f528dfed31e..b138b91a8a5 100644 --- a/crates/hir-expand/src/db.rs +++ b/crates/hir-expand/src/db.rs @@ -1,6 +1,6 @@ //! Defines database & queries for macro expansion. -use base_db::{salsa, Edition, SourceDatabase}; +use base_db::{salsa, CrateId, Edition, SourceDatabase}; use either::Either; use limit::Limit; use mbe::syntax_node_to_token_tree; @@ -13,7 +13,7 @@ use triomphe::Arc; use crate::{ ast_id_map::AstIdMap, builtin_attr_macro::pseudo_derive_attr_expansion, - builtin_fn_macro::EagerExpander, fixup, hygiene::HygieneFrame, tt, BuiltinAttrExpander, + builtin_fn_macro::EagerExpander, fixup, hygiene::HygieneFrame, tt, AstId, BuiltinAttrExpander, BuiltinDeriveExpander, BuiltinFnLikeExpander, EagerCallInfo, ExpandError, ExpandResult, ExpandTo, HirFileId, HirFileIdRepr, MacroCallId, MacroCallKind, MacroCallLoc, MacroDefId, MacroDefKind, MacroFile, ProcMacroExpander, @@ -27,61 +27,59 @@ use crate::{ /// Actual max for `analysis-stats .` at some point: 30672. static TOKEN_LIMIT: Limit = Limit::new(1_048_576); +#[derive(Debug, Clone, Eq, PartialEq)] +/// Old-style `macro_rules` or the new macros 2.0 +pub struct DeclarativeMacroExpander { + pub mac: mbe::DeclarativeMacro, + pub def_site_token_map: mbe::TokenMap, +} + +impl DeclarativeMacroExpander { + pub fn expand(&self, tt: &tt::Subtree) -> ExpandResult { + match self.mac.err() { + Some(e) => ExpandResult::new( + tt::Subtree::empty(), + ExpandError::other(format!("invalid macro definition: {e}")), + ), + None => self.mac.expand(tt).map_err(Into::into), + } + } +} + #[derive(Debug, Clone, Eq, PartialEq)] pub enum TokenExpander { - /// Old-style `macro_rules` or the new macros 2.0 - DeclarativeMacro { mac: mbe::DeclarativeMacro, def_site_token_map: mbe::TokenMap }, + DeclarativeMacro(Arc), /// Stuff like `line!` and `file!`. - Builtin(BuiltinFnLikeExpander), + BuiltIn(BuiltinFnLikeExpander), /// Built-in eagerly expanded fn-like macros (`include!`, `concat!`, etc.) - BuiltinEager(EagerExpander), + BuiltInEager(EagerExpander), /// `global_allocator` and such. - BuiltinAttr(BuiltinAttrExpander), + BuiltInAttr(BuiltinAttrExpander), /// `derive(Copy)` and such. - BuiltinDerive(BuiltinDeriveExpander), + BuiltInDerive(BuiltinDeriveExpander), /// The thing we love the most here in rust-analyzer -- procedural macros. ProcMacro(ProcMacroExpander), } impl TokenExpander { - fn expand( - &self, - db: &dyn ExpandDatabase, - id: MacroCallId, - tt: &tt::Subtree, - ) -> ExpandResult { - match self { - TokenExpander::DeclarativeMacro { mac, .. } => mac.expand(tt).map_err(Into::into), - TokenExpander::Builtin(it) => it.expand(db, id, tt).map_err(Into::into), - TokenExpander::BuiltinEager(it) => it.expand(db, id, tt).map_err(Into::into), - TokenExpander::BuiltinAttr(it) => it.expand(db, id, tt), - TokenExpander::BuiltinDerive(_) => { - unreachable!("builtin derives should be expanded manually") - } - TokenExpander::ProcMacro(_) => { - unreachable!("ExpandDatabase::expand_proc_macro should be used for proc macros") - } - } - } - pub(crate) fn map_id_down(&self, id: tt::TokenId) -> tt::TokenId { match self { - TokenExpander::DeclarativeMacro { mac, .. } => mac.map_id_down(id), - TokenExpander::Builtin(..) - | TokenExpander::BuiltinEager(..) - | TokenExpander::BuiltinAttr(..) - | TokenExpander::BuiltinDerive(..) + TokenExpander::DeclarativeMacro(expander) => expander.mac.map_id_down(id), + TokenExpander::BuiltIn(..) + | TokenExpander::BuiltInEager(..) + | TokenExpander::BuiltInAttr(..) + | TokenExpander::BuiltInDerive(..) | TokenExpander::ProcMacro(..) => id, } } pub(crate) fn map_id_up(&self, id: tt::TokenId) -> (tt::TokenId, mbe::Origin) { match self { - TokenExpander::DeclarativeMacro { mac, .. } => mac.map_id_up(id), - TokenExpander::Builtin(..) - | TokenExpander::BuiltinEager(..) - | TokenExpander::BuiltinAttr(..) - | TokenExpander::BuiltinDerive(..) + TokenExpander::DeclarativeMacro(expander) => expander.mac.map_id_up(id), + TokenExpander::BuiltIn(..) + | TokenExpander::BuiltInEager(..) + | TokenExpander::BuiltInAttr(..) + | TokenExpander::BuiltInDerive(..) | TokenExpander::ProcMacro(..) => (id, mbe::Origin::Call), } } @@ -124,7 +122,14 @@ pub trait ExpandDatabase: SourceDatabase { fn macro_arg_text(&self, id: MacroCallId) -> Option; /// Gets the expander for this macro. This compiles declarative macros, and /// just fetches procedural ones. - fn macro_def(&self, id: MacroDefId) -> Result, mbe::ParseError>; + // FIXME: Rename this + #[salsa::transparent] + fn macro_def(&self, id: MacroDefId) -> TokenExpander; + fn decl_macro_expander( + &self, + def_crate: CrateId, + id: AstId, + ) -> Arc; /// Expand macro call to a token tree. // This query is LRU cached @@ -162,7 +167,7 @@ pub fn expand_speculative( token_to_map: SyntaxToken, ) -> Option<(SyntaxNode, SyntaxToken)> { let loc = db.lookup_intern_macro_call(actual_macro_call); - let macro_def = db.macro_def(loc.def).ok()?; + let macro_def = db.macro_def(loc.def); let token_range = token_to_map.text_range(); // Build the subtree and token mapping for the speculative args @@ -239,7 +244,12 @@ pub fn expand_speculative( let adt = ast::Adt::cast(speculative_args.clone()).unwrap(); expander.expand(db, actual_macro_call, &adt, &spec_args_tmap) } - _ => macro_def.expand(db, actual_macro_call, &tt), + MacroDefKind::Declarative(it) => db.decl_macro_expander(loc.krate, it).expand(&tt), + MacroDefKind::BuiltIn(it, _) => it.expand(db, actual_macro_call, &tt).map_err(Into::into), + MacroDefKind::BuiltInEager(it, _) => { + it.expand(db, actual_macro_call, &tt).map_err(Into::into) + } + MacroDefKind::BuiltInAttr(it, _) => it.expand(db, actual_macro_call, &tt), }; let expand_to = macro_expand_to(db, actual_macro_call); @@ -412,44 +422,55 @@ fn macro_arg_text(db: &dyn ExpandDatabase, id: MacroCallId) -> Option } } -fn macro_def( +fn decl_macro_expander( db: &dyn ExpandDatabase, - id: MacroDefId, -) -> Result, mbe::ParseError> { + def_crate: CrateId, + id: AstId, +) -> Arc { + let is_2021 = db.crate_graph()[def_crate].edition >= Edition::Edition2021; + let (mac, def_site_token_map) = match id.to_node(db) { + ast::Macro::MacroRules(macro_rules) => match macro_rules.token_tree() { + Some(arg) => { + let (tt, def_site_token_map) = mbe::syntax_node_to_token_tree(arg.syntax()); + let mac = mbe::DeclarativeMacro::parse_macro_rules(&tt, is_2021); + (mac, def_site_token_map) + } + None => ( + mbe::DeclarativeMacro::from_err( + mbe::ParseError::Expected("expected a token tree".into()), + is_2021, + ), + Default::default(), + ), + }, + ast::Macro::MacroDef(macro_def) => match macro_def.body() { + Some(arg) => { + let (tt, def_site_token_map) = mbe::syntax_node_to_token_tree(arg.syntax()); + let mac = mbe::DeclarativeMacro::parse_macro2(&tt, is_2021); + (mac, def_site_token_map) + } + None => ( + mbe::DeclarativeMacro::from_err( + mbe::ParseError::Expected("expected a token tree".into()), + is_2021, + ), + Default::default(), + ), + }, + }; + Arc::new(DeclarativeMacroExpander { mac, def_site_token_map }) +} + +fn macro_def(db: &dyn ExpandDatabase, id: MacroDefId) -> TokenExpander { match id.kind { MacroDefKind::Declarative(ast_id) => { - let is_2021 = db.crate_graph()[id.krate].edition >= Edition::Edition2021; - let (mac, def_site_token_map) = match ast_id.to_node(db) { - ast::Macro::MacroRules(macro_rules) => { - let arg = macro_rules - .token_tree() - .ok_or_else(|| mbe::ParseError::Expected("expected a token tree".into()))?; - let (tt, def_site_token_map) = mbe::syntax_node_to_token_tree(arg.syntax()); - let mac = mbe::DeclarativeMacro::parse_macro_rules(&tt, is_2021)?; - (mac, def_site_token_map) - } - ast::Macro::MacroDef(macro_def) => { - let arg = macro_def - .body() - .ok_or_else(|| mbe::ParseError::Expected("expected a token tree".into()))?; - let (tt, def_site_token_map) = mbe::syntax_node_to_token_tree(arg.syntax()); - let mac = mbe::DeclarativeMacro::parse_macro2(&tt, is_2021)?; - (mac, def_site_token_map) - } - }; - Ok(Arc::new(TokenExpander::DeclarativeMacro { mac, def_site_token_map })) + TokenExpander::DeclarativeMacro(db.decl_macro_expander(id.krate, ast_id)) } - MacroDefKind::BuiltIn(expander, _) => Ok(Arc::new(TokenExpander::Builtin(expander))), - MacroDefKind::BuiltInAttr(expander, _) => { - Ok(Arc::new(TokenExpander::BuiltinAttr(expander))) - } - MacroDefKind::BuiltInDerive(expander, _) => { - Ok(Arc::new(TokenExpander::BuiltinDerive(expander))) - } - MacroDefKind::BuiltInEager(expander, ..) => { - Ok(Arc::new(TokenExpander::BuiltinEager(expander))) - } - MacroDefKind::ProcMacro(expander, ..) => Ok(Arc::new(TokenExpander::ProcMacro(expander))), + MacroDefKind::BuiltIn(expander, _) => TokenExpander::BuiltIn(expander), + MacroDefKind::BuiltInAttr(expander, _) => TokenExpander::BuiltInAttr(expander), + MacroDefKind::BuiltInDerive(expander, _) => TokenExpander::BuiltInDerive(expander), + MacroDefKind::BuiltInEager(expander, ..) => TokenExpander::BuiltInEager(expander), + MacroDefKind::ProcMacro(expander, ..) => TokenExpander::ProcMacro(expander), } } @@ -483,20 +504,6 @@ fn macro_expand(db: &dyn ExpandDatabase, id: MacroCallId) -> ExpandResult { - let expander = match db.macro_def(loc.def) { - Ok(it) => it, - // FIXME: We should make sure to enforce a variant that invalid macro - // definitions do not get expanders that could reach this call path! - Err(err) => { - return ExpandResult { - value: Arc::new(tt::Subtree { - delimiter: tt::Delimiter::UNSPECIFIED, - token_trees: vec![], - }), - err: Some(ExpandError::other(format!("invalid macro definition: {err}"))), - } - } - }; let Some(macro_arg) = db.macro_arg(id) else { return ExpandResult { value: Arc::new(tt::Subtree { @@ -509,7 +516,15 @@ fn macro_expand(db: &dyn ExpandDatabase, id: MacroCallId) -> ExpandResult { + db.decl_macro_expander(loc.def.krate, id).expand(&arg) + } + MacroDefKind::BuiltIn(it, _) => it.expand(db, id, &arg).map_err(Into::into), + MacroDefKind::BuiltInEager(it, _) => it.expand(db, id, &arg).map_err(Into::into), + MacroDefKind::BuiltInAttr(it, _) => it.expand(db, id, &arg), + _ => unreachable!(), + }; fixup::reverse_fixups(&mut res.value, arg_tm, undo_info); (res, None) } diff --git a/crates/hir-expand/src/hygiene.rs b/crates/hir-expand/src/hygiene.rs index 10f8fe9cec4..b2921bb173b 100644 --- a/crates/hir-expand/src/hygiene.rs +++ b/crates/hir-expand/src/hygiene.rs @@ -126,7 +126,7 @@ struct HygieneInfo { /// The start offset of the `macro_rules!` arguments or attribute input. attr_input_or_mac_def_start: Option>, - macro_def: Arc, + macro_def: TokenExpander, macro_arg: Arc<(crate::tt::Subtree, mbe::TokenMap, fixup::SyntaxFixupUndoInfo)>, macro_arg_shift: mbe::Shift, exp_map: Arc, @@ -159,9 +159,9 @@ impl HygieneInfo { &self.macro_arg.1, InFile::new(loc.kind.file_id(), loc.kind.arg(db)?.text_range().start()), ), - mbe::Origin::Def => match (&*self.macro_def, &self.attr_input_or_mac_def_start) { - (TokenExpander::DeclarativeMacro { def_site_token_map, .. }, Some(tt)) => { - (def_site_token_map, *tt) + mbe::Origin::Def => match (&self.macro_def, &self.attr_input_or_mac_def_start) { + (TokenExpander::DeclarativeMacro(expander), Some(tt)) => { + (&expander.def_site_token_map, *tt) } _ => panic!("`Origin::Def` used with non-`macro_rules!` macro"), }, @@ -198,7 +198,7 @@ fn make_hygiene_info( _ => None, }); - let macro_def = db.macro_def(loc.def).ok()?; + let macro_def = db.macro_def(loc.def); let (_, exp_map) = db.parse_macro_expansion(macro_file).value; let macro_arg = db.macro_arg(macro_file.macro_call_id).unwrap_or_else(|| { Arc::new(( diff --git a/crates/hir-expand/src/lib.rs b/crates/hir-expand/src/lib.rs index 3a534f56e4f..a92c17f4ed0 100644 --- a/crates/hir-expand/src/lib.rs +++ b/crates/hir-expand/src/lib.rs @@ -274,7 +274,7 @@ impl HirFileId { let arg_tt = loc.kind.arg(db)?; - let macro_def = db.macro_def(loc.def).ok()?; + let macro_def = db.macro_def(loc.def); let (parse, exp_map) = db.parse_macro_expansion(macro_file).value; let macro_arg = db.macro_arg(macro_file.macro_call_id).unwrap_or_else(|| { Arc::new(( @@ -287,7 +287,7 @@ impl HirFileId { let def = loc.def.ast_id().left().and_then(|id| { let def_tt = match id.to_node(db) { ast::Macro::MacroRules(mac) => mac.token_tree()?, - ast::Macro::MacroDef(_) if matches!(*macro_def, TokenExpander::BuiltinAttr(_)) => { + ast::Macro::MacroDef(_) if matches!(macro_def, TokenExpander::BuiltInAttr(_)) => { return None } ast::Macro::MacroDef(mac) => mac.body()?, @@ -633,7 +633,7 @@ pub struct ExpansionInfo { /// The `macro_rules!` or attribute input. attr_input_or_mac_def: Option>, - macro_def: Arc, + macro_def: TokenExpander, macro_arg: Arc<(tt::Subtree, mbe::TokenMap, fixup::SyntaxFixupUndoInfo)>, /// A shift built from `macro_arg`'s subtree, relevant for attributes as the item is the macro arg /// and as such we need to shift tokens if they are part of an attributes input instead of their item. @@ -780,9 +780,9 @@ impl ExpansionInfo { } _ => match origin { mbe::Origin::Call => (&self.macro_arg.1, self.arg.clone()), - mbe::Origin::Def => match (&*self.macro_def, &self.attr_input_or_mac_def) { - (TokenExpander::DeclarativeMacro { def_site_token_map, .. }, Some(tt)) => { - (def_site_token_map, tt.syntax().cloned()) + mbe::Origin::Def => match (&self.macro_def, &self.attr_input_or_mac_def) { + (TokenExpander::DeclarativeMacro(expander), Some(tt)) => { + (&expander.def_site_token_map, tt.syntax().cloned()) } _ => panic!("`Origin::Def` used with non-`macro_rules!` macro"), }, diff --git a/crates/hir/src/db.rs b/crates/hir/src/db.rs index e0cde689fed..fa2deb72152 100644 --- a/crates/hir/src/db.rs +++ b/crates/hir/src/db.rs @@ -5,9 +5,9 @@ //! But we need this for at least LRU caching at the query level. pub use hir_def::db::*; pub use hir_expand::db::{ - AstIdMapQuery, ExpandDatabase, ExpandDatabaseStorage, ExpandProcMacroQuery, HygieneFrameQuery, - InternMacroCallQuery, MacroArgTextQuery, MacroDefQuery, MacroExpandQuery, - ParseMacroExpansionErrorQuery, ParseMacroExpansionQuery, + AstIdMapQuery, DeclMacroExpanderQuery, ExpandDatabase, ExpandDatabaseStorage, + ExpandProcMacroQuery, HygieneFrameQuery, InternMacroCallQuery, MacroArgTextQuery, + MacroExpandQuery, ParseMacroExpansionErrorQuery, ParseMacroExpansionQuery, }; pub use hir_ty::db::*; diff --git a/crates/hir/src/lib.rs b/crates/hir/src/lib.rs index 44f901fb3a3..d09963d4d31 100644 --- a/crates/hir/src/lib.rs +++ b/crates/hir/src/lib.rs @@ -698,16 +698,18 @@ impl Module { fn emit_macro_def_diagnostics(db: &dyn HirDatabase, acc: &mut Vec, m: Macro) { let id = macro_id_to_def_id(db.upcast(), m.id); - if let Err(e) = db.macro_def(id) { - let Some(ast) = id.ast_id().left() else { - never!("MacroDefError for proc-macro: {:?}", e); - return; - }; - emit_def_diagnostic_( - db, - acc, - &DefDiagnosticKind::MacroDefError { ast, message: e.to_string() }, - ); + if let hir_expand::db::TokenExpander::DeclarativeMacro(expander) = db.macro_def(id) { + if let Some(e) = expander.mac.err() { + let Some(ast) = id.ast_id().left() else { + never!("declarative expander for non decl-macro: {:?}", e); + return; + }; + emit_def_diagnostic_( + db, + acc, + &DefDiagnosticKind::MacroDefError { ast, message: e.to_string() }, + ); + } } } diff --git a/crates/ide-db/src/apply_change.rs b/crates/ide-db/src/apply_change.rs index 0dd544d0ae2..7ba0bd73ec4 100644 --- a/crates/ide-db/src/apply_change.rs +++ b/crates/ide-db/src/apply_change.rs @@ -100,7 +100,7 @@ impl RootDatabase { hir::db::ParseMacroExpansionQuery hir::db::InternMacroCallQuery hir::db::MacroArgTextQuery - hir::db::MacroDefQuery + hir::db::DeclMacroExpanderQuery hir::db::MacroExpandQuery hir::db::ExpandProcMacroQuery hir::db::HygieneFrameQuery diff --git a/crates/ide-db/src/lib.rs b/crates/ide-db/src/lib.rs index ff1a20f03f4..feced7e4bbb 100644 --- a/crates/ide-db/src/lib.rs +++ b/crates/ide-db/src/lib.rs @@ -201,7 +201,7 @@ impl RootDatabase { // hir_db::ParseMacroExpansionQuery // hir_db::InternMacroCallQuery hir_db::MacroArgTextQuery - hir_db::MacroDefQuery + hir_db::DeclMacroExpanderQuery // hir_db::MacroExpandQuery hir_db::ExpandProcMacroQuery hir_db::HygieneFrameQuery diff --git a/crates/mbe/src/benchmark.rs b/crates/mbe/src/benchmark.rs index d28dd17def3..793326a8417 100644 --- a/crates/mbe/src/benchmark.rs +++ b/crates/mbe/src/benchmark.rs @@ -20,10 +20,7 @@ fn benchmark_parse_macro_rules() { let rules = macro_rules_fixtures_tt(); let hash: usize = { let _pt = bench("mbe parse macro rules"); - rules - .values() - .map(|it| DeclarativeMacro::parse_macro_rules(it, true).unwrap().rules.len()) - .sum() + rules.values().map(|it| DeclarativeMacro::parse_macro_rules(it, true).rules.len()).sum() }; assert_eq!(hash, 1144); } @@ -53,7 +50,7 @@ fn benchmark_expand_macro_rules() { fn macro_rules_fixtures() -> FxHashMap { macro_rules_fixtures_tt() .into_iter() - .map(|(id, tt)| (id, DeclarativeMacro::parse_macro_rules(&tt, true).unwrap())) + .map(|(id, tt)| (id, DeclarativeMacro::parse_macro_rules(&tt, true))) .collect() } diff --git a/crates/mbe/src/lib.rs b/crates/mbe/src/lib.rs index 09af93f1257..b528af8c6a7 100644 --- a/crates/mbe/src/lib.rs +++ b/crates/mbe/src/lib.rs @@ -132,6 +132,7 @@ pub struct DeclarativeMacro { // This is used for correctly determining the behavior of the pat fragment // FIXME: This should be tracked by hygiene of the fragment identifier! is_2021: bool, + err: Option>, } #[derive(Clone, Debug, PartialEq, Eq)] @@ -214,65 +215,100 @@ pub enum Origin { } impl DeclarativeMacro { + pub fn from_err(err: ParseError, is_2021: bool) -> DeclarativeMacro { + DeclarativeMacro { + rules: Box::default(), + shift: Shift(0), + is_2021, + err: Some(Box::new(err)), + } + } + /// The old, `macro_rules! m {}` flavor. - pub fn parse_macro_rules( - tt: &tt::Subtree, - is_2021: bool, - ) -> Result { + pub fn parse_macro_rules(tt: &tt::Subtree, is_2021: bool) -> DeclarativeMacro { // Note: this parsing can be implemented using mbe machinery itself, by // matching against `$($lhs:tt => $rhs:tt);*` pattern, but implementing // manually seems easier. let mut src = TtIter::new(tt); let mut rules = Vec::new(); + let mut err = None; + while src.len() > 0 { - let rule = Rule::parse(&mut src, true)?; + let rule = match Rule::parse(&mut src, true) { + Ok(it) => it, + Err(e) => { + err = Some(Box::new(e)); + break; + } + }; rules.push(rule); if let Err(()) = src.expect_char(';') { if src.len() > 0 { - return Err(ParseError::expected("expected `;`")); + err = Some(Box::new(ParseError::expected("expected `;`"))); } break; } } for Rule { lhs, .. } in &rules { - validate(lhs)?; + if let Err(e) = validate(lhs) { + err = Some(Box::new(e)); + break; + } } - Ok(DeclarativeMacro { rules: rules.into_boxed_slice(), shift: Shift::new(tt), is_2021 }) + DeclarativeMacro { rules: rules.into_boxed_slice(), shift: Shift::new(tt), is_2021, err } } /// The new, unstable `macro m {}` flavor. - pub fn parse_macro2(tt: &tt::Subtree, is_2021: bool) -> Result { + pub fn parse_macro2(tt: &tt::Subtree, is_2021: bool) -> DeclarativeMacro { let mut src = TtIter::new(tt); let mut rules = Vec::new(); + let mut err = None; if tt::DelimiterKind::Brace == tt.delimiter.kind { cov_mark::hit!(parse_macro_def_rules); while src.len() > 0 { - let rule = Rule::parse(&mut src, true)?; + let rule = match Rule::parse(&mut src, true) { + Ok(it) => it, + Err(e) => { + err = Some(Box::new(e)); + break; + } + }; rules.push(rule); if let Err(()) = src.expect_any_char(&[';', ',']) { if src.len() > 0 { - return Err(ParseError::expected("expected `;` or `,` to delimit rules")); + err = Some(Box::new(ParseError::expected( + "expected `;` or `,` to delimit rules", + ))); } break; } } } else { cov_mark::hit!(parse_macro_def_simple); - let rule = Rule::parse(&mut src, false)?; - if src.len() != 0 { - return Err(ParseError::expected("remaining tokens in macro def")); + match Rule::parse(&mut src, false) { + Ok(rule) => { + if src.len() != 0 { + err = Some(Box::new(ParseError::expected("remaining tokens in macro def"))); + } + rules.push(rule); + } + Err(e) => { + err = Some(Box::new(e)); + } } - rules.push(rule); } for Rule { lhs, .. } in &rules { - validate(lhs)?; + if let Err(e) = validate(lhs) { + err = Some(Box::new(e)); + break; + } } - Ok(DeclarativeMacro { rules: rules.into_boxed_slice(), shift: Shift::new(tt), is_2021 }) + DeclarativeMacro { rules: rules.into_boxed_slice(), shift: Shift::new(tt), is_2021, err } } pub fn expand(&self, tt: &tt::Subtree) -> ExpandResult { @@ -282,6 +318,10 @@ impl DeclarativeMacro { expander::expand_rules(&self.rules, &tt, self.is_2021) } + pub fn err(&self) -> Option<&ParseError> { + self.err.as_deref() + } + pub fn map_id_down(&self, id: tt::TokenId) -> tt::TokenId { self.shift.shift(id) } From f6c09099daa5bb0b3aba405c38e6d580a5436702 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Mon, 10 Jul 2023 16:28:23 +0200 Subject: [PATCH 3/3] Don't unnecessarily clone the input tt for decl macros --- crates/hir-expand/src/db.rs | 27 ++++++++++++++++++++------- crates/mbe/src/benchmark.rs | 4 ++-- crates/mbe/src/lib.rs | 4 +--- 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/crates/hir-expand/src/db.rs b/crates/hir-expand/src/db.rs index b138b91a8a5..a2b642cd114 100644 --- a/crates/hir-expand/src/db.rs +++ b/crates/hir-expand/src/db.rs @@ -35,7 +35,7 @@ pub struct DeclarativeMacroExpander { } impl DeclarativeMacroExpander { - pub fn expand(&self, tt: &tt::Subtree) -> ExpandResult { + pub fn expand(&self, tt: tt::Subtree) -> ExpandResult { match self.mac.err() { Some(e) => ExpandResult::new( tt::Subtree::empty(), @@ -44,6 +44,14 @@ impl DeclarativeMacroExpander { None => self.mac.expand(tt).map_err(Into::into), } } + + pub fn map_id_down(&self, token_id: tt::TokenId) -> tt::TokenId { + self.mac.map_id_down(token_id) + } + + pub fn map_id_up(&self, token_id: tt::TokenId) -> (tt::TokenId, mbe::Origin) { + self.mac.map_id_up(token_id) + } } #[derive(Debug, Clone, Eq, PartialEq)] @@ -61,10 +69,11 @@ pub enum TokenExpander { ProcMacro(ProcMacroExpander), } +// FIXME: Get rid of these methods impl TokenExpander { pub(crate) fn map_id_down(&self, id: tt::TokenId) -> tt::TokenId { match self { - TokenExpander::DeclarativeMacro(expander) => expander.mac.map_id_down(id), + TokenExpander::DeclarativeMacro(expander) => expander.map_id_down(id), TokenExpander::BuiltIn(..) | TokenExpander::BuiltInEager(..) | TokenExpander::BuiltInAttr(..) @@ -75,7 +84,7 @@ impl TokenExpander { pub(crate) fn map_id_up(&self, id: tt::TokenId) -> (tt::TokenId, mbe::Origin) { match self { - TokenExpander::DeclarativeMacro(expander) => expander.mac.map_id_up(id), + TokenExpander::DeclarativeMacro(expander) => expander.map_id_up(id), TokenExpander::BuiltIn(..) | TokenExpander::BuiltInEager(..) | TokenExpander::BuiltInAttr(..) @@ -167,7 +176,6 @@ pub fn expand_speculative( token_to_map: SyntaxToken, ) -> Option<(SyntaxNode, SyntaxToken)> { let loc = db.lookup_intern_macro_call(actual_macro_call); - let macro_def = db.macro_def(loc.def); let token_range = token_to_map.text_range(); // Build the subtree and token mapping for the speculative args @@ -225,7 +233,12 @@ pub fn expand_speculative( None => { let range = token_range.checked_sub(speculative_args.text_range().start())?; let token_id = spec_args_tmap.token_by_range(range)?; - macro_def.map_id_down(token_id) + match loc.def.kind { + MacroDefKind::Declarative(it) => { + db.decl_macro_expander(loc.krate, it).map_id_down(token_id) + } + _ => token_id, + } } }; @@ -244,7 +257,7 @@ pub fn expand_speculative( let adt = ast::Adt::cast(speculative_args.clone()).unwrap(); expander.expand(db, actual_macro_call, &adt, &spec_args_tmap) } - MacroDefKind::Declarative(it) => db.decl_macro_expander(loc.krate, it).expand(&tt), + MacroDefKind::Declarative(it) => db.decl_macro_expander(loc.krate, it).expand(tt), MacroDefKind::BuiltIn(it, _) => it.expand(db, actual_macro_call, &tt).map_err(Into::into), MacroDefKind::BuiltInEager(it, _) => { it.expand(db, actual_macro_call, &tt).map_err(Into::into) @@ -518,7 +531,7 @@ fn macro_expand(db: &dyn ExpandDatabase, id: MacroCallId) -> ExpandResult { - db.decl_macro_expander(loc.def.krate, id).expand(&arg) + db.decl_macro_expander(loc.def.krate, id).expand(arg.clone()) } MacroDefKind::BuiltIn(it, _) => it.expand(db, id, &arg).map_err(Into::into), MacroDefKind::BuiltInEager(it, _) => it.expand(db, id, &arg).map_err(Into::into), diff --git a/crates/mbe/src/benchmark.rs b/crates/mbe/src/benchmark.rs index 793326a8417..9d43e130457 100644 --- a/crates/mbe/src/benchmark.rs +++ b/crates/mbe/src/benchmark.rs @@ -38,7 +38,7 @@ fn benchmark_expand_macro_rules() { invocations .into_iter() .map(|(id, tt)| { - let res = rules[&id].expand(&tt); + let res = rules[&id].expand(tt); assert!(res.err.is_none()); res.value.token_trees.len() }) @@ -102,7 +102,7 @@ fn invocation_fixtures(rules: &FxHashMap) -> Vec<(Stri for op in rule.lhs.iter() { collect_from_op(op, &mut subtree, &mut seed); } - if it.expand(&subtree).err.is_none() { + if it.expand(subtree.clone()).err.is_none() { res.push((name.clone(), subtree)); break; } diff --git a/crates/mbe/src/lib.rs b/crates/mbe/src/lib.rs index b528af8c6a7..c17ba1c58e2 100644 --- a/crates/mbe/src/lib.rs +++ b/crates/mbe/src/lib.rs @@ -311,9 +311,7 @@ impl DeclarativeMacro { DeclarativeMacro { rules: rules.into_boxed_slice(), shift: Shift::new(tt), is_2021, err } } - pub fn expand(&self, tt: &tt::Subtree) -> ExpandResult { - // apply shift - let mut tt = tt.clone(); + pub fn expand(&self, mut tt: tt::Subtree) -> ExpandResult { self.shift.shift_all(&mut tt); expander::expand_rules(&self.rules, &tt, self.is_2021) }