From ff7031008651021c330b93d4bd502810022b045d Mon Sep 17 00:00:00 2001 From: davidsemakula Date: Fri, 16 Feb 2024 20:39:52 +0300 Subject: [PATCH] fix: only emit "unnecessary else" diagnostic for expr stmts --- crates/hir-ty/src/diagnostics/expr.rs | 64 +++++++++++-------- .../src/handlers/remove_unnecessary_else.rs | 14 +++- 2 files changed, 49 insertions(+), 29 deletions(-) diff --git a/crates/hir-ty/src/diagnostics/expr.rs b/crates/hir-ty/src/diagnostics/expr.rs index 718409e1599..4fe75f24b80 100644 --- a/crates/hir-ty/src/diagnostics/expr.rs +++ b/crates/hir-ty/src/diagnostics/expr.rs @@ -109,7 +109,7 @@ impl ExprValidator { self.check_for_trailing_return(*body_expr, &body); } Expr::If { .. } => { - self.check_for_unnecessary_else(id, expr, db); + self.check_for_unnecessary_else(id, expr, &body, db); } Expr::Block { .. } => { self.validate_block(db, expr); @@ -337,35 +337,17 @@ impl ExprValidator { } } - fn check_for_unnecessary_else(&mut self, id: ExprId, expr: &Expr, db: &dyn HirDatabase) { + fn check_for_unnecessary_else( + &mut self, + id: ExprId, + expr: &Expr, + body: &Body, + db: &dyn HirDatabase, + ) { if let Expr::If { condition: _, then_branch, else_branch } = expr { if else_branch.is_none() { return; } - let (body, source_map) = db.body_with_source_map(self.owner); - let Ok(source_ptr) = source_map.expr_syntax(id) else { - return; - }; - let root = source_ptr.file_syntax(db.upcast()); - let ast::Expr::IfExpr(if_expr) = source_ptr.value.to_node(&root) else { - return; - }; - let mut top_if_expr = if_expr; - loop { - let parent = top_if_expr.syntax().parent(); - let has_parent_let_stmt = - parent.as_ref().map_or(false, |node| ast::LetStmt::can_cast(node.kind())); - if has_parent_let_stmt { - // Bail if parent or direct ancestor is a let stmt. - return; - } - let Some(parent_if_expr) = parent.and_then(ast::IfExpr::cast) else { - // Parent is neither an if expr nor a let stmt. - break; - }; - // Check parent if expr. - top_if_expr = parent_if_expr; - } if let Expr::Block { statements, tail, .. } = &body.exprs[*then_branch] { let last_then_expr = tail.or_else(|| match statements.last()? { Statement::Expr { expr, .. } => Some(*expr), @@ -374,6 +356,36 @@ impl ExprValidator { if let Some(last_then_expr) = last_then_expr { let last_then_expr_ty = &self.infer[last_then_expr]; if last_then_expr_ty.is_never() { + // Only look at sources if the then branch diverges and we have an else branch. + let (_, source_map) = db.body_with_source_map(self.owner); + let Ok(source_ptr) = source_map.expr_syntax(id) else { + return; + }; + let root = source_ptr.file_syntax(db.upcast()); + let ast::Expr::IfExpr(if_expr) = source_ptr.value.to_node(&root) else { + return; + }; + let mut top_if_expr = if_expr; + loop { + let parent = top_if_expr.syntax().parent(); + let has_parent_expr_stmt_or_stmt_list = + parent.as_ref().map_or(false, |node| { + ast::ExprStmt::can_cast(node.kind()) + | ast::StmtList::can_cast(node.kind()) + }); + if has_parent_expr_stmt_or_stmt_list { + // Only emit diagnostic if parent or direct ancestor is either + // an expr stmt or a stmt list. + break; + } + let Some(parent_if_expr) = parent.and_then(ast::IfExpr::cast) else { + // Bail if parent is neither an if expr, an expr stmt nor a stmt list. + return; + }; + // Check parent if expr. + top_if_expr = parent_if_expr; + } + self.diagnostics .push(BodyValidationDiagnostic::RemoveUnnecessaryElse { if_expr: id }) } diff --git a/crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs b/crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs index 9564807a334..7bfd64596ed 100644 --- a/crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs +++ b/crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs @@ -467,10 +467,10 @@ fn test() { } #[test] - fn no_diagnostic_if_tail_exists_in_else_branch() { + fn no_diagnostic_if_not_expr_stmt() { check_diagnostics_with_needless_return_disabled( r#" -fn test1(a: bool) { +fn test1() { let _x = if a { return; } else { @@ -478,7 +478,7 @@ fn test1(a: bool) { }; } -fn test2(a: bool, b: bool, c: bool) { +fn test2() { let _x = if a { return; } else if b { @@ -491,5 +491,13 @@ fn test2(a: bool, b: bool, c: bool) { } "#, ); + check_diagnostics_with_disabled( + r#" +fn test3() { + foo(if a { return 1 } else { 0 }) +} +"#, + std::iter::once("E0308".to_owned()), + ); } }