diff --git a/crates/ide-assists/src/handlers/bool_to_enum.rs b/crates/ide-assists/src/handlers/bool_to_enum.rs index 85b0b87d0c9..082839118c5 100644 --- a/crates/ide-assists/src/handlers/bool_to_enum.rs +++ b/crates/ide-assists/src/handlers/bool_to_enum.rs @@ -16,7 +16,7 @@ use syntax::{ edit_in_place::{AttrsOwnerEdit, Indent}, make, HasName, }, - ted, AstNode, NodeOrToken, SyntaxNode, T, + ted, AstNode, NodeOrToken, SyntaxKind, SyntaxNode, T, }; use text_edit::TextRange; @@ -40,10 +40,10 @@ use crate::assist_context::{AssistContext, Assists}; // ``` // -> // ``` -// fn main() { -// #[derive(PartialEq, Eq)] -// enum Bool { True, False } +// #[derive(PartialEq, Eq)] +// enum Bool { True, False } // +// fn main() { // let bool = Bool::True; // // if bool == Bool::True { @@ -270,6 +270,15 @@ fn replace_usages( } _ => (), } + } else if let Some((ty_annotation, initializer)) = find_assoc_const_usage(&new_name) + { + edit.replace(ty_annotation.syntax().text_range(), "Bool"); + replace_bool_expr(edit, initializer); + } else if let Some(receiver) = find_method_call_expr_usage(&new_name) { + edit.replace( + receiver.syntax().text_range(), + format!("({} == Bool::True)", receiver), + ); } else if new_name.syntax().ancestors().find_map(ast::UseTree::cast).is_none() { // for any other usage in an expression, replace it with a check that it is the true variant if let Some((record_field, expr)) = new_name @@ -413,6 +422,26 @@ fn find_record_pat_field_usage(name: &ast::NameLike) -> Option { } } +fn find_assoc_const_usage(name: &ast::NameLike) -> Option<(ast::Type, ast::Expr)> { + let const_ = name.syntax().parent().and_then(ast::Const::cast)?; + if const_.syntax().parent().and_then(ast::AssocItemList::cast).is_none() { + return None; + } + + Some((const_.ty()?, const_.body()?)) +} + +fn find_method_call_expr_usage(name: &ast::NameLike) -> Option { + let method_call = name.syntax().ancestors().find_map(ast::MethodCallExpr::cast)?; + let receiver = method_call.receiver()?; + + if !receiver.syntax().descendants().contains(name.syntax()) { + return None; + } + + Some(receiver) +} + /// Adds the definition of the new enum before the target node. fn add_enum_def( edit: &mut SourceChangeBuilder, @@ -430,11 +459,12 @@ fn add_enum_def( .any(|module| module.nearest_non_block_module(ctx.db()) != *target_module); let enum_def = make_bool_enum(make_enum_pub); - let indent = IndentLevel::from_node(&target_node); + let insert_before = node_to_insert_before(target_node); + let indent = IndentLevel::from_node(&insert_before); enum_def.reindent_to(indent); ted::insert_all( - ted::Position::before(&edit.make_syntax_mut(target_node)), + ted::Position::before(&edit.make_syntax_mut(insert_before)), vec![ enum_def.syntax().clone().into(), make::tokens::whitespace(&format!("\n\n{indent}")).into(), @@ -442,6 +472,18 @@ fn add_enum_def( ); } +/// Finds where to put the new enum definition. +/// Tries to find the ast node at the nearest module or at top-level, otherwise just +/// returns the input node. +fn node_to_insert_before(target_node: SyntaxNode) -> SyntaxNode { + target_node + .ancestors() + .take_while(|it| !matches!(it.kind(), SyntaxKind::MODULE | SyntaxKind::SOURCE_FILE)) + .filter(|it| ast::Item::can_cast(it.kind())) + .last() + .unwrap_or(target_node) +} + fn make_bool_enum(make_pub: bool) -> ast::Enum { let enum_def = make::enum_( if make_pub { Some(make::visibility_pub()) } else { None }, @@ -491,10 +533,10 @@ fn main() { } "#, r#" -fn main() { - #[derive(PartialEq, Eq)] - enum Bool { True, False } +#[derive(PartialEq, Eq)] +enum Bool { True, False } +fn main() { let foo = Bool::True; if foo == Bool::True { @@ -520,10 +562,10 @@ fn main() { } "#, r#" -fn main() { - #[derive(PartialEq, Eq)] - enum Bool { True, False } +#[derive(PartialEq, Eq)] +enum Bool { True, False } +fn main() { let foo = Bool::True; if foo == Bool::False { @@ -545,10 +587,10 @@ fn main() { } "#, r#" -fn main() { - #[derive(PartialEq, Eq)] - enum Bool { True, False } +#[derive(PartialEq, Eq)] +enum Bool { True, False } +fn main() { let foo: Bool = Bool::False; } "#, @@ -565,10 +607,10 @@ fn main() { } "#, r#" -fn main() { - #[derive(PartialEq, Eq)] - enum Bool { True, False } +#[derive(PartialEq, Eq)] +enum Bool { True, False } +fn main() { let foo = if 1 == 2 { Bool::True } else { Bool::False }; } "#, @@ -590,10 +632,10 @@ fn main() { } "#, r#" -fn main() { - #[derive(PartialEq, Eq)] - enum Bool { True, False } +#[derive(PartialEq, Eq)] +enum Bool { True, False } +fn main() { let foo = Bool::False; let bar = true; @@ -619,10 +661,10 @@ fn main() { } "#, r#" -fn main() { - #[derive(PartialEq, Eq)] - enum Bool { True, False } +#[derive(PartialEq, Eq)] +enum Bool { True, False } +fn main() { let foo = Bool::True; if *&foo == Bool::True { @@ -645,10 +687,10 @@ fn main() { } "#, r#" -fn main() { - #[derive(PartialEq, Eq)] - enum Bool { True, False } +#[derive(PartialEq, Eq)] +enum Bool { True, False } +fn main() { let foo: Bool; foo = Bool::True; } @@ -671,10 +713,10 @@ fn main() { } "#, r#" -fn main() { - #[derive(PartialEq, Eq)] - enum Bool { True, False } +#[derive(PartialEq, Eq)] +enum Bool { True, False } +fn main() { let foo = Bool::True; let bar = foo == Bool::False; @@ -702,11 +744,11 @@ fn main() { } "#, r#" +#[derive(PartialEq, Eq)] +enum Bool { True, False } + fn main() { if !"foo".chars().any(|c| { - #[derive(PartialEq, Eq)] - enum Bool { True, False } - let foo = Bool::True; foo == Bool::True }) { @@ -1244,6 +1286,38 @@ fn main() { ) } + #[test] + fn field_method_chain_usage() { + check_assist( + bool_to_enum, + r#" +struct Foo { + $0bool: bool, +} + +fn main() { + let foo = Foo { bool: true }; + + foo.bool.then(|| 2); +} +"#, + r#" +#[derive(PartialEq, Eq)] +enum Bool { True, False } + +struct Foo { + bool: Bool, +} + +fn main() { + let foo = Foo { bool: Bool::True }; + + (foo.bool == Bool::True).then(|| 2); +} +"#, + ) + } + #[test] fn field_non_bool() { cov_mark::check!(not_applicable_non_bool_field); @@ -1445,6 +1519,90 @@ pub mod bar { ) } + #[test] + fn const_in_impl_cross_file() { + check_assist( + bool_to_enum, + r#" +//- /main.rs +mod foo; + +struct Foo; + +impl Foo { + pub const $0BOOL: bool = true; +} + +//- /foo.rs +use crate::Foo; + +fn foo() -> bool { + Foo::BOOL +} +"#, + r#" +//- /main.rs +mod foo; + +struct Foo; + +#[derive(PartialEq, Eq)] +pub enum Bool { True, False } + +impl Foo { + pub const BOOL: Bool = Bool::True; +} + +//- /foo.rs +use crate::{Foo, Bool}; + +fn foo() -> bool { + Foo::BOOL == Bool::True +} +"#, + ) + } + + #[test] + fn const_in_trait() { + check_assist( + bool_to_enum, + r#" +trait Foo { + const $0BOOL: bool; +} + +impl Foo for usize { + const BOOL: bool = true; +} + +fn main() { + if ::BOOL { + println!("foo"); + } +} +"#, + r#" +#[derive(PartialEq, Eq)] +enum Bool { True, False } + +trait Foo { + const BOOL: Bool; +} + +impl Foo for usize { + const BOOL: Bool = Bool::True; +} + +fn main() { + if ::BOOL == Bool::True { + println!("foo"); + } +} +"#, + ) + } + #[test] fn const_non_bool() { cov_mark::check!(not_applicable_non_bool_const); diff --git a/crates/ide-assists/src/tests/generated.rs b/crates/ide-assists/src/tests/generated.rs index 63a08a0e569..5a815d5c6a1 100644 --- a/crates/ide-assists/src/tests/generated.rs +++ b/crates/ide-assists/src/tests/generated.rs @@ -294,10 +294,10 @@ fn main() { } "#####, r#####" -fn main() { - #[derive(PartialEq, Eq)] - enum Bool { True, False } +#[derive(PartialEq, Eq)] +enum Bool { True, False } +fn main() { let bool = Bool::True; if bool == Bool::True {