diff --git a/crates/hir-ty/src/diagnostics/expr.rs b/crates/hir-ty/src/diagnostics/expr.rs index ff70618ca12..718409e1599 100644 --- a/crates/hir-ty/src/diagnostics/expr.rs +++ b/crates/hir-ty/src/diagnostics/expr.rs @@ -12,6 +12,7 @@ use itertools::Itertools; use rustc_hash::FxHashSet; use rustc_pattern_analysis::usefulness::{compute_match_usefulness, ValidityConstraint}; +use syntax::{ast, AstNode}; use tracing::debug; use triomphe::Arc; use typed_arena::Arena; @@ -108,7 +109,7 @@ fn validate_body(&mut self, db: &dyn HirDatabase) { self.check_for_trailing_return(*body_expr, &body); } Expr::If { .. } => { - self.check_for_unnecessary_else(id, expr, &body); + self.check_for_unnecessary_else(id, expr, db); } Expr::Block { .. } => { self.validate_block(db, expr); @@ -336,32 +337,35 @@ fn check_for_trailing_return(&mut self, body_expr: ExprId, body: &Body) { } } - fn check_for_unnecessary_else(&mut self, id: ExprId, expr: &Expr, body: &Body) { + fn check_for_unnecessary_else(&mut self, id: ExprId, expr: &Expr, db: &dyn HirDatabase) { if let Expr::If { condition: _, then_branch, else_branch } = expr { - if let Some(else_branch) = else_branch { - // If else branch has a tail, it is an "expression" that produces a value, - // e.g. `let a = if { ... } else { ... };` and this `else` is not unnecessary - let mut branch = *else_branch; - loop { - match body.exprs[branch] { - Expr::Block { tail: Some(_), .. } => return, - Expr::If { then_branch, else_branch, .. } => { - if let Expr::Block { tail: Some(_), .. } = body.exprs[then_branch] { - return; - } - if let Some(else_branch) = else_branch { - // Continue checking for branches like `if { ... } else if { ... } else...` - branch = else_branch; - continue; - } - } - _ => break, - } - break; - } - } else { + if else_branch.is_none() { return; } + let (body, source_map) = db.body_with_source_map(self.owner); + let Ok(source_ptr) = source_map.expr_syntax(id) else { + return; + }; + let root = source_ptr.file_syntax(db.upcast()); + let ast::Expr::IfExpr(if_expr) = source_ptr.value.to_node(&root) else { + return; + }; + let mut top_if_expr = if_expr; + loop { + let parent = top_if_expr.syntax().parent(); + let has_parent_let_stmt = + parent.as_ref().map_or(false, |node| ast::LetStmt::can_cast(node.kind())); + if has_parent_let_stmt { + // Bail if parent or direct ancestor is a let stmt. + return; + } + let Some(parent_if_expr) = parent.and_then(ast::IfExpr::cast) else { + // Parent is neither an if expr nor a let stmt. + break; + }; + // Check parent if expr. + top_if_expr = parent_if_expr; + } if let Expr::Block { statements, tail, .. } = &body.exprs[*then_branch] { let last_then_expr = tail.or_else(|| match statements.last()? { Statement::Expr { expr, .. } => Some(*expr), diff --git a/crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs b/crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs index bbc10e96cef..351f728747e 100644 --- a/crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs +++ b/crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs @@ -400,15 +400,7 @@ fn test1(a: bool) { }; } -fn test2(a: bool) -> i32 { - if a { - return 1; - } else { - 0 - } -} - -fn test3(a: bool, b: bool, c: bool) { +fn test2(a: bool, b: bool, c: bool) { let _x = if a { return; } else if b {