Auto merge of #15700 - rmehri01:15694_iterator_demorgan, r=Veykril

feat: add assist for applying De Morgan's law to `Iterator::all` and `Iterator::any`

This PR adds an assist for transforming expressions of the form `!iter.any(|x| predicate(x))` into `iter.all(|x| !predicate(x))` and vice versa.

[IteratorDeMorgans.webm](https://github.com/rust-lang/rust-analyzer/assets/52933714/aad1a299-6620-432b-9106-aafd2a7fa9f5)

Closes #15694
This commit is contained in:
bors 2023-10-04 11:08:44 +00:00
commit 7e9da40078
3 changed files with 357 additions and 2 deletions

View File

@ -1,7 +1,13 @@
use std::collections::VecDeque; use std::collections::VecDeque;
use ide_db::{
assists::GroupLabel,
famous_defs::FamousDefs,
source_change::SourceChangeBuilder,
syntax_helpers::node_ext::{for_each_tail_expr, walk_expr},
};
use syntax::{ use syntax::{
ast::{self, AstNode, Expr::BinExpr}, ast::{self, make, AstNode, Expr::BinExpr, HasArgList},
ted::{self, Position}, ted::{self, Position},
SyntaxKind, SyntaxKind,
}; };
@ -89,7 +95,8 @@ pub(crate) fn apply_demorgan(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti
let dm_lhs = demorganed.lhs()?; let dm_lhs = demorganed.lhs()?;
acc.add( acc.add_group(
&GroupLabel("Apply De Morgan's law".to_string()),
AssistId("apply_demorgan", AssistKind::RefactorRewrite), AssistId("apply_demorgan", AssistKind::RefactorRewrite),
"Apply De Morgan's law", "Apply De Morgan's law",
op_range, op_range,
@ -143,6 +150,127 @@ pub(crate) fn apply_demorgan(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti
) )
} }
// Assist: apply_demorgan_iterator
//
// Apply https://en.wikipedia.org/wiki/De_Morgan%27s_laws[De Morgan's law] to
// `Iterator::all` and `Iterator::any`.
//
// This transforms expressions of the form `!iter.any(|x| predicate(x))` into
// `iter.all(|x| !predicate(x))` and vice versa. This also works the other way for
// `Iterator::all` into `Iterator::any`.
//
// ```
// # //- minicore: iterator
// fn main() {
// let arr = [1, 2, 3];
// if !arr.into_iter().$0any(|num| num == 4) {
// println!("foo");
// }
// }
// ```
// ->
// ```
// fn main() {
// let arr = [1, 2, 3];
// if arr.into_iter().all(|num| num != 4) {
// println!("foo");
// }
// }
// ```
pub(crate) fn apply_demorgan_iterator(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
let method_call: ast::MethodCallExpr = ctx.find_node_at_offset()?;
let (name, arg_expr) = validate_method_call_expr(ctx, &method_call)?;
let ast::Expr::ClosureExpr(closure_expr) = arg_expr else { return None };
let closure_body = closure_expr.body()?;
let op_range = method_call.syntax().text_range();
let label = format!("Apply De Morgan's law to `Iterator::{}`", name.text().as_str());
acc.add_group(
&GroupLabel("Apply De Morgan's law".to_string()),
AssistId("apply_demorgan_iterator", AssistKind::RefactorRewrite),
label,
op_range,
|edit| {
// replace the method name
let new_name = match name.text().as_str() {
"all" => make::name_ref("any"),
"any" => make::name_ref("all"),
_ => unreachable!(),
}
.clone_for_update();
edit.replace_ast(name, new_name);
// negate all tail expressions in the closure body
let tail_cb = &mut |e: &_| tail_cb_impl(edit, e);
walk_expr(&closure_body, &mut |expr| {
if let ast::Expr::ReturnExpr(ret_expr) = expr {
if let Some(ret_expr_arg) = &ret_expr.expr() {
for_each_tail_expr(ret_expr_arg, tail_cb);
}
}
});
for_each_tail_expr(&closure_body, tail_cb);
// negate the whole method call
if let Some(prefix_expr) = method_call
.syntax()
.parent()
.and_then(ast::PrefixExpr::cast)
.filter(|prefix_expr| matches!(prefix_expr.op_kind(), Some(ast::UnaryOp::Not)))
{
edit.delete(
prefix_expr
.op_token()
.expect("prefix expression always has an operator")
.text_range(),
);
} else {
edit.insert(method_call.syntax().text_range().start(), "!");
}
},
)
}
/// Ensures that the method call is to `Iterator::all` or `Iterator::any`.
fn validate_method_call_expr(
ctx: &AssistContext<'_>,
method_call: &ast::MethodCallExpr,
) -> Option<(ast::NameRef, ast::Expr)> {
let name_ref = method_call.name_ref()?;
if name_ref.text() != "all" && name_ref.text() != "any" {
return None;
}
let arg_expr = method_call.arg_list()?.args().next()?;
let sema = &ctx.sema;
let receiver = method_call.receiver()?;
let it_type = sema.type_of_expr(&receiver)?.adjusted();
let module = sema.scope(receiver.syntax())?.module();
let krate = module.krate();
let iter_trait = FamousDefs(sema, krate).core_iter_Iterator()?;
it_type.impls_trait(sema.db, iter_trait, &[]).then_some((name_ref, arg_expr))
}
fn tail_cb_impl(edit: &mut SourceChangeBuilder, e: &ast::Expr) {
match e {
ast::Expr::BreakExpr(break_expr) => {
if let Some(break_expr_arg) = break_expr.expr() {
for_each_tail_expr(&break_expr_arg, &mut |e| tail_cb_impl(edit, e))
}
}
ast::Expr::ReturnExpr(_) => {
// all return expressions have already been handled by the walk loop
}
e => {
let inverted_body = invert_boolean_expression(e.clone());
edit.replace(e.syntax().text_range(), inverted_body.syntax().text());
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@ -255,4 +383,206 @@ fn demorgan_removes_pars_in_eq_precedence() {
"fn() { let x = a && b && c; }", "fn() { let x = a && b && c; }",
) )
} }
#[test]
fn demorgan_iterator_any_all_reverse() {
check_assist(
apply_demorgan_iterator,
r#"
//- minicore: iterator
fn main() {
let arr = [1, 2, 3];
if arr.into_iter().all(|num| num $0!= 4) {
println!("foo");
}
}
"#,
r#"
fn main() {
let arr = [1, 2, 3];
if !arr.into_iter().any(|num| num == 4) {
println!("foo");
}
}
"#,
);
}
#[test]
fn demorgan_iterator_all_any() {
check_assist(
apply_demorgan_iterator,
r#"
//- minicore: iterator
fn main() {
let arr = [1, 2, 3];
if !arr.into_iter().$0all(|num| num > 3) {
println!("foo");
}
}
"#,
r#"
fn main() {
let arr = [1, 2, 3];
if arr.into_iter().any(|num| num <= 3) {
println!("foo");
}
}
"#,
);
}
#[test]
fn demorgan_iterator_multiple_terms() {
check_assist(
apply_demorgan_iterator,
r#"
//- minicore: iterator
fn main() {
let arr = [1, 2, 3];
if !arr.into_iter().$0any(|num| num > 3 && num == 23 && num <= 30) {
println!("foo");
}
}
"#,
r#"
fn main() {
let arr = [1, 2, 3];
if arr.into_iter().all(|num| !(num > 3 && num == 23 && num <= 30)) {
println!("foo");
}
}
"#,
);
}
#[test]
fn demorgan_iterator_double_negation() {
check_assist(
apply_demorgan_iterator,
r#"
//- minicore: iterator
fn main() {
let arr = [1, 2, 3];
if !arr.into_iter().$0all(|num| !(num > 3)) {
println!("foo");
}
}
"#,
r#"
fn main() {
let arr = [1, 2, 3];
if arr.into_iter().any(|num| num > 3) {
println!("foo");
}
}
"#,
);
}
#[test]
fn demorgan_iterator_double_parens() {
check_assist(
apply_demorgan_iterator,
r#"
//- minicore: iterator
fn main() {
let arr = [1, 2, 3];
if !arr.into_iter().$0any(|num| (num > 3 && (num == 1 || num == 2))) {
println!("foo");
}
}
"#,
r#"
fn main() {
let arr = [1, 2, 3];
if arr.into_iter().all(|num| !(num > 3 && (num == 1 || num == 2))) {
println!("foo");
}
}
"#,
);
}
#[test]
fn demorgan_iterator_multiline() {
check_assist(
apply_demorgan_iterator,
r#"
//- minicore: iterator
fn main() {
let arr = [1, 2, 3];
if arr
.into_iter()
.all$0(|num| !num.is_negative())
{
println!("foo");
}
}
"#,
r#"
fn main() {
let arr = [1, 2, 3];
if !arr
.into_iter()
.any(|num| num.is_negative())
{
println!("foo");
}
}
"#,
);
}
#[test]
fn demorgan_iterator_block_closure() {
check_assist(
apply_demorgan_iterator,
r#"
//- minicore: iterator
fn main() {
let arr = [-1, 1, 2, 3];
if arr.into_iter().all(|num: i32| {
$0if num.is_positive() {
num <= 3
} else {
num >= -1
}
}) {
println!("foo");
}
}
"#,
r#"
fn main() {
let arr = [-1, 1, 2, 3];
if !arr.into_iter().any(|num: i32| {
if num.is_positive() {
num > 3
} else {
num < -1
}
}) {
println!("foo");
}
}
"#,
);
}
#[test]
fn demorgan_iterator_wrong_method() {
check_assist_not_applicable(
apply_demorgan_iterator,
r#"
//- minicore: iterator
fn main() {
let arr = [1, 2, 3];
if !arr.into_iter().$0map(|num| num > 3) {
println!("foo");
}
}
"#,
);
}
} }

View File

@ -226,6 +226,7 @@ pub(crate) fn all() -> &'static [Handler] {
add_return_type::add_return_type, add_return_type::add_return_type,
add_turbo_fish::add_turbo_fish, add_turbo_fish::add_turbo_fish,
apply_demorgan::apply_demorgan, apply_demorgan::apply_demorgan,
apply_demorgan::apply_demorgan_iterator,
auto_import::auto_import, auto_import::auto_import,
bind_unused_param::bind_unused_param, bind_unused_param::bind_unused_param,
bool_to_enum::bool_to_enum, bool_to_enum::bool_to_enum,

View File

@ -244,6 +244,30 @@ fn main() {
) )
} }
#[test]
fn doctest_apply_demorgan_iterator() {
check_doc_test(
"apply_demorgan_iterator",
r#####"
//- minicore: iterator
fn main() {
let arr = [1, 2, 3];
if !arr.into_iter().$0any(|num| num == 4) {
println!("foo");
}
}
"#####,
r#####"
fn main() {
let arr = [1, 2, 3];
if arr.into_iter().all(|num| num != 4) {
println!("foo");
}
}
"#####,
)
}
#[test] #[test]
fn doctest_auto_import() { fn doctest_auto_import() {
check_doc_test( check_doc_test(