use super::utils::make_iterator_snippet; use super::MANUAL_FIND; use clippy_utils::diagnostics::span_lint_and_then; use clippy_utils::source::snippet_with_applicability; use clippy_utils::ty::implements_trait; use clippy_utils::{higher, is_res_lang_ctor, path_res, peel_blocks_with_stmt}; use if_chain::if_chain; use rustc_errors::Applicability; use rustc_hir::def::Res; use rustc_hir::lang_items::LangItem; use rustc_hir::{BindingAnnotation, Block, Expr, ExprKind, HirId, Node, Pat, PatKind, Stmt, StmtKind}; use rustc_lint::LateContext; use rustc_span::source_map::Span; pub(super) fn check<'tcx>( cx: &LateContext<'tcx>, pat: &'tcx Pat<'_>, arg: &'tcx Expr<'_>, body: &'tcx Expr<'_>, span: Span, expr: &'tcx Expr<'_>, ) { let inner_expr = peel_blocks_with_stmt(body); // Check for the specific case that the result is returned and optimize suggestion for that (more // cases can be added later) if let Some(higher::If { cond, then, r#else: None, }) = higher::If::hir(inner_expr) && let Some(binding_id) = get_binding(pat) && let ExprKind::Block(block, _) = then.kind && let [stmt] = block.stmts && let StmtKind::Semi(semi) = stmt.kind && let ExprKind::Ret(Some(ret_value)) = semi.kind && let ExprKind::Call(ctor, [inner_ret]) = ret_value.kind && is_res_lang_ctor(cx, path_res(cx, ctor), LangItem::OptionSome) && path_res(cx, inner_ret) == Res::Local(binding_id) && let Some((last_stmt, last_ret)) = last_stmt_and_ret(cx, expr) { let mut applicability = Applicability::MachineApplicable; let mut snippet = make_iterator_snippet(cx, arg, &mut applicability); // Checks if `pat` is a single reference to a binding (`&x`) let is_ref_to_binding = matches!(pat.kind, PatKind::Ref(inner, _) if matches!(inner.kind, PatKind::Binding(..))); // If `pat` is not a binding or a reference to a binding (`x` or `&x`) // we need to map it to the binding returned by the function (i.e. `.map(|(x, _)| x)`) if !(matches!(pat.kind, PatKind::Binding(..)) || is_ref_to_binding) { snippet.push_str( &format!( ".map(|{}| {})", snippet_with_applicability(cx, pat.span, "..", &mut applicability), snippet_with_applicability(cx, inner_ret.span, "..", &mut applicability), )[..], ); } let ty = cx.typeck_results().expr_ty(inner_ret); if cx.tcx.lang_items().copy_trait().map_or(false, |id| implements_trait(cx, ty, id, &[])) { snippet.push_str( &format!( ".find(|{}{}| {})", "&".repeat(1 + usize::from(is_ref_to_binding)), snippet_with_applicability(cx, inner_ret.span, "..", &mut applicability), snippet_with_applicability(cx, cond.span, "..", &mut applicability), )[..], ); if is_ref_to_binding { snippet.push_str(".copied()"); } } else { applicability = Applicability::MaybeIncorrect; snippet.push_str( &format!( ".find(|{}| {})", snippet_with_applicability(cx, inner_ret.span, "..", &mut applicability), snippet_with_applicability(cx, cond.span, "..", &mut applicability), )[..], ); } // Extends to `last_stmt` to include semicolon in case of `return None;` let lint_span = span.to(last_stmt.span).to(last_ret.span); span_lint_and_then( cx, MANUAL_FIND, lint_span, "manual implementation of `Iterator::find`", |diag| { if applicability == Applicability::MaybeIncorrect { diag.note("you may need to dereference some variables"); } diag.span_suggestion( lint_span, "replace with an iterator", snippet, applicability, ); }, ); } } fn get_binding(pat: &Pat<'_>) -> Option { let mut hir_id = None; let mut count = 0; pat.each_binding(|annotation, id, _, _| { count += 1; if count > 1 { hir_id = None; return; } if let BindingAnnotation::NONE = annotation { hir_id = Some(id); } }); hir_id } // Returns the last statement and last return if function fits format for lint fn last_stmt_and_ret<'tcx>( cx: &LateContext<'tcx>, expr: &'tcx Expr<'_>, ) -> Option<(&'tcx Stmt<'tcx>, &'tcx Expr<'tcx>)> { // Returns last non-return statement and the last return fn extract<'tcx>(block: &Block<'tcx>) -> Option<(&'tcx Stmt<'tcx>, &'tcx Expr<'tcx>)> { if let [.., last_stmt] = block.stmts { if let Some(ret) = block.expr { return Some((last_stmt, ret)); } if let [.., snd_last, _] = block.stmts && let StmtKind::Semi(last_expr) = last_stmt.kind && let ExprKind::Ret(Some(ret)) = last_expr.kind { return Some((snd_last, ret)); } } None } let mut parent_iter = cx.tcx.hir().parent_iter(expr.hir_id); if let Some((node_hir, Node::Stmt(..))) = parent_iter.next() // This should be the loop // This should be the function body && let Some((_, Node::Block(block))) = parent_iter.next() && let Some((last_stmt, last_ret)) = extract(block) && last_stmt.hir_id == node_hir && is_res_lang_ctor(cx, path_res(cx, last_ret), LangItem::OptionNone) && let Some((_, Node::Expr(_block))) = parent_iter.next() // This includes the function header && let Some((_, func)) = parent_iter.next() && func.fn_kind().is_some() { Some((block.stmts.last().unwrap(), last_ret)) } else { None } }