From 1141259a23cedb07b07dac3ccf4ff054d92adcf5 Mon Sep 17 00:00:00 2001 From: Wyatt Herkamp Date: Sun, 24 Mar 2024 09:51:08 -0400 Subject: [PATCH 1/3] Init Wrap/Unwrap cfg_attr --- .../src/handlers/wrap_unwrap_cfg_attr.rs | 538 ++++++++++++++++++ crates/ide-assists/src/lib.rs | 3 + 2 files changed, 541 insertions(+) create mode 100644 crates/ide-assists/src/handlers/wrap_unwrap_cfg_attr.rs diff --git a/crates/ide-assists/src/handlers/wrap_unwrap_cfg_attr.rs b/crates/ide-assists/src/handlers/wrap_unwrap_cfg_attr.rs new file mode 100644 index 00000000000..4b8fb1e78a8 --- /dev/null +++ b/crates/ide-assists/src/handlers/wrap_unwrap_cfg_attr.rs @@ -0,0 +1,538 @@ +use ide_db::source_change::SourceChangeBuilder; +use itertools::Itertools; +use syntax::{ + algo, + ast::{self, make, AstNode}, + ted::{self, Position}, + AstToken, NodeOrToken, SyntaxToken, TextRange, T, +}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: wrap_unwrap_cfg_attr +// +// Wraps an attribute to a cfg_attr attribute or unwraps a cfg_attr attribute to the inner attributes. +// +// ``` +// #[derive$0(Debug)] +// struct S { +// field: i32 +// } +// ``` +// -> +// ``` +// #[cfg_attr($0, derive(Debug))] +// struct S { +// field: i32 +// } + +enum WrapUnwrapOption { + WrapDerive { derive: TextRange, attr: ast::Attr }, + WrapAttr(ast::Attr), +} + +/// Attempts to get the derive attribute from a derive attribute list +/// +/// This will collect all the tokens in the "path" within the derive attribute list +/// But a derive attribute list doesn't have paths. So we need to collect all the tokens before and after the ident +/// +/// If this functions return None just map to WrapAttr +fn attempt_get_derive(attr: ast::Attr, ident: SyntaxToken) -> WrapUnwrapOption { + let attempt_attr = || { + { + let mut derive = ident.text_range(); + // TokenTree is all the tokens between the `(` and `)`. They do not have paths. So a path `serde::Serialize` would be [Ident Colon Colon Ident] + // So lets say we have derive(Debug, serde::Serialize, Copy) ident would be on Serialize + // We need to grab all previous tokens until we find a `,` or `(` and all following tokens until we find a `,` or `)` + // We also want to consume the following comma if it exists + + let mut prev = algo::skip_trivia_token( + ident.prev_sibling_or_token()?.into_token()?, + syntax::Direction::Prev, + )?; + let mut following = algo::skip_trivia_token( + ident.next_sibling_or_token()?.into_token()?, + syntax::Direction::Next, + )?; + if (prev.kind() == T![,] || prev.kind() == T!['(']) + && (following.kind() == T![,] || following.kind() == T!['(']) + { + // This would be a single ident such as Debug. As no path is present + if following.kind() == T![,] { + derive = derive.cover(following.text_range()); + } + + Some(WrapUnwrapOption::WrapDerive { derive, attr: attr.clone() }) + } else { + // Collect the path + + while let Some(prev_token) = algo::skip_trivia_token(prev, syntax::Direction::Prev) + { + let kind = prev_token.kind(); + if kind == T![,] || kind == T!['('] { + break; + } + derive = derive.cover(prev_token.text_range()); + prev = prev_token.prev_sibling_or_token()?.into_token()?; + } + while let Some(next_token) = + algo::skip_trivia_token(following.clone(), syntax::Direction::Next) + { + let kind = next_token.kind(); + if kind != T![')'] { + // We also want to consume a following comma + derive = derive.cover(next_token.text_range()); + } + following = next_token.next_sibling_or_token()?.into_token()?; + + if kind == T![,] || kind == T![')'] { + break; + } + } + Some(WrapUnwrapOption::WrapDerive { derive, attr: attr.clone() }) + } + } + }; + if ident.parent().and_then(ast::TokenTree::cast).is_none() + || !attr.simple_name().map(|v| v.eq("derive")).unwrap_or_default() + { + WrapUnwrapOption::WrapAttr(attr) + } else { + attempt_attr().unwrap_or(WrapUnwrapOption::WrapAttr(attr)) + } +} +pub(crate) fn wrap_unwrap_cfg_attr(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let option = if ctx.has_empty_selection() { + let ident = ctx.find_token_at_offset::().map(|v| v.syntax().clone()); + let attr = ctx.find_node_at_offset::(); + match (attr, ident) { + (Some(attr), Some(ident)) + if attr.simple_name().map(|v| v.eq("derive")).unwrap_or_default() => + { + Some(attempt_get_derive(attr.clone(), ident)) + } + (Some(attr), _) => Some(WrapUnwrapOption::WrapAttr(attr)), + _ => None, + } + } else { + let covering_element = ctx.covering_element(); + match covering_element { + NodeOrToken::Node(node) => ast::Attr::cast(node).map(WrapUnwrapOption::WrapAttr), + NodeOrToken::Token(ident) if ident.kind() == syntax::T![ident] => { + let attr = ident.parent_ancestors().find_map(ast::Attr::cast)?; + Some(attempt_get_derive(attr.clone(), ident)) + } + _ => None, + } + }?; + match option { + WrapUnwrapOption::WrapAttr(attr) if attr.simple_name().as_deref() == Some("cfg_attr") => { + unwrap_cfg_attr(acc, attr) + } + WrapUnwrapOption::WrapAttr(attr) => wrap_cfg_attr(acc, ctx, attr), + WrapUnwrapOption::WrapDerive { derive, attr } => wrap_derive(acc, ctx, attr, derive), + } +} + +fn wrap_derive( + acc: &mut Assists, + ctx: &AssistContext<'_>, + attr: ast::Attr, + derive_element: TextRange, +) -> Option<()> { + let range = attr.syntax().text_range(); + let token_tree = attr.token_tree()?; + let mut path_text = String::new(); + + let mut cfg_derive_tokens = Vec::new(); + let mut new_derive = Vec::new(); + + for tt in token_tree.token_trees_and_tokens() { + let NodeOrToken::Token(token) = tt else { + continue; + }; + if token.kind() == T!['('] || token.kind() == T![')'] { + continue; + } + + if derive_element.contains_range(token.text_range()) { + if token.kind() != T![,] { + path_text.push_str(token.text()); + cfg_derive_tokens.push(NodeOrToken::Token(token)); + } + } else { + new_derive.push(NodeOrToken::Token(token)); + } + } + let handle_source_change = |edit: &mut SourceChangeBuilder| { + let new_derive = make::attr_outer(make::meta_token_tree( + make::ext::ident_path("derive"), + make::token_tree(T!['('], new_derive), + )) + .clone_for_update(); + let meta = make::meta_token_tree( + make::ext::ident_path("cfg_attr"), + make::token_tree( + T!['('], + vec![ + NodeOrToken::Token(make::token(T![,])), + NodeOrToken::Token(make::tokens::whitespace(" ")), + NodeOrToken::Token(make::tokens::ident("derive")), + NodeOrToken::Node(make::token_tree(T!['('], cfg_derive_tokens)), + ], + ), + ); + // Remove the derive attribute + let edit_attr = edit.make_syntax_mut(attr.syntax().clone()); + + ted::replace(edit_attr, new_derive.syntax().clone()); + let cfg_attr = make::attr_outer(meta).clone_for_update(); + + ted::insert_all_raw( + Position::after(new_derive.syntax().clone()), + vec![make::tokens::whitespace("\n").into(), cfg_attr.syntax().clone().into()], + ); + if let Some(snippet_cap) = ctx.config.snippet_cap { + if let Some(first_meta) = + cfg_attr.meta().and_then(|meta| meta.token_tree()).and_then(|tt| tt.l_paren_token()) + { + edit.add_tabstop_after_token(snippet_cap, first_meta) + } + } + }; + + acc.add( + AssistId("wrap_unwrap_cfg_attr", AssistKind::Refactor), + format!("Wrap #[derive({path_text})] in `cfg_attr`",), + range, + handle_source_change, + ); + Some(()) +} +fn wrap_cfg_attr(acc: &mut Assists, ctx: &AssistContext<'_>, attr: ast::Attr) -> Option<()> { + let range = attr.syntax().text_range(); + let path = attr.path()?; + let handle_source_change = |edit: &mut SourceChangeBuilder| { + let mut raw_tokens = vec![ + NodeOrToken::Token(make::token(T![,])), + NodeOrToken::Token(make::tokens::whitespace(" ")), + ]; + path.syntax().descendants_with_tokens().for_each(|it| { + if let NodeOrToken::Token(token) = it { + raw_tokens.push(NodeOrToken::Token(token)); + } + }); + if let Some(meta) = attr.meta() { + if let (Some(eq), Some(expr)) = (meta.eq_token(), meta.expr()) { + raw_tokens.push(NodeOrToken::Token(make::tokens::whitespace(" "))); + raw_tokens.push(NodeOrToken::Token(eq.clone())); + raw_tokens.push(NodeOrToken::Token(make::tokens::whitespace(" "))); + + expr.syntax().descendants_with_tokens().for_each(|it| { + if let NodeOrToken::Token(token) = it { + raw_tokens.push(NodeOrToken::Token(token)); + } + }); + } else if let Some(tt) = meta.token_tree() { + raw_tokens.extend(tt.token_trees_and_tokens()); + } + } + let meta = make::meta_token_tree( + make::ext::ident_path("cfg_attr"), + make::token_tree(T!['('], raw_tokens), + ); + let cfg_attr = if attr.excl_token().is_some() { + make::attr_inner(meta) + } else { + make::attr_outer(meta) + } + .clone_for_update(); + let attr_syntax = edit.make_syntax_mut(attr.syntax().clone()); + ted::replace(attr_syntax, cfg_attr.syntax()); + + if let Some(snippet_cap) = ctx.config.snippet_cap { + if let Some(first_meta) = + cfg_attr.meta().and_then(|meta| meta.token_tree()).and_then(|tt| tt.l_paren_token()) + { + edit.add_tabstop_after_token(snippet_cap, first_meta) + } + } + }; + acc.add( + AssistId("wrap_unwrap_cfg_attr", AssistKind::Refactor), + "Convert to `cfg_attr`", + range, + handle_source_change, + ); + Some(()) +} +fn unwrap_cfg_attr(acc: &mut Assists, attr: ast::Attr) -> Option<()> { + let range = attr.syntax().text_range(); + let meta = attr.meta()?; + let meta_tt = meta.token_tree()?; + let mut inner_attrs = Vec::with_capacity(1); + let mut found_comma = false; + let mut iter = meta_tt.token_trees_and_tokens().skip(1).peekable(); + while let Some(tt) = iter.next() { + if let NodeOrToken::Token(token) = &tt { + if token.kind() == T![')'] { + break; + } + if token.kind() == T![,] { + found_comma = true; + continue; + } + } + if !found_comma { + continue; + } + let Some(attr_name) = tt.into_token().and_then(|token| { + if token.kind() == T![ident] { + Some(make::ext::ident_path(token.text())) + } else { + None + } + }) else { + continue; + }; + let next_tt = iter.next()?; + let meta = match next_tt { + NodeOrToken::Node(tt) => make::meta_token_tree(attr_name, tt), + NodeOrToken::Token(token) if token.kind() == T![,] || token.kind() == T![')'] => { + make::meta_path(attr_name) + } + NodeOrToken::Token(token) => { + let equals = algo::skip_trivia_token(token, syntax::Direction::Next)?; + if equals.kind() != T![=] { + return None; + } + let expr_token = + algo::skip_trivia_token(equals.next_token()?, syntax::Direction::Next) + .and_then(|it| { + if it.kind().is_literal() { + Some(make::expr_literal(it.text())) + } else { + None + } + })?; + make::meta_expr(attr_name, ast::Expr::Literal(expr_token)) + } + }; + if attr.excl_token().is_some() { + inner_attrs.push(make::attr_inner(meta)); + } else { + inner_attrs.push(make::attr_outer(meta)); + } + } + if inner_attrs.is_empty() { + return None; + } + let handle_source_change = |f: &mut SourceChangeBuilder| { + let inner_attrs = inner_attrs.iter().map(|it| it.to_string()).join("\n"); + f.replace(range, inner_attrs); + }; + acc.add( + AssistId("wrap_unwrap_cfg_attr", AssistKind::Refactor), + "Extract Inner Attributes from `cfg_attr`", + range, + handle_source_change, + ); + Some(()) +} +#[cfg(test)] +mod tests { + use crate::tests::check_assist; + + use super::*; + + #[test] + fn test_basic_to_from_cfg_attr() { + check_assist( + wrap_unwrap_cfg_attr, + r#" + #[derive$0(Debug)] + pub struct Test { + test: u32, + } + "#, + r#" + #[cfg_attr($0, derive(Debug))] + pub struct Test { + test: u32, + } + "#, + ); + check_assist( + wrap_unwrap_cfg_attr, + r#" + #[cfg_attr(debug_assertions, $0 derive(Debug))] + pub struct Test { + test: u32, + } + "#, + r#" + #[derive(Debug)] + pub struct Test { + test: u32, + } + "#, + ); + } + #[test] + fn to_from_path_attr() { + check_assist( + wrap_unwrap_cfg_attr, + r#" + pub struct Test { + #[foo$0] + test: u32, + } + "#, + r#" + pub struct Test { + #[cfg_attr($0, foo)] + test: u32, + } + "#, + ); + check_assist( + wrap_unwrap_cfg_attr, + r#" + pub struct Test { + #[cfg_attr(debug_assertions$0, foo)] + test: u32, + } + "#, + r#" + pub struct Test { + #[foo] + test: u32, + } + "#, + ); + } + #[test] + fn to_from_eq_attr() { + check_assist( + wrap_unwrap_cfg_attr, + r#" + pub struct Test { + #[foo = "bar"$0] + test: u32, + } + "#, + r#" + pub struct Test { + #[cfg_attr($0, foo = "bar")] + test: u32, + } + "#, + ); + check_assist( + wrap_unwrap_cfg_attr, + r#" + pub struct Test { + #[cfg_attr(debug_assertions$0, foo = "bar")] + test: u32, + } + "#, + r#" + pub struct Test { + #[foo = "bar"] + test: u32, + } + "#, + ); + } + #[test] + fn inner_attrs() { + check_assist( + wrap_unwrap_cfg_attr, + r#" + #![no_std$0] + "#, + r#" + #![cfg_attr($0, no_std)] + "#, + ); + check_assist( + wrap_unwrap_cfg_attr, + r#" + #![cfg_attr(not(feature = "std")$0, no_std)] + "#, + r#" + #![no_std] + "#, + ); + } + #[test] + fn test_derive_wrap() { + check_assist( + wrap_unwrap_cfg_attr, + r#" + #[derive(Debug$0, Clone, Copy)] + pub struct Test { + test: u32, + } + "#, + r#" + #[derive( Clone, Copy)] + #[cfg_attr($0, derive(Debug))] + pub struct Test { + test: u32, + } + "#, + ); + check_assist( + wrap_unwrap_cfg_attr, + r#" + #[derive(Clone, Debug$0, Copy)] + pub struct Test { + test: u32, + } + "#, + r#" + #[derive(Clone, Copy)] + #[cfg_attr($0, derive(Debug))] + pub struct Test { + test: u32, + } + "#, + ); + } + #[test] + fn test_derive_wrap_with_path() { + check_assist( + wrap_unwrap_cfg_attr, + r#" + #[derive(std::fmt::Debug$0, Clone, Copy)] + pub struct Test { + test: u32, + } + "#, + r#" + #[derive( Clone, Copy)] + #[cfg_attr($0, derive(std::fmt::Debug))] + pub struct Test { + test: u32, + } + "#, + ); + check_assist( + wrap_unwrap_cfg_attr, + r#" + #[derive(Clone, std::fmt::Debug$0, Copy)] + pub struct Test { + test: u32, + } + "#, + r#" + #[derive(Clone, Copy)] + #[cfg_attr($0, derive(std::fmt::Debug))] + pub struct Test { + test: u32, + } + "#, + ); + } +} diff --git a/crates/ide-assists/src/lib.rs b/crates/ide-assists/src/lib.rs index 8f0b8f861c2..1364120ea23 100644 --- a/crates/ide-assists/src/lib.rs +++ b/crates/ide-assists/src/lib.rs @@ -217,6 +217,7 @@ mod handlers { mod unwrap_result_return_type; mod unwrap_tuple; mod wrap_return_type_in_result; + mod wrap_unwrap_cfg_attr; pub(crate) fn all() -> &'static [Handler] { &[ @@ -342,6 +343,8 @@ pub(crate) fn all() -> &'static [Handler] { unwrap_tuple::unwrap_tuple, unqualify_method_call::unqualify_method_call, wrap_return_type_in_result::wrap_return_type_in_result, + wrap_unwrap_cfg_attr::wrap_unwrap_cfg_attr, + // These are manually sorted for better priorities. By default, // priority is determined by the size of the target range (smaller // target wins). If the ranges are equal, position in this list is From ecac8e3514b84104dc932cb92524cd3152453218 Mon Sep 17 00:00:00 2001 From: Wyatt Herkamp Date: Sun, 24 Mar 2024 10:37:41 -0400 Subject: [PATCH 2/3] Format and codegen for attr --- crates/ide-assists/src/lib.rs | 2 +- crates/ide-assists/src/tests/generated.rs | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/crates/ide-assists/src/lib.rs b/crates/ide-assists/src/lib.rs index 1364120ea23..024448ef1e7 100644 --- a/crates/ide-assists/src/lib.rs +++ b/crates/ide-assists/src/lib.rs @@ -344,7 +344,7 @@ pub(crate) fn all() -> &'static [Handler] { unqualify_method_call::unqualify_method_call, wrap_return_type_in_result::wrap_return_type_in_result, wrap_unwrap_cfg_attr::wrap_unwrap_cfg_attr, - + // These are manually sorted for better priorities. By default, // priority is determined by the size of the target range (smaller // target wins). If the ranges are equal, position in this list is diff --git a/crates/ide-assists/src/tests/generated.rs b/crates/ide-assists/src/tests/generated.rs index a66e199a75b..2bf09347672 100644 --- a/crates/ide-assists/src/tests/generated.rs +++ b/crates/ide-assists/src/tests/generated.rs @@ -3151,3 +3151,22 @@ fn foo() -> Result { Ok(42i32) } "#####, ) } + +#[test] +fn doctest_wrap_unwrap_cfg_attr() { + check_doc_test( + "wrap_unwrap_cfg_attr", + r#####" +#[derive$0(Debug)] +struct S { + field: i32 +} +"#####, + r#####" +#[cfg_attr($0, derive(Debug))] +struct S { + field: i32 +} +"#####, + ) +} From e3f9a0afe143f4f589c1d57bb1aab6c5770dc14a Mon Sep 17 00:00:00 2001 From: Wyatt Herkamp Date: Sun, 24 Mar 2024 10:38:03 -0400 Subject: [PATCH 3/3] Fixed cursor being at end --- .../src/handlers/wrap_unwrap_cfg_attr.rs | 73 +++++++++++++++---- 1 file changed, 58 insertions(+), 15 deletions(-) diff --git a/crates/ide-assists/src/handlers/wrap_unwrap_cfg_attr.rs b/crates/ide-assists/src/handlers/wrap_unwrap_cfg_attr.rs index 4b8fb1e78a8..0fa46ef43a1 100644 --- a/crates/ide-assists/src/handlers/wrap_unwrap_cfg_attr.rs +++ b/crates/ide-assists/src/handlers/wrap_unwrap_cfg_attr.rs @@ -4,7 +4,7 @@ algo, ast::{self, make, AstNode}, ted::{self, Position}, - AstToken, NodeOrToken, SyntaxToken, TextRange, T, + NodeOrToken, SyntaxToken, TextRange, T, }; use crate::{AssistContext, AssistId, AssistKind, Assists}; @@ -55,39 +55,46 @@ fn attempt_get_derive(attr: ast::Attr, ident: SyntaxToken) -> WrapUnwrapOption { syntax::Direction::Next, )?; if (prev.kind() == T![,] || prev.kind() == T!['(']) - && (following.kind() == T![,] || following.kind() == T!['(']) + && (following.kind() == T![,] || following.kind() == T![')']) { // This would be a single ident such as Debug. As no path is present if following.kind() == T![,] { derive = derive.cover(following.text_range()); + } else if following.kind() == T![')'] && prev.kind() == T![,] { + derive = derive.cover(prev.text_range()); } Some(WrapUnwrapOption::WrapDerive { derive, attr: attr.clone() }) } else { + let mut consumed_comma = false; // Collect the path - while let Some(prev_token) = algo::skip_trivia_token(prev, syntax::Direction::Prev) { let kind = prev_token.kind(); - if kind == T![,] || kind == T!['('] { + if kind == T![,] { + consumed_comma = true; + derive = derive.cover(prev_token.text_range()); break; + } else if kind == T!['('] { + break; + } else { + derive = derive.cover(prev_token.text_range()); } - derive = derive.cover(prev_token.text_range()); prev = prev_token.prev_sibling_or_token()?.into_token()?; } while let Some(next_token) = algo::skip_trivia_token(following.clone(), syntax::Direction::Next) { let kind = next_token.kind(); - if kind != T![')'] { - // We also want to consume a following comma - derive = derive.cover(next_token.text_range()); + match kind { + T![,] if !consumed_comma => { + derive = derive.cover(next_token.text_range()); + break; + } + T![')'] | T![,] => break, + _ => derive = derive.cover(next_token.text_range()), } following = next_token.next_sibling_or_token()?.into_token()?; - - if kind == T![,] || kind == T![')'] { - break; - } } Some(WrapUnwrapOption::WrapDerive { derive, attr: attr.clone() }) } @@ -103,7 +110,7 @@ fn attempt_get_derive(attr: ast::Attr, ident: SyntaxToken) -> WrapUnwrapOption { } pub(crate) fn wrap_unwrap_cfg_attr(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { let option = if ctx.has_empty_selection() { - let ident = ctx.find_token_at_offset::().map(|v| v.syntax().clone()); + let ident = ctx.find_token_syntax_at_offset(T![ident]); let attr = ctx.find_node_at_offset::(); match (attr, ident) { (Some(attr), Some(ident)) @@ -111,6 +118,7 @@ pub(crate) fn wrap_unwrap_cfg_attr(acc: &mut Assists, ctx: &AssistContext<'_>) - { Some(attempt_get_derive(attr.clone(), ident)) } + (Some(attr), _) => Some(WrapUnwrapOption::WrapAttr(attr)), _ => None, } @@ -156,7 +164,7 @@ fn wrap_derive( } if derive_element.contains_range(token.text_range()) { - if token.kind() != T![,] { + if token.kind() != T![,] && token.kind() != syntax::SyntaxKind::WHITESPACE { path_text.push_str(token.text()); cfg_derive_tokens.push(NodeOrToken::Token(token)); } @@ -527,7 +535,42 @@ pub struct Test { } "#, r#" - #[derive(Clone, Copy)] + #[derive(Clone, Copy)] + #[cfg_attr($0, derive(std::fmt::Debug))] + pub struct Test { + test: u32, + } + "#, + ); + } + #[test] + fn test_derive_wrap_at_end() { + check_assist( + wrap_unwrap_cfg_attr, + r#" + #[derive(std::fmt::Debug, Clone, Cop$0y)] + pub struct Test { + test: u32, + } + "#, + r#" + #[derive(std::fmt::Debug, Clone)] + #[cfg_attr($0, derive(Copy))] + pub struct Test { + test: u32, + } + "#, + ); + check_assist( + wrap_unwrap_cfg_attr, + r#" + #[derive(Clone, Copy, std::fmt::D$0ebug)] + pub struct Test { + test: u32, + } + "#, + r#" + #[derive(Clone, Copy)] #[cfg_attr($0, derive(std::fmt::Debug))] pub struct Test { test: u32,