From 34d3490198fe6e7f56eb60c9665d25ef9cfd6f4e Mon Sep 17 00:00:00 2001 From: Ryan Mehri Date: Sun, 1 Oct 2023 21:27:06 -0700 Subject: [PATCH 1/2] feat: add assist for applying De Morgan's law to iterators --- .../src/handlers/apply_demorgan.rs | 329 +++++++++++++++++- crates/ide-assists/src/lib.rs | 1 + crates/ide-assists/src/tests/generated.rs | 24 ++ 3 files changed, 352 insertions(+), 2 deletions(-) diff --git a/crates/ide-assists/src/handlers/apply_demorgan.rs b/crates/ide-assists/src/handlers/apply_demorgan.rs index 66bc2f6dadc..74db300465a 100644 --- a/crates/ide-assists/src/handlers/apply_demorgan.rs +++ b/crates/ide-assists/src/handlers/apply_demorgan.rs @@ -1,7 +1,13 @@ use std::collections::VecDeque; +use ide_db::{ + assists::GroupLabel, + famous_defs::FamousDefs, + source_change::SourceChangeBuilder, + syntax_helpers::node_ext::{for_each_tail_expr, walk_expr}, +}; use syntax::{ - ast::{self, AstNode, Expr::BinExpr}, + ast::{self, make, AstNode, Expr::BinExpr, HasArgList}, ted::{self, Position}, SyntaxKind, }; @@ -89,7 +95,8 @@ pub(crate) fn apply_demorgan(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti let dm_lhs = demorganed.lhs()?; - acc.add( + acc.add_group( + &GroupLabel("Apply De Morgan's law".to_string()), AssistId("apply_demorgan", AssistKind::RefactorRewrite), "Apply De Morgan's law", op_range, @@ -143,6 +150,122 @@ pub(crate) fn apply_demorgan(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti ) } +// Assist: apply_demorgan_iterator +// +// Apply https://en.wikipedia.org/wiki/De_Morgan%27s_laws[De Morgan's law] to +// `Iterator::all` and `Iterator::any`. +// +// This transforms expressions of the form `!iter.any(|x| predicate(x))` into +// `iter.all(|x| !predicate(x))` and vice versa. This also works the other way for +// `Iterator::all` into `Iterator::any`. +// +// ``` +// # //- minicore: iterator +// fn main() { +// let arr = [1, 2, 3]; +// if !arr.into_iter().$0any(|num| num == 4) { +// println!("foo"); +// } +// } +// ``` +// -> +// ``` +// fn main() { +// let arr = [1, 2, 3]; +// if arr.into_iter().all(|num| num != 4) { +// println!("foo"); +// } +// } +// ``` +pub(crate) fn apply_demorgan_iterator(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let method_call: ast::MethodCallExpr = ctx.find_node_at_offset()?; + let (name, arg_expr) = validate_method_call_expr(ctx, &method_call)?; + + let ast::Expr::ClosureExpr(closure_expr) = arg_expr else { return None }; + let closure_body = closure_expr.body()?; + + let op_range = method_call.syntax().text_range(); + let label = format!("Apply De Morgan's law to `Iterator::{}`", name.text().as_str()); + acc.add_group( + &GroupLabel("Apply De Morgan's law".to_string()), + AssistId("apply_demorgan_iterator", AssistKind::RefactorRewrite), + label, + op_range, + |edit| { + // replace the method name + let new_name = match name.text().as_str() { + "all" => make::name_ref("any"), + "any" => make::name_ref("all"), + _ => unreachable!(), + } + .clone_for_update(); + edit.replace_ast(name, new_name); + + // negate all tail expressions in the closure body + let tail_cb = &mut |e: &_| tail_cb_impl(edit, e); + walk_expr(&closure_body, &mut |expr| { + if let ast::Expr::ReturnExpr(ret_expr) = expr { + if let Some(ret_expr_arg) = &ret_expr.expr() { + for_each_tail_expr(ret_expr_arg, tail_cb); + } + } + }); + for_each_tail_expr(&closure_body, tail_cb); + + // negate the whole method call + if let Some(prefix_expr) = method_call + .syntax() + .parent() + .and_then(ast::PrefixExpr::cast) + .filter(|prefix_expr| matches!(prefix_expr.op_kind(), Some(ast::UnaryOp::Not))) + { + edit.delete(prefix_expr.op_token().unwrap().text_range()); + } else { + edit.insert(method_call.syntax().text_range().start(), "!"); + } + }, + ) +} + +/// Ensures that the method call is to `Iterator::all` or `Iterator::any`. +fn validate_method_call_expr( + ctx: &AssistContext<'_>, + method_call: &ast::MethodCallExpr, +) -> Option<(ast::NameRef, ast::Expr)> { + let name_ref = method_call.name_ref()?; + if name_ref.text() != "all" && name_ref.text() != "any" { + return None; + } + let arg_expr = method_call.arg_list()?.args().next()?; + + let sema = &ctx.sema; + + let receiver = method_call.receiver()?; + let it_type = sema.type_of_expr(&receiver)?.adjusted(); + let module = sema.scope(receiver.syntax())?.module(); + let krate = module.krate(); + + let iter_trait = FamousDefs(sema, krate).core_iter_Iterator()?; + it_type.impls_trait(sema.db, iter_trait, &[]).then_some((name_ref, arg_expr)) +} + +fn tail_cb_impl(edit: &mut SourceChangeBuilder, e: &ast::Expr) { + match e { + ast::Expr::BreakExpr(break_expr) => { + if let Some(break_expr_arg) = break_expr.expr() { + for_each_tail_expr(&break_expr_arg, &mut |e| tail_cb_impl(edit, e)) + } + } + ast::Expr::ReturnExpr(_) => { + // all return expressions have already been handled by the walk loop + } + e => { + let inverted_body = invert_boolean_expression(e.clone()); + edit.replace(e.syntax().text_range(), inverted_body.syntax().text()); + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -255,4 +378,206 @@ fn demorgan_removes_pars_in_eq_precedence() { "fn() { let x = a && b && c; }", ) } + + #[test] + fn demorgan_iterator_any_all_reverse() { + check_assist( + apply_demorgan_iterator, + r#" +//- minicore: iterator +fn main() { + let arr = [1, 2, 3]; + if arr.into_iter().all(|num| num $0!= 4) { + println!("foo"); + } +} +"#, + r#" +fn main() { + let arr = [1, 2, 3]; + if !arr.into_iter().any(|num| num == 4) { + println!("foo"); + } +} +"#, + ); + } + + #[test] + fn demorgan_iterator_all_any() { + check_assist( + apply_demorgan_iterator, + r#" +//- minicore: iterator +fn main() { + let arr = [1, 2, 3]; + if !arr.into_iter().$0all(|num| num > 3) { + println!("foo"); + } +} +"#, + r#" +fn main() { + let arr = [1, 2, 3]; + if arr.into_iter().any(|num| num <= 3) { + println!("foo"); + } +} +"#, + ); + } + + #[test] + fn demorgan_iterator_multiple_terms() { + check_assist( + apply_demorgan_iterator, + r#" +//- minicore: iterator +fn main() { + let arr = [1, 2, 3]; + if !arr.into_iter().$0any(|num| num > 3 && num == 23 && num <= 30) { + println!("foo"); + } +} +"#, + r#" +fn main() { + let arr = [1, 2, 3]; + if arr.into_iter().all(|num| !(num > 3 && num == 23 && num <= 30)) { + println!("foo"); + } +} +"#, + ); + } + + #[test] + fn demorgan_iterator_double_negation() { + check_assist( + apply_demorgan_iterator, + r#" +//- minicore: iterator +fn main() { + let arr = [1, 2, 3]; + if !arr.into_iter().$0all(|num| !(num > 3)) { + println!("foo"); + } +} +"#, + r#" +fn main() { + let arr = [1, 2, 3]; + if arr.into_iter().any(|num| num > 3) { + println!("foo"); + } +} +"#, + ); + } + + #[test] + fn demorgan_iterator_double_parens() { + check_assist( + apply_demorgan_iterator, + r#" +//- minicore: iterator +fn main() { + let arr = [1, 2, 3]; + if !arr.into_iter().$0any(|num| (num > 3 && (num == 1 || num == 2))) { + println!("foo"); + } +} +"#, + r#" +fn main() { + let arr = [1, 2, 3]; + if arr.into_iter().all(|num| !(num > 3 && (num == 1 || num == 2))) { + println!("foo"); + } +} +"#, + ); + } + + #[test] + fn demorgan_iterator_multiline() { + check_assist( + apply_demorgan_iterator, + r#" +//- minicore: iterator +fn main() { + let arr = [1, 2, 3]; + if arr + .into_iter() + .all$0(|num| !num.is_negative()) + { + println!("foo"); + } +} +"#, + r#" +fn main() { + let arr = [1, 2, 3]; + if !arr + .into_iter() + .any(|num| num.is_negative()) + { + println!("foo"); + } +} +"#, + ); + } + + #[test] + fn demorgan_iterator_block_closure() { + check_assist( + apply_demorgan_iterator, + r#" +//- minicore: iterator +fn main() { + let arr = [-1, 1, 2, 3]; + if arr.into_iter().all(|num: i32| { + $0if num.is_positive() { + num <= 3 + } else { + num >= -1 + } + }) { + println!("foo"); + } +} +"#, + r#" +fn main() { + let arr = [-1, 1, 2, 3]; + if !arr.into_iter().any(|num: i32| { + if num.is_positive() { + num > 3 + } else { + num < -1 + } + }) { + println!("foo"); + } +} +"#, + ); + } + + #[test] + fn demorgan_iterator_wrong_method() { + check_assist_not_applicable( + apply_demorgan_iterator, + r#" +//- minicore: iterator +fn main() { + let arr = [1, 2, 3]; + if !arr.into_iter().$0map(|num| num > 3) { + println!("foo"); + } +} +"#, + ); + } } diff --git a/crates/ide-assists/src/lib.rs b/crates/ide-assists/src/lib.rs index a17ce93e928..50476ccf363 100644 --- a/crates/ide-assists/src/lib.rs +++ b/crates/ide-assists/src/lib.rs @@ -226,6 +226,7 @@ pub(crate) fn all() -> &'static [Handler] { add_return_type::add_return_type, add_turbo_fish::add_turbo_fish, apply_demorgan::apply_demorgan, + apply_demorgan::apply_demorgan_iterator, auto_import::auto_import, bind_unused_param::bind_unused_param, bool_to_enum::bool_to_enum, diff --git a/crates/ide-assists/src/tests/generated.rs b/crates/ide-assists/src/tests/generated.rs index 5a815d5c6a1..65bd74c018b 100644 --- a/crates/ide-assists/src/tests/generated.rs +++ b/crates/ide-assists/src/tests/generated.rs @@ -244,6 +244,30 @@ fn main() { ) } +#[test] +fn doctest_apply_demorgan_iterator() { + check_doc_test( + "apply_demorgan_iterator", + r#####" +//- minicore: iterator +fn main() { + let arr = [1, 2, 3]; + if !arr.into_iter().$0any(|num| num == 4) { + println!("foo"); + } +} +"#####, + r#####" +fn main() { + let arr = [1, 2, 3]; + if arr.into_iter().all(|num| num != 4) { + println!("foo"); + } +} +"#####, + ) +} + #[test] fn doctest_auto_import() { check_doc_test( From c266387e130de11c60be2c2d7d0a1d5c3bc3eb62 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Wed, 4 Oct 2023 13:06:23 +0200 Subject: [PATCH 2/2] Replace unwrap with expect --- crates/ide-assists/src/handlers/apply_demorgan.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/crates/ide-assists/src/handlers/apply_demorgan.rs b/crates/ide-assists/src/handlers/apply_demorgan.rs index 74db300465a..2d41243c20e 100644 --- a/crates/ide-assists/src/handlers/apply_demorgan.rs +++ b/crates/ide-assists/src/handlers/apply_demorgan.rs @@ -219,7 +219,12 @@ pub(crate) fn apply_demorgan_iterator(acc: &mut Assists, ctx: &AssistContext<'_> .and_then(ast::PrefixExpr::cast) .filter(|prefix_expr| matches!(prefix_expr.op_kind(), Some(ast::UnaryOp::Not))) { - edit.delete(prefix_expr.op_token().unwrap().text_range()); + edit.delete( + prefix_expr + .op_token() + .expect("prefix expression always has an operator") + .text_range(), + ); } else { edit.insert(method_call.syntax().text_range().start(), "!"); }