diff --git a/clippy_lints/src/bool_assert_comparison.rs b/clippy_lints/src/bool_assert_comparison.rs index 8d3f68565b2..171d22f9986 100644 --- a/clippy_lints/src/bool_assert_comparison.rs +++ b/clippy_lints/src/bool_assert_comparison.rs @@ -1,9 +1,11 @@ -use clippy_utils::diagnostics::span_lint_and_sugg; -use clippy_utils::{ast_utils, is_direct_expn_of}; -use rustc_ast::ast::{Expr, ExprKind, Lit, LitKind}; +use clippy_utils::{diagnostics::span_lint_and_sugg, higher, is_direct_expn_of, ty::implements_trait}; +use rustc_ast::ast::LitKind; use rustc_errors::Applicability; -use rustc_lint::{EarlyContext, EarlyLintPass}; +use rustc_hir::*; +use rustc_lint::{LateContext, LateLintPass}; +use rustc_middle::ty; use rustc_session::{declare_lint_pass, declare_tool_lint}; +use rustc_span::symbol::Ident; declare_clippy_lint! { /// ### What it does @@ -28,45 +30,77 @@ declare_clippy_lint! { declare_lint_pass!(BoolAssertComparison => [BOOL_ASSERT_COMPARISON]); -fn is_bool_lit(e: &Expr) -> bool { +fn is_bool_lit(e: &Expr<'_>) -> bool { matches!( e.kind, ExprKind::Lit(Lit { - kind: LitKind::Bool(_), + node: LitKind::Bool(_), .. }) ) && !e.span.from_expansion() } -impl EarlyLintPass for BoolAssertComparison { - fn check_expr(&mut self, cx: &EarlyContext<'_>, e: &Expr) { +fn impl_not_trait_with_bool_out(cx: &LateContext<'tcx>, e: &'tcx Expr<'_>) -> bool { + let ty = cx.typeck_results().expr_ty(e); + + cx.tcx + .lang_items() + .not_trait() + .filter(|id| implements_trait(cx, ty, *id, &[])) + .and_then(|id| { + cx.tcx.associated_items(id).find_by_name_and_kind( + cx.tcx, + Ident::from_str("Output"), + ty::AssocKind::Type, + id, + ) + }) + .map_or(false, |item| { + let proj = cx.tcx.mk_projection(item.def_id, cx.tcx.mk_substs_trait(ty, &[])); + let nty = cx.tcx.normalize_erasing_regions(cx.param_env, proj); + + nty.is_bool() + }) +} + +impl<'tcx> LateLintPass<'tcx> for BoolAssertComparison { + fn check_expr(&mut self, cx: &LateContext<'tcx>, expr: &'tcx Expr<'_>) { let macros = ["assert_eq", "debug_assert_eq"]; let inverted_macros = ["assert_ne", "debug_assert_ne"]; for mac in macros.iter().chain(inverted_macros.iter()) { - if let Some(span) = is_direct_expn_of(e.span, mac) { - if let Some([a, b]) = ast_utils::extract_assert_macro_args(e) { - let nb_bool_args = is_bool_lit(a) as usize + is_bool_lit(b) as usize; + if let Some(span) = is_direct_expn_of(expr.span, mac) { + if let Some(args) = higher::extract_assert_macro_args(expr) { + if let [a, b, ..] = args[..] { + let nb_bool_args = is_bool_lit(a) as usize + is_bool_lit(b) as usize; - if nb_bool_args != 1 { - // If there are two boolean arguments, we definitely don't understand - // what's going on, so better leave things as is... - // - // Or there is simply no boolean and then we can leave things as is! + if nb_bool_args != 1 { + // If there are two boolean arguments, we definitely don't understand + // what's going on, so better leave things as is... + // + // Or there is simply no boolean and then we can leave things as is! + return; + } + + if !impl_not_trait_with_bool_out(cx, a) || !impl_not_trait_with_bool_out(cx, b) { + // At this point the expression which is not a boolean + // literal does not implement Not trait with a bool output, + // so we cannot suggest to rewrite our code + return; + } + + let non_eq_mac = &mac[..mac.len() - 3]; + span_lint_and_sugg( + cx, + BOOL_ASSERT_COMPARISON, + span, + &format!("used `{}!` with a literal bool", mac), + "replace it with", + format!("{}!(..)", non_eq_mac), + Applicability::MaybeIncorrect, + ); return; } - - let non_eq_mac = &mac[..mac.len() - 3]; - span_lint_and_sugg( - cx, - BOOL_ASSERT_COMPARISON, - span, - &format!("used `{}!` with a literal bool", mac), - "replace it with", - format!("{}!(..)", non_eq_mac), - Applicability::MaybeIncorrect, - ); - return; } } } diff --git a/clippy_lints/src/lib.rs b/clippy_lints/src/lib.rs index be97776187f..5987f9e5b0c 100644 --- a/clippy_lints/src/lib.rs +++ b/clippy_lints/src/lib.rs @@ -2115,7 +2115,7 @@ pub fn register_plugins(store: &mut rustc_lint::LintStore, sess: &Session, conf: store.register_late_pass(|| box from_str_radix_10::FromStrRadix10); store.register_late_pass(|| box manual_map::ManualMap); store.register_late_pass(move || box if_then_some_else_none::IfThenSomeElseNone::new(msrv)); - store.register_early_pass(|| box bool_assert_comparison::BoolAssertComparison); + store.register_late_pass(|| box bool_assert_comparison::BoolAssertComparison); store.register_early_pass(move || box module_style::ModStyle); store.register_late_pass(|| box unused_async::UnusedAsync); let disallowed_types = conf.disallowed_types.iter().cloned().collect::>();