From bce4be9478f18eded9412bb064d5ef193189d9aa Mon Sep 17 00:00:00 2001 From: Ryan Mehri Date: Mon, 25 Sep 2023 21:01:54 -0700 Subject: [PATCH 1/3] fix: make bool_to_enum assist create enum at top-level --- .../ide-assists/src/handlers/bool_to_enum.rs | 193 +++++++++++++++--- crates/ide-assists/src/tests/generated.rs | 6 +- 2 files changed, 163 insertions(+), 36 deletions(-) diff --git a/crates/ide-assists/src/handlers/bool_to_enum.rs b/crates/ide-assists/src/handlers/bool_to_enum.rs index 85b0b87d0c9..3303a2dd3c7 100644 --- a/crates/ide-assists/src/handlers/bool_to_enum.rs +++ b/crates/ide-assists/src/handlers/bool_to_enum.rs @@ -16,7 +16,7 @@ use syntax::{ edit_in_place::{AttrsOwnerEdit, Indent}, make, HasName, }, - ted, AstNode, NodeOrToken, SyntaxNode, T, + match_ast, ted, AstNode, NodeOrToken, SyntaxNode, T, }; use text_edit::TextRange; @@ -40,10 +40,10 @@ use crate::assist_context::{AssistContext, Assists}; // ``` // -> // ``` -// fn main() { -// #[derive(PartialEq, Eq)] -// enum Bool { True, False } +// #[derive(PartialEq, Eq)] +// enum Bool { True, False } // +// fn main() { // let bool = Bool::True; // // if bool == Bool::True { @@ -270,6 +270,10 @@ fn replace_usages( } _ => (), } + } else if let Some((ty_annotation, initializer)) = find_assoc_const_usage(&new_name) + { + edit.replace(ty_annotation.syntax().text_range(), "Bool"); + replace_bool_expr(edit, initializer); } else if new_name.syntax().ancestors().find_map(ast::UseTree::cast).is_none() { // for any other usage in an expression, replace it with a check that it is the true variant if let Some((record_field, expr)) = new_name @@ -413,6 +417,15 @@ fn find_record_pat_field_usage(name: &ast::NameLike) -> Option { } } +fn find_assoc_const_usage(name: &ast::NameLike) -> Option<(ast::Type, ast::Expr)> { + let const_ = name.syntax().parent().and_then(ast::Const::cast)?; + if const_.syntax().parent().and_then(ast::AssocItemList::cast).is_none() { + return None; + } + + Some((const_.ty()?, const_.body()?)) +} + /// Adds the definition of the new enum before the target node. fn add_enum_def( edit: &mut SourceChangeBuilder, @@ -430,11 +443,12 @@ fn add_enum_def( .any(|module| module.nearest_non_block_module(ctx.db()) != *target_module); let enum_def = make_bool_enum(make_enum_pub); - let indent = IndentLevel::from_node(&target_node); + let insert_before = node_to_insert_before(target_node); + let indent = IndentLevel::from_node(&insert_before); enum_def.reindent_to(indent); ted::insert_all( - ted::Position::before(&edit.make_syntax_mut(target_node)), + ted::Position::before(&edit.make_syntax_mut(insert_before)), vec![ enum_def.syntax().clone().into(), make::tokens::whitespace(&format!("\n\n{indent}")).into(), @@ -442,6 +456,35 @@ fn add_enum_def( ); } +/// Finds where to put the new enum definition, at the nearest module or at top-level. +fn node_to_insert_before(mut target_node: SyntaxNode) -> SyntaxNode { + let mut ancestors = target_node.ancestors(); + + while let Some(ancestor) = ancestors.next() { + match_ast! { + match ancestor { + ast::Item(item) => { + if item + .syntax() + .parent() + .and_then(|item_list| item_list.parent()) + .and_then(ast::Module::cast) + .is_some() + { + return ancestor; + } + }, + ast::SourceFile(_) => break, + _ => (), + } + } + + target_node = ancestor; + } + + target_node +} + fn make_bool_enum(make_pub: bool) -> ast::Enum { let enum_def = make::enum_( if make_pub { Some(make::visibility_pub()) } else { None }, @@ -491,10 +534,10 @@ fn main() { } "#, r#" -fn main() { - #[derive(PartialEq, Eq)] - enum Bool { True, False } +#[derive(PartialEq, Eq)] +enum Bool { True, False } +fn main() { let foo = Bool::True; if foo == Bool::True { @@ -520,10 +563,10 @@ fn main() { } "#, r#" -fn main() { - #[derive(PartialEq, Eq)] - enum Bool { True, False } +#[derive(PartialEq, Eq)] +enum Bool { True, False } +fn main() { let foo = Bool::True; if foo == Bool::False { @@ -545,10 +588,10 @@ fn main() { } "#, r#" -fn main() { - #[derive(PartialEq, Eq)] - enum Bool { True, False } +#[derive(PartialEq, Eq)] +enum Bool { True, False } +fn main() { let foo: Bool = Bool::False; } "#, @@ -565,10 +608,10 @@ fn main() { } "#, r#" -fn main() { - #[derive(PartialEq, Eq)] - enum Bool { True, False } +#[derive(PartialEq, Eq)] +enum Bool { True, False } +fn main() { let foo = if 1 == 2 { Bool::True } else { Bool::False }; } "#, @@ -590,10 +633,10 @@ fn main() { } "#, r#" -fn main() { - #[derive(PartialEq, Eq)] - enum Bool { True, False } +#[derive(PartialEq, Eq)] +enum Bool { True, False } +fn main() { let foo = Bool::False; let bar = true; @@ -619,10 +662,10 @@ fn main() { } "#, r#" -fn main() { - #[derive(PartialEq, Eq)] - enum Bool { True, False } +#[derive(PartialEq, Eq)] +enum Bool { True, False } +fn main() { let foo = Bool::True; if *&foo == Bool::True { @@ -645,10 +688,10 @@ fn main() { } "#, r#" -fn main() { - #[derive(PartialEq, Eq)] - enum Bool { True, False } +#[derive(PartialEq, Eq)] +enum Bool { True, False } +fn main() { let foo: Bool; foo = Bool::True; } @@ -671,10 +714,10 @@ fn main() { } "#, r#" -fn main() { - #[derive(PartialEq, Eq)] - enum Bool { True, False } +#[derive(PartialEq, Eq)] +enum Bool { True, False } +fn main() { let foo = Bool::True; let bar = foo == Bool::False; @@ -702,11 +745,11 @@ fn main() { } "#, r#" +#[derive(PartialEq, Eq)] +enum Bool { True, False } + fn main() { if !"foo".chars().any(|c| { - #[derive(PartialEq, Eq)] - enum Bool { True, False } - let foo = Bool::True; foo == Bool::True }) { @@ -1445,6 +1488,90 @@ pub mod bar { ) } + #[test] + fn const_in_impl_cross_file() { + check_assist( + bool_to_enum, + r#" +//- /main.rs +mod foo; + +struct Foo; + +impl Foo { + pub const $0BOOL: bool = true; +} + +//- /foo.rs +use crate::Foo; + +fn foo() -> bool { + Foo::BOOL +} +"#, + r#" +//- /main.rs +mod foo; + +struct Foo; + +#[derive(PartialEq, Eq)] +pub enum Bool { True, False } + +impl Foo { + pub const BOOL: Bool = Bool::True; +} + +//- /foo.rs +use crate::{Foo, Bool}; + +fn foo() -> bool { + Foo::BOOL == Bool::True +} +"#, + ) + } + + #[test] + fn const_in_trait() { + check_assist( + bool_to_enum, + r#" +trait Foo { + const $0BOOL: bool; +} + +impl Foo for usize { + const BOOL: bool = true; +} + +fn main() { + if ::BOOL { + println!("foo"); + } +} +"#, + r#" +#[derive(PartialEq, Eq)] +enum Bool { True, False } + +trait Foo { + const BOOL: Bool; +} + +impl Foo for usize { + const BOOL: Bool = Bool::True; +} + +fn main() { + if ::BOOL == Bool::True { + println!("foo"); + } +} +"#, + ) + } + #[test] fn const_non_bool() { cov_mark::check!(not_applicable_non_bool_const); diff --git a/crates/ide-assists/src/tests/generated.rs b/crates/ide-assists/src/tests/generated.rs index 63a08a0e569..5a815d5c6a1 100644 --- a/crates/ide-assists/src/tests/generated.rs +++ b/crates/ide-assists/src/tests/generated.rs @@ -294,10 +294,10 @@ fn main() { } "#####, r#####" -fn main() { - #[derive(PartialEq, Eq)] - enum Bool { True, False } +#[derive(PartialEq, Eq)] +enum Bool { True, False } +fn main() { let bool = Bool::True; if bool == Bool::True { From 73150c3f360d399126c1ce9ce2f9846b9f0b5293 Mon Sep 17 00:00:00 2001 From: Ryan Mehri Date: Mon, 25 Sep 2023 21:44:16 -0700 Subject: [PATCH 2/3] fix: wrap method call exprs in parens --- .../ide-assists/src/handlers/bool_to_enum.rs | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/crates/ide-assists/src/handlers/bool_to_enum.rs b/crates/ide-assists/src/handlers/bool_to_enum.rs index 3303a2dd3c7..1c0cbb9dfda 100644 --- a/crates/ide-assists/src/handlers/bool_to_enum.rs +++ b/crates/ide-assists/src/handlers/bool_to_enum.rs @@ -274,6 +274,11 @@ fn replace_usages( { edit.replace(ty_annotation.syntax().text_range(), "Bool"); replace_bool_expr(edit, initializer); + } else if let Some(receiver) = find_method_call_expr_usage(&new_name) { + edit.replace( + receiver.syntax().text_range(), + format!("({} == Bool::True)", receiver), + ); } else if new_name.syntax().ancestors().find_map(ast::UseTree::cast).is_none() { // for any other usage in an expression, replace it with a check that it is the true variant if let Some((record_field, expr)) = new_name @@ -426,6 +431,17 @@ fn find_assoc_const_usage(name: &ast::NameLike) -> Option<(ast::Type, ast::Expr) Some((const_.ty()?, const_.body()?)) } +fn find_method_call_expr_usage(name: &ast::NameLike) -> Option { + let method_call = name.syntax().ancestors().find_map(ast::MethodCallExpr::cast)?; + let receiver = method_call.receiver()?; + + if !receiver.syntax().descendants().contains(name.syntax()) { + return None; + } + + Some(receiver) +} + /// Adds the definition of the new enum before the target node. fn add_enum_def( edit: &mut SourceChangeBuilder, @@ -1287,6 +1303,38 @@ fn main() { ) } + #[test] + fn field_method_chain_usage() { + check_assist( + bool_to_enum, + r#" +struct Foo { + $0bool: bool, +} + +fn main() { + let foo = Foo { bool: true }; + + foo.bool.then(|| 2); +} +"#, + r#" +#[derive(PartialEq, Eq)] +enum Bool { True, False } + +struct Foo { + bool: Bool, +} + +fn main() { + let foo = Foo { bool: Bool::True }; + + (foo.bool == Bool::True).then(|| 2); +} +"#, + ) + } + #[test] fn field_non_bool() { cov_mark::check!(not_applicable_non_bool_field); From 1b3e5b2105b20b3237efaac17c4a9761890f6597 Mon Sep 17 00:00:00 2001 From: Ryan Mehri Date: Thu, 28 Sep 2023 10:09:13 -0700 Subject: [PATCH 3/3] style: simplify node_to_insert_before --- .../ide-assists/src/handlers/bool_to_enum.rs | 37 +++++-------------- 1 file changed, 10 insertions(+), 27 deletions(-) diff --git a/crates/ide-assists/src/handlers/bool_to_enum.rs b/crates/ide-assists/src/handlers/bool_to_enum.rs index 1c0cbb9dfda..082839118c5 100644 --- a/crates/ide-assists/src/handlers/bool_to_enum.rs +++ b/crates/ide-assists/src/handlers/bool_to_enum.rs @@ -16,7 +16,7 @@ use syntax::{ edit_in_place::{AttrsOwnerEdit, Indent}, make, HasName, }, - match_ast, ted, AstNode, NodeOrToken, SyntaxNode, T, + ted, AstNode, NodeOrToken, SyntaxKind, SyntaxNode, T, }; use text_edit::TextRange; @@ -472,33 +472,16 @@ fn add_enum_def( ); } -/// Finds where to put the new enum definition, at the nearest module or at top-level. -fn node_to_insert_before(mut target_node: SyntaxNode) -> SyntaxNode { - let mut ancestors = target_node.ancestors(); - - while let Some(ancestor) = ancestors.next() { - match_ast! { - match ancestor { - ast::Item(item) => { - if item - .syntax() - .parent() - .and_then(|item_list| item_list.parent()) - .and_then(ast::Module::cast) - .is_some() - { - return ancestor; - } - }, - ast::SourceFile(_) => break, - _ => (), - } - } - - target_node = ancestor; - } - +/// Finds where to put the new enum definition. +/// Tries to find the ast node at the nearest module or at top-level, otherwise just +/// returns the input node. +fn node_to_insert_before(target_node: SyntaxNode) -> SyntaxNode { target_node + .ancestors() + .take_while(|it| !matches!(it.kind(), SyntaxKind::MODULE | SyntaxKind::SOURCE_FILE)) + .filter(|it| ast::Item::can_cast(it.kind())) + .last() + .unwrap_or(target_node) } fn make_bool_enum(make_pub: bool) -> ast::Enum {