diff --git a/clippy_lints/src/copies.rs b/clippy_lints/src/copies.rs index 1deff9684a1..0e3d9317590 100644 --- a/clippy_lints/src/copies.rs +++ b/clippy_lints/src/copies.rs @@ -1,13 +1,16 @@ use clippy_utils::diagnostics::{span_lint_and_note, span_lint_and_then}; use clippy_utils::source::{first_line_of_span, indent_of, reindent_multiline, snippet, snippet_opt}; +use clippy_utils::ty::needs_ordered_drop; +use clippy_utils::visitors::for_each_expr; use clippy_utils::{ - eq_expr_value, get_enclosing_block, hash_expr, hash_stmt, if_sequence, is_else_clause, is_lint_allowed, - search_same, ContainsName, HirEqInterExpr, SpanlessEq, + capture_local_usage, eq_expr_value, get_enclosing_block, hash_expr, hash_stmt, if_sequence, is_else_clause, + is_lint_allowed, path_to_local, search_same, ContainsName, HirEqInterExpr, SpanlessEq, }; use core::iter; +use core::ops::ControlFlow; use rustc_errors::Applicability; use rustc_hir::intravisit; -use rustc_hir::{BinOpKind, Block, Expr, ExprKind, HirId, Stmt, StmtKind}; +use rustc_hir::{BinOpKind, Block, Expr, ExprKind, HirId, HirIdSet, Stmt, StmtKind}; use rustc_lint::{LateContext, LateLintPass}; use rustc_session::{declare_lint_pass, declare_tool_lint}; use rustc_span::hygiene::walk_chain; @@ -214,7 +217,7 @@ fn lint_if_same_then_else(cx: &LateContext<'_>, conds: &[&Expr<'_>], blocks: &[& fn lint_branches_sharing_code<'tcx>( cx: &LateContext<'tcx>, conds: &[&'tcx Expr<'_>], - blocks: &[&Block<'tcx>], + blocks: &[&'tcx Block<'_>], expr: &'tcx Expr<'_>, ) { // We only lint ifs with multiple blocks @@ -340,6 +343,21 @@ fn eq_binding_names(s: &Stmt<'_>, names: &[(HirId, Symbol)]) -> bool { } } +/// Checks if the statement modifies or moves any of the given locals. +fn modifies_any_local<'tcx>(cx: &LateContext<'tcx>, s: &'tcx Stmt<'_>, locals: &HirIdSet) -> bool { + for_each_expr(s, |e| { + if let Some(id) = path_to_local(e) + && locals.contains(&id) + && !capture_local_usage(cx, e).is_imm_ref() + { + ControlFlow::Break(()) + } else { + ControlFlow::Continue(()) + } + }) + .is_some() +} + /// Checks if the given statement should be considered equal to the statement in the same position /// for each block. fn eq_stmts( @@ -365,18 +383,52 @@ fn eq_stmts( .all(|b| get_stmt(b).map_or(false, |s| eq.eq_stmt(s, stmt))) } -fn scan_block_for_eq(cx: &LateContext<'_>, _conds: &[&Expr<'_>], block: &Block<'_>, blocks: &[&Block<'_>]) -> BlockEq { +#[expect(clippy::too_many_lines)] +fn scan_block_for_eq<'tcx>( + cx: &LateContext<'tcx>, + conds: &[&'tcx Expr<'_>], + block: &'tcx Block<'_>, + blocks: &[&'tcx Block<'_>], +) -> BlockEq { let mut eq = SpanlessEq::new(cx); let mut eq = eq.inter_expr(); let mut moved_locals = Vec::new(); + let mut cond_locals = HirIdSet::default(); + for &cond in conds { + let _: Option = for_each_expr(cond, |e| { + if let Some(id) = path_to_local(e) { + cond_locals.insert(id); + } + ControlFlow::Continue(()) + }); + } + + let mut local_needs_ordered_drop = false; let start_end_eq = block .stmts .iter() .enumerate() - .find(|&(i, stmt)| !eq_stmts(stmt, blocks, |b| b.stmts.get(i), &mut eq, &mut moved_locals)) + .find(|&(i, stmt)| { + if let StmtKind::Local(l) = stmt.kind + && needs_ordered_drop(cx, cx.typeck_results().node_type(l.hir_id)) + { + local_needs_ordered_drop = true; + return true; + } + modifies_any_local(cx, stmt, &cond_locals) + || !eq_stmts(stmt, blocks, |b| b.stmts.get(i), &mut eq, &mut moved_locals) + }) .map_or(block.stmts.len(), |(i, _)| i); + if local_needs_ordered_drop { + return BlockEq { + start_end_eq, + end_begin_eq: None, + moved_locals, + }; + } + // Walk backwards through the final expression/statements so long as their hashes are equal. Note // `SpanlessHash` treats all local references as equal allowing locals declared earlier in the block // to match those in other blocks. e.g. If each block ends with the following the hash value will be diff --git a/clippy_utils/src/lib.rs b/clippy_utils/src/lib.rs index 0e739303683..fd0c6869929 100644 --- a/clippy_utils/src/lib.rs +++ b/clippy_utils/src/lib.rs @@ -890,7 +890,7 @@ pub fn capture_local_usage<'tcx>(cx: &LateContext<'tcx>, e: &Expr<'_>) -> Captur Node::Expr(e) => match e.kind { ExprKind::AddrOf(_, mutability, _) => return CaptureKind::Ref(mutability), ExprKind::Index(..) | ExprKind::Unary(UnOp::Deref, _) => capture = CaptureKind::Ref(Mutability::Not), - ExprKind::Assign(lhs, ..) | ExprKind::Assign(_, lhs, _) if lhs.hir_id == child_id => { + ExprKind::Assign(lhs, ..) | ExprKind::AssignOp(_, lhs, _) if lhs.hir_id == child_id => { return CaptureKind::Ref(Mutability::Mut); }, ExprKind::Field(..) => { diff --git a/clippy_utils/src/visitors.rs b/clippy_utils/src/visitors.rs index e89a46b8538..0b9f75238a6 100644 --- a/clippy_utils/src/visitors.rs +++ b/clippy_utils/src/visitors.rs @@ -5,7 +5,7 @@ use rustc_hir as hir; use rustc_hir::def::{CtorKind, DefKind, Res}; use rustc_hir::intravisit::{self, walk_block, walk_expr, Visitor}; use rustc_hir::{ - Arm, Block, BlockCheckMode, Body, BodyId, Expr, ExprKind, HirId, ItemId, ItemKind, Let, QPath, Stmt, UnOp, + Arm, Block, BlockCheckMode, Body, BodyId, Expr, ExprKind, HirId, ItemId, ItemKind, Let, Pat, QPath, Stmt, UnOp, UnsafeSource, Unsafety, }; use rustc_lint::LateContext; @@ -13,6 +13,74 @@ use rustc_middle::hir::map::Map; use rustc_middle::hir::nested_filter; use rustc_middle::ty::adjustment::Adjust; use rustc_middle::ty::{self, Ty, TypeckResults}; +use rustc_span::Span; + +mod internal { + /// Trait for visitor functions to control whether or not to descend to child nodes. Implemented + /// for only two types. `()` always descends. `Descend` allows controlled descent. + pub trait Continue { + fn descend(&self) -> bool; + } +} +use internal::Continue; + +impl Continue for () { + fn descend(&self) -> bool { + true + } +} + +/// Allows for controlled descent whe using visitor functions. Use `()` instead when always +/// descending into child nodes. +#[derive(Clone, Copy)] +pub enum Descend { + Yes, + No, +} +impl From for Descend { + fn from(from: bool) -> Self { + if from { Self::Yes } else { Self::No } + } +} +impl Continue for Descend { + fn descend(&self) -> bool { + matches!(self, Self::Yes) + } +} + +/// Calls the given function once for each expression contained. This does not enter any bodies or +/// nested items. +pub fn for_each_expr<'tcx, B, C: Continue>( + node: impl Visitable<'tcx>, + f: impl FnMut(&'tcx Expr<'tcx>) -> ControlFlow, +) -> Option { + struct V { + f: F, + res: Option, + } + impl<'tcx, B, C: Continue, F: FnMut(&'tcx Expr<'tcx>) -> ControlFlow> Visitor<'tcx> for V { + fn visit_expr(&mut self, e: &'tcx Expr<'tcx>) { + if self.res.is_some() { + return; + } + match (self.f)(e) { + ControlFlow::Continue(c) if c.descend() => walk_expr(self, e), + ControlFlow::Break(b) => self.res = Some(b), + ControlFlow::Continue(_) => (), + } + } + + // Avoid unnecessary `walk_*` calls. + fn visit_ty(&mut self, _: &'tcx hir::Ty<'tcx>) {} + fn visit_pat(&mut self, _: &'tcx Pat<'tcx>) {} + fn visit_qpath(&mut self, _: &'tcx QPath<'tcx>, _: HirId, _: Span) {} + // Avoid monomorphising all `visit_*` functions. + fn visit_nested_item(&mut self, _: ItemId) {} + } + let mut v = V { f, res: None }; + node.visit(&mut v); + v.res +} /// Convenience method for creating a `Visitor` with just `visit_expr` overridden and nested /// bodies (i.e. closures) are visited. diff --git a/tests/ui/branches_sharing_code/false_positives.rs b/tests/ui/branches_sharing_code/false_positives.rs index 06448200951..5e3a1a29693 100644 --- a/tests/ui/branches_sharing_code/false_positives.rs +++ b/tests/ui/branches_sharing_code/false_positives.rs @@ -1,6 +1,8 @@ #![allow(dead_code)] #![deny(clippy::if_same_then_else, clippy::branches_sharing_code)] +use std::sync::Mutex; + // ################################## // # Issue clippy#7369 // ################################## @@ -38,4 +40,56 @@ fn main() { let (y, x) = x; foo(x, y) }; + + let m = Mutex::new(0u32); + let l = m.lock().unwrap(); + let _ = if true { + drop(l); + println!("foo"); + m.lock().unwrap(); + 0 + } else if *l == 0 { + drop(l); + println!("foo"); + println!("bar"); + m.lock().unwrap(); + 1 + } else { + drop(l); + println!("foo"); + println!("baz"); + m.lock().unwrap(); + 2 + }; + + if true { + let _guard = m.lock(); + println!("foo"); + } else { + println!("foo"); + } + + if true { + let _guard = m.lock(); + println!("foo"); + println!("bar"); + } else { + let _guard = m.lock(); + println!("foo"); + println!("baz"); + } + + let mut c = 0; + for _ in 0..5 { + if c == 0 { + c += 1; + println!("0"); + } else if c == 1 { + c += 1; + println!("1"); + } else { + c += 1; + println!("more"); + } + } }