From beca92b245953873d273f19a08c7a927e5a3ed78 Mon Sep 17 00:00:00 2001
From: Aleksey Kladov <aleksey.kladov@gmail.com>
Date: Sat, 14 Aug 2021 16:40:00 +0300
Subject: [PATCH] internal: make invert binary op more robust

Previously, we only inverted comparison operators (< and the like) if
the type implemented Ord. This doesn't make sense: if `<` works, then
`>=` will work as well!

Extra semantic checks greatly reduce robustness and predictability of
the assist, it's better to keep things simple.
---
 .../src/handlers/apply_demorgan.rs            | 73 +++----------------
 .../src/handlers/convert_bool_then.rs         |  2 +-
 .../ide_assists/src/handlers/early_return.rs  |  2 +-
 crates/ide_assists/src/handlers/invert_if.rs  |  2 +-
 crates/ide_assists/src/tests/generated.rs     |  2 +-
 crates/ide_assists/src/utils.rs               | 50 +++----------
 6 files changed, 22 insertions(+), 109 deletions(-)

diff --git a/crates/ide_assists/src/handlers/apply_demorgan.rs b/crates/ide_assists/src/handlers/apply_demorgan.rs
index e2bd6e4567e..cafc4297fde 100644
--- a/crates/ide_assists/src/handlers/apply_demorgan.rs
+++ b/crates/ide_assists/src/handlers/apply_demorgan.rs
@@ -19,7 +19,7 @@ use crate::{utils::invert_boolean_expression, AssistContext, AssistId, AssistKin
 // ->
 // ```
 // fn main() {
-//     if !(x == 4 && !(y < 3.14)) {}
+//     if !(x == 4 && y >= 3.14) {}
 // }
 // ```
 pub(crate) fn apply_demorgan(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
@@ -99,7 +99,7 @@ pub(crate) fn apply_demorgan(acc: &mut Assists, ctx: &AssistContext) -> Option<(
             if let Some(paren_expr) = paren_expr {
                 for term in terms {
                     let range = term.syntax().text_range();
-                    let not_term = invert_boolean_expression(&ctx.sema, term);
+                    let not_term = invert_boolean_expression(term);
 
                     edit.replace(range, not_term.syntax().text());
                 }
@@ -114,21 +114,21 @@ pub(crate) fn apply_demorgan(acc: &mut Assists, ctx: &AssistContext) -> Option<(
             } else {
                 if let Some(lhs) = terms.pop_front() {
                     let lhs_range = lhs.syntax().text_range();
-                    let not_lhs = invert_boolean_expression(&ctx.sema, lhs);
+                    let not_lhs = invert_boolean_expression(lhs);
 
                     edit.replace(lhs_range, format!("!({}", not_lhs.syntax().text()));
                 }
 
                 if let Some(rhs) = terms.pop_back() {
                     let rhs_range = rhs.syntax().text_range();
-                    let not_rhs = invert_boolean_expression(&ctx.sema, rhs);
+                    let not_rhs = invert_boolean_expression(rhs);
 
                     edit.replace(rhs_range, format!("{})", not_rhs.syntax().text()));
                 }
 
                 for term in terms {
                     let term_range = term.syntax().text_range();
-                    let not_term = invert_boolean_expression(&ctx.sema, term);
+                    let not_term = invert_boolean_expression(term);
                     edit.replace(term_range, not_term.syntax().text());
                 }
             }
@@ -156,40 +156,12 @@ mod tests {
         check_assist(
             apply_demorgan,
             r#"
-//- minicore: ord, derive
-#[derive(PartialEq, Eq, PartialOrd, Ord)]
 struct S;
-
-fn f() {
-    S < S &&$0 S <= S
-}
-"#,
-            r#"
-#[derive(PartialEq, Eq, PartialOrd, Ord)]
-struct S;
-
-fn f() {
-    !(S >= S || S > S)
-}
-"#,
-        );
-
-        check_assist(
-            apply_demorgan,
-            r#"
-//- minicore: ord, derive
-struct S;
-
-fn f() {
-    S < S &&$0 S <= S
-}
+fn f() { S < S &&$0 S <= S }
 "#,
             r#"
 struct S;
-
-fn f() {
-    !(!(S < S) || !(S <= S))
-}
+fn f() { !(S >= S || S > S) }
 "#,
         );
     }
@@ -199,39 +171,12 @@ fn f() {
         check_assist(
             apply_demorgan,
             r#"
-//- minicore: ord, derive
-#[derive(PartialEq, Eq, PartialOrd, Ord)]
 struct S;
-
-fn f() {
-    S > S &&$0 S >= S
-}
-"#,
-            r#"
-#[derive(PartialEq, Eq, PartialOrd, Ord)]
-struct S;
-
-fn f() {
-    !(S <= S || S < S)
-}
-"#,
-        );
-        check_assist(
-            apply_demorgan,
-            r#"
-//- minicore: ord, derive
-struct S;
-
-fn f() {
-    S > S &&$0 S >= S
-}
+fn f() { S > S &&$0 S >= S }
 "#,
             r#"
 struct S;
-
-fn f() {
-    !(!(S > S) || !(S >= S))
-}
+fn f() { !(S <= S || S < S) }
 "#,
         );
     }
diff --git a/crates/ide_assists/src/handlers/convert_bool_then.rs b/crates/ide_assists/src/handlers/convert_bool_then.rs
index 3bb78fe0f29..5adb3f5a1b3 100644
--- a/crates/ide_assists/src/handlers/convert_bool_then.rs
+++ b/crates/ide_assists/src/handlers/convert_bool_then.rs
@@ -97,7 +97,7 @@ pub(crate) fn convert_if_to_bool_then(acc: &mut Assists, ctx: &AssistContext) ->
                 e => e,
             };
 
-            let cond = if invert_cond { invert_boolean_expression(&ctx.sema, cond) } else { cond };
+            let cond = if invert_cond { invert_boolean_expression(cond) } else { cond };
             let arg_list = make::arg_list(Some(make::expr_closure(None, closure_body)));
             let mcall = make::expr_method_call(cond, make::name_ref("then"), arg_list);
             builder.replace(target, mcall.to_string());
diff --git a/crates/ide_assists/src/handlers/early_return.rs b/crates/ide_assists/src/handlers/early_return.rs
index b4745b84242..1b3fa898bb7 100644
--- a/crates/ide_assists/src/handlers/early_return.rs
+++ b/crates/ide_assists/src/handlers/early_return.rs
@@ -115,7 +115,7 @@ pub(crate) fn convert_to_guarded_return(acc: &mut Assists, ctx: &AssistContext)
                     let new_expr = {
                         let then_branch =
                             make::block_expr(once(make::expr_stmt(early_expression).into()), None);
-                        let cond = invert_boolean_expression(&ctx.sema, cond_expr);
+                        let cond = invert_boolean_expression(cond_expr);
                         make::expr_if(make::condition(cond, None), then_branch, None)
                             .indent(if_indent_level)
                     };
diff --git a/crates/ide_assists/src/handlers/invert_if.rs b/crates/ide_assists/src/handlers/invert_if.rs
index f7f38dffbda..50845cd9e03 100644
--- a/crates/ide_assists/src/handlers/invert_if.rs
+++ b/crates/ide_assists/src/handlers/invert_if.rs
@@ -47,7 +47,7 @@ pub(crate) fn invert_if(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
     };
 
     acc.add(AssistId("invert_if", AssistKind::RefactorRewrite), "Invert if", if_range, |edit| {
-        let flip_cond = invert_boolean_expression(&ctx.sema, cond.clone());
+        let flip_cond = invert_boolean_expression(cond.clone());
         edit.replace_ast(cond, flip_cond);
 
         let else_node = else_block.syntax();
diff --git a/crates/ide_assists/src/tests/generated.rs b/crates/ide_assists/src/tests/generated.rs
index c4df6aec9fc..853c41f78f4 100644
--- a/crates/ide_assists/src/tests/generated.rs
+++ b/crates/ide_assists/src/tests/generated.rs
@@ -151,7 +151,7 @@ fn main() {
 "#####,
         r#####"
 fn main() {
-    if !(x == 4 && !(y < 3.14)) {}
+    if !(x == 4 && y >= 3.14) {}
 }
 "#####,
     )
diff --git a/crates/ide_assists/src/utils.rs b/crates/ide_assists/src/utils.rs
index a4e4a00f78d..256ddb8c9b2 100644
--- a/crates/ide_assists/src/utils.rs
+++ b/crates/ide_assists/src/utils.rs
@@ -5,12 +5,8 @@ mod gen_trait_fn_body;
 
 use std::ops;
 
-use hir::{Adt, HasSource, Semantics};
-use ide_db::{
-    helpers::{FamousDefs, SnippetCap},
-    path_transform::PathTransform,
-    RootDatabase,
-};
+use hir::{Adt, HasSource};
+use ide_db::{helpers::SnippetCap, path_transform::PathTransform, RootDatabase};
 use itertools::Itertools;
 use stdx::format_to;
 use syntax::{
@@ -207,31 +203,19 @@ pub(crate) fn vis_offset(node: &SyntaxNode) -> TextSize {
         .unwrap_or_else(|| node.text_range().start())
 }
 
-pub(crate) fn invert_boolean_expression(
-    sema: &Semantics<RootDatabase>,
-    expr: ast::Expr,
-) -> ast::Expr {
-    invert_special_case(sema, &expr).unwrap_or_else(|| make::expr_prefix(T![!], expr))
+pub(crate) fn invert_boolean_expression(expr: ast::Expr) -> ast::Expr {
+    invert_special_case(&expr).unwrap_or_else(|| make::expr_prefix(T![!], expr))
 }
 
-fn invert_special_case(sema: &Semantics<RootDatabase>, expr: &ast::Expr) -> Option<ast::Expr> {
+fn invert_special_case(expr: &ast::Expr) -> Option<ast::Expr> {
     match expr {
         ast::Expr::BinExpr(bin) => match bin.op_kind()? {
             ast::BinOp::NegatedEqualityTest => bin.replace_op(T![==]).map(|it| it.into()),
             ast::BinOp::EqualityTest => bin.replace_op(T![!=]).map(|it| it.into()),
-            // Swap `<` with `>=`, `<=` with `>`, ... if operands `impl Ord`
-            ast::BinOp::LesserTest if bin_impls_ord(sema, bin) => {
-                bin.replace_op(T![>=]).map(|it| it.into())
-            }
-            ast::BinOp::LesserEqualTest if bin_impls_ord(sema, bin) => {
-                bin.replace_op(T![>]).map(|it| it.into())
-            }
-            ast::BinOp::GreaterTest if bin_impls_ord(sema, bin) => {
-                bin.replace_op(T![<=]).map(|it| it.into())
-            }
-            ast::BinOp::GreaterEqualTest if bin_impls_ord(sema, bin) => {
-                bin.replace_op(T![<]).map(|it| it.into())
-            }
+            ast::BinOp::LesserTest => bin.replace_op(T![>=]).map(|it| it.into()),
+            ast::BinOp::LesserEqualTest => bin.replace_op(T![>]).map(|it| it.into()),
+            ast::BinOp::GreaterTest => bin.replace_op(T![<=]).map(|it| it.into()),
+            ast::BinOp::GreaterEqualTest => bin.replace_op(T![<]).map(|it| it.into()),
             // Parenthesize other expressions before prefixing `!`
             _ => Some(make::expr_prefix(T![!], make::expr_paren(expr.clone()))),
         },
@@ -267,22 +251,6 @@ fn invert_special_case(sema: &Semantics<RootDatabase>, expr: &ast::Expr) -> Opti
     }
 }
 
-fn bin_impls_ord(sema: &Semantics<RootDatabase>, bin: &ast::BinExpr) -> bool {
-    match (
-        bin.lhs().and_then(|lhs| sema.type_of_expr(&lhs)).map(hir::TypeInfo::adjusted),
-        bin.rhs().and_then(|rhs| sema.type_of_expr(&rhs)).map(hir::TypeInfo::adjusted),
-    ) {
-        (Some(lhs_ty), Some(rhs_ty)) if lhs_ty == rhs_ty => {
-            let krate = sema.scope(bin.syntax()).module().map(|it| it.krate());
-            let ord_trait = FamousDefs(sema, krate).core_cmp_Ord();
-            ord_trait.map_or(false, |ord_trait| {
-                lhs_ty.autoderef(sema.db).any(|ty| ty.impls_trait(sema.db, ord_trait, &[]))
-            })
-        }
-        _ => false,
-    }
-}
-
 pub(crate) fn next_prev() -> impl Iterator<Item = Direction> {
     [Direction::Next, Direction::Prev].iter().copied()
 }