Auto merge of #15667 - rmehri01:bool_to_enum_top_level, r=Veykril

fix: make bool_to_enum assist create enum at top-level

This pr makes the `bool_to_enum` assist create the `enum` at the next closest module block or at top-level, which fixes a few tricky cases such as with an associated `const` in a trait or module:

```rust
trait Foo {
    const $0BOOL: bool;
}

impl Foo for usize {
    const BOOL: bool = true;
}

fn main() {
    if <usize as Foo>::BOOL {
        println!("foo");
    }
}
```

Which now properly produces:

```rust
#[derive(PartialEq, Eq)]
enum Bool { True, False }

trait Foo {
    const BOOL: Bool;
}

impl Foo for usize {
    const BOOL: Bool = Bool::True;
}

fn main() {
    if <usize as Foo>::BOOL == Bool::True {
        println!("foo");
    }
}
```

I also think it's a bit nicer, especially for local variables, but didn't really know to do it in the first PR :)
This commit is contained in:
bors 2023-09-29 10:20:11 +00:00
commit 87e2c310f9
2 changed files with 194 additions and 36 deletions

View File

@ -16,7 +16,7 @@ use syntax::{
edit_in_place::{AttrsOwnerEdit, Indent}, edit_in_place::{AttrsOwnerEdit, Indent},
make, HasName, make, HasName,
}, },
ted, AstNode, NodeOrToken, SyntaxNode, T, ted, AstNode, NodeOrToken, SyntaxKind, SyntaxNode, T,
}; };
use text_edit::TextRange; use text_edit::TextRange;
@ -40,10 +40,10 @@ use crate::assist_context::{AssistContext, Assists};
// ``` // ```
// -> // ->
// ``` // ```
// fn main() {
// #[derive(PartialEq, Eq)] // #[derive(PartialEq, Eq)]
// enum Bool { True, False } // enum Bool { True, False }
// //
// fn main() {
// let bool = Bool::True; // let bool = Bool::True;
// //
// if bool == Bool::True { // if bool == Bool::True {
@ -270,6 +270,15 @@ fn replace_usages(
} }
_ => (), _ => (),
} }
} else if let Some((ty_annotation, initializer)) = find_assoc_const_usage(&new_name)
{
edit.replace(ty_annotation.syntax().text_range(), "Bool");
replace_bool_expr(edit, initializer);
} else if let Some(receiver) = find_method_call_expr_usage(&new_name) {
edit.replace(
receiver.syntax().text_range(),
format!("({} == Bool::True)", receiver),
);
} else if new_name.syntax().ancestors().find_map(ast::UseTree::cast).is_none() { } else if new_name.syntax().ancestors().find_map(ast::UseTree::cast).is_none() {
// for any other usage in an expression, replace it with a check that it is the true variant // for any other usage in an expression, replace it with a check that it is the true variant
if let Some((record_field, expr)) = new_name if let Some((record_field, expr)) = new_name
@ -413,6 +422,26 @@ fn find_record_pat_field_usage(name: &ast::NameLike) -> Option<ast::Pat> {
} }
} }
fn find_assoc_const_usage(name: &ast::NameLike) -> Option<(ast::Type, ast::Expr)> {
let const_ = name.syntax().parent().and_then(ast::Const::cast)?;
if const_.syntax().parent().and_then(ast::AssocItemList::cast).is_none() {
return None;
}
Some((const_.ty()?, const_.body()?))
}
fn find_method_call_expr_usage(name: &ast::NameLike) -> Option<ast::Expr> {
let method_call = name.syntax().ancestors().find_map(ast::MethodCallExpr::cast)?;
let receiver = method_call.receiver()?;
if !receiver.syntax().descendants().contains(name.syntax()) {
return None;
}
Some(receiver)
}
/// Adds the definition of the new enum before the target node. /// Adds the definition of the new enum before the target node.
fn add_enum_def( fn add_enum_def(
edit: &mut SourceChangeBuilder, edit: &mut SourceChangeBuilder,
@ -430,11 +459,12 @@ fn add_enum_def(
.any(|module| module.nearest_non_block_module(ctx.db()) != *target_module); .any(|module| module.nearest_non_block_module(ctx.db()) != *target_module);
let enum_def = make_bool_enum(make_enum_pub); let enum_def = make_bool_enum(make_enum_pub);
let indent = IndentLevel::from_node(&target_node); let insert_before = node_to_insert_before(target_node);
let indent = IndentLevel::from_node(&insert_before);
enum_def.reindent_to(indent); enum_def.reindent_to(indent);
ted::insert_all( ted::insert_all(
ted::Position::before(&edit.make_syntax_mut(target_node)), ted::Position::before(&edit.make_syntax_mut(insert_before)),
vec![ vec![
enum_def.syntax().clone().into(), enum_def.syntax().clone().into(),
make::tokens::whitespace(&format!("\n\n{indent}")).into(), make::tokens::whitespace(&format!("\n\n{indent}")).into(),
@ -442,6 +472,18 @@ fn add_enum_def(
); );
} }
/// Finds where to put the new enum definition.
/// Tries to find the ast node at the nearest module or at top-level, otherwise just
/// returns the input node.
fn node_to_insert_before(target_node: SyntaxNode) -> SyntaxNode {
target_node
.ancestors()
.take_while(|it| !matches!(it.kind(), SyntaxKind::MODULE | SyntaxKind::SOURCE_FILE))
.filter(|it| ast::Item::can_cast(it.kind()))
.last()
.unwrap_or(target_node)
}
fn make_bool_enum(make_pub: bool) -> ast::Enum { fn make_bool_enum(make_pub: bool) -> ast::Enum {
let enum_def = make::enum_( let enum_def = make::enum_(
if make_pub { Some(make::visibility_pub()) } else { None }, if make_pub { Some(make::visibility_pub()) } else { None },
@ -491,10 +533,10 @@ fn main() {
} }
"#, "#,
r#" r#"
fn main() { #[derive(PartialEq, Eq)]
#[derive(PartialEq, Eq)] enum Bool { True, False }
enum Bool { True, False }
fn main() {
let foo = Bool::True; let foo = Bool::True;
if foo == Bool::True { if foo == Bool::True {
@ -520,10 +562,10 @@ fn main() {
} }
"#, "#,
r#" r#"
fn main() { #[derive(PartialEq, Eq)]
#[derive(PartialEq, Eq)] enum Bool { True, False }
enum Bool { True, False }
fn main() {
let foo = Bool::True; let foo = Bool::True;
if foo == Bool::False { if foo == Bool::False {
@ -545,10 +587,10 @@ fn main() {
} }
"#, "#,
r#" r#"
fn main() { #[derive(PartialEq, Eq)]
#[derive(PartialEq, Eq)] enum Bool { True, False }
enum Bool { True, False }
fn main() {
let foo: Bool = Bool::False; let foo: Bool = Bool::False;
} }
"#, "#,
@ -565,10 +607,10 @@ fn main() {
} }
"#, "#,
r#" r#"
fn main() { #[derive(PartialEq, Eq)]
#[derive(PartialEq, Eq)] enum Bool { True, False }
enum Bool { True, False }
fn main() {
let foo = if 1 == 2 { Bool::True } else { Bool::False }; let foo = if 1 == 2 { Bool::True } else { Bool::False };
} }
"#, "#,
@ -590,10 +632,10 @@ fn main() {
} }
"#, "#,
r#" r#"
fn main() { #[derive(PartialEq, Eq)]
#[derive(PartialEq, Eq)] enum Bool { True, False }
enum Bool { True, False }
fn main() {
let foo = Bool::False; let foo = Bool::False;
let bar = true; let bar = true;
@ -619,10 +661,10 @@ fn main() {
} }
"#, "#,
r#" r#"
fn main() { #[derive(PartialEq, Eq)]
#[derive(PartialEq, Eq)] enum Bool { True, False }
enum Bool { True, False }
fn main() {
let foo = Bool::True; let foo = Bool::True;
if *&foo == Bool::True { if *&foo == Bool::True {
@ -645,10 +687,10 @@ fn main() {
} }
"#, "#,
r#" r#"
fn main() { #[derive(PartialEq, Eq)]
#[derive(PartialEq, Eq)] enum Bool { True, False }
enum Bool { True, False }
fn main() {
let foo: Bool; let foo: Bool;
foo = Bool::True; foo = Bool::True;
} }
@ -671,10 +713,10 @@ fn main() {
} }
"#, "#,
r#" r#"
fn main() { #[derive(PartialEq, Eq)]
#[derive(PartialEq, Eq)] enum Bool { True, False }
enum Bool { True, False }
fn main() {
let foo = Bool::True; let foo = Bool::True;
let bar = foo == Bool::False; let bar = foo == Bool::False;
@ -702,11 +744,11 @@ fn main() {
} }
"#, "#,
r#" r#"
#[derive(PartialEq, Eq)]
enum Bool { True, False }
fn main() { fn main() {
if !"foo".chars().any(|c| { if !"foo".chars().any(|c| {
#[derive(PartialEq, Eq)]
enum Bool { True, False }
let foo = Bool::True; let foo = Bool::True;
foo == Bool::True foo == Bool::True
}) { }) {
@ -1244,6 +1286,38 @@ fn main() {
) )
} }
#[test]
fn field_method_chain_usage() {
check_assist(
bool_to_enum,
r#"
struct Foo {
$0bool: bool,
}
fn main() {
let foo = Foo { bool: true };
foo.bool.then(|| 2);
}
"#,
r#"
#[derive(PartialEq, Eq)]
enum Bool { True, False }
struct Foo {
bool: Bool,
}
fn main() {
let foo = Foo { bool: Bool::True };
(foo.bool == Bool::True).then(|| 2);
}
"#,
)
}
#[test] #[test]
fn field_non_bool() { fn field_non_bool() {
cov_mark::check!(not_applicable_non_bool_field); cov_mark::check!(not_applicable_non_bool_field);
@ -1445,6 +1519,90 @@ pub mod bar {
) )
} }
#[test]
fn const_in_impl_cross_file() {
check_assist(
bool_to_enum,
r#"
//- /main.rs
mod foo;
struct Foo;
impl Foo {
pub const $0BOOL: bool = true;
}
//- /foo.rs
use crate::Foo;
fn foo() -> bool {
Foo::BOOL
}
"#,
r#"
//- /main.rs
mod foo;
struct Foo;
#[derive(PartialEq, Eq)]
pub enum Bool { True, False }
impl Foo {
pub const BOOL: Bool = Bool::True;
}
//- /foo.rs
use crate::{Foo, Bool};
fn foo() -> bool {
Foo::BOOL == Bool::True
}
"#,
)
}
#[test]
fn const_in_trait() {
check_assist(
bool_to_enum,
r#"
trait Foo {
const $0BOOL: bool;
}
impl Foo for usize {
const BOOL: bool = true;
}
fn main() {
if <usize as Foo>::BOOL {
println!("foo");
}
}
"#,
r#"
#[derive(PartialEq, Eq)]
enum Bool { True, False }
trait Foo {
const BOOL: Bool;
}
impl Foo for usize {
const BOOL: Bool = Bool::True;
}
fn main() {
if <usize as Foo>::BOOL == Bool::True {
println!("foo");
}
}
"#,
)
}
#[test] #[test]
fn const_non_bool() { fn const_non_bool() {
cov_mark::check!(not_applicable_non_bool_const); cov_mark::check!(not_applicable_non_bool_const);

View File

@ -294,10 +294,10 @@ fn main() {
} }
"#####, "#####,
r#####" r#####"
fn main() { #[derive(PartialEq, Eq)]
#[derive(PartialEq, Eq)] enum Bool { True, False }
enum Bool { True, False }
fn main() {
let bool = Bool::True; let bool = Bool::True;
if bool == Bool::True { if bool == Bool::True {