diff --git a/crates/ide-diagnostics/src/handlers/missing_unsafe.rs b/crates/ide-diagnostics/src/handlers/missing_unsafe.rs index c709c8c482b..eb32db25065 100644 --- a/crates/ide-diagnostics/src/handlers/missing_unsafe.rs +++ b/crates/ide-diagnostics/src/handlers/missing_unsafe.rs @@ -1,7 +1,7 @@ -use hir::db::AstDatabase; +use hir::db::ExpandDatabase; use ide_db::{assists::Assist, source_change::SourceChange}; -use syntax::AstNode; use syntax::{ast, SyntaxNode}; +use syntax::{match_ast, AstNode}; use text_edit::TextEdit; use crate::{fix, Diagnostic, DiagnosticsContext}; @@ -19,10 +19,15 @@ pub(crate) fn missing_unsafe(ctx: &DiagnosticsContext<'_>, d: &hir::MissingUnsaf } fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::MissingUnsafe) -> Option> { + // The fixit will not work correctly for macro expansions, so we don't offer it in that case. + if d.expr.file_id.is_macro() { + return None; + } + let root = ctx.sema.db.parse_or_expand(d.expr.file_id)?; let expr = d.expr.value.to_node(&root); - let node_to_add_unsafe_block = pick_best_node_to_add_unsafe_block(&expr); + let node_to_add_unsafe_block = pick_best_node_to_add_unsafe_block(&expr)?; let replacement = format!("unsafe {{ {} }}", node_to_add_unsafe_block.text()); let edit = TextEdit::replace(node_to_add_unsafe_block.text_range(), replacement); @@ -42,72 +47,51 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::MissingUnsafe) -> Option `unsafe { unsafe_expr += 1 }` // - `&unsafe_expr` -> `unsafe { &unsafe_expr }` // - `&&unsafe_expr` -> `unsafe { &&unsafe_expr }` -fn pick_best_node_to_add_unsafe_block(unsafe_expr: &ast::Expr) -> SyntaxNode { +fn pick_best_node_to_add_unsafe_block(unsafe_expr: &ast::Expr) -> Option { // The `unsafe_expr` might be: // - `ast::CallExpr`: call an unsafe function // - `ast::MethodCallExpr`: call an unsafe method // - `ast::PrefixExpr`: dereference a raw pointer // - `ast::PathExpr`: access a static mut variable - for node in unsafe_expr.syntax().ancestors() { - let Some(parent) = node.parent() else { - return node; - }; - match parent.kind() { - syntax::SyntaxKind::METHOD_CALL_EXPR => { - // Check if the `node` is the receiver of the method call - let method_call_expr = ast::MethodCallExpr::cast(parent.clone()).unwrap(); - if method_call_expr - .receiver() - .map(|receiver| { - receiver.syntax().text_range().contains_range(node.text_range()) - }) - .unwrap_or(false) - { - // Actually, I think it's not necessary to check whether the - // text range of the `node` (which is the ancestor of the - // `unsafe_expr`) is contained in the text range of the - // receiver. The `node` could potentially be the receiver, the - // method name, or the argument list. Since the `node` is the - // ancestor of the unsafe_expr, it cannot be the method name. - // Additionally, if the `node` is the argument list, the loop - // would break at least when `parent` reaches the argument list. - // - // Dispite this, I still check the text range because I think it - // makes the code easier to understand. - continue; - } - return node; - } - syntax::SyntaxKind::FIELD_EXPR | syntax::SyntaxKind::REF_EXPR => continue, - syntax::SyntaxKind::BIN_EXPR => { - // Check if the `node` is the left-hand side of an assignment - let is_left_hand_side_of_assignment = { - let bin_expr = ast::BinExpr::cast(parent.clone()).unwrap(); - if let Some(ast::BinaryOp::Assignment { .. }) = bin_expr.op_kind() { - let is_left_hand_side = bin_expr - .lhs() - .map(|lhs| lhs.syntax().text_range().contains_range(node.text_range())) - .unwrap_or(false); - is_left_hand_side - } else { - false + for (node, parent) in + unsafe_expr.syntax().ancestors().zip(unsafe_expr.syntax().ancestors().skip(1)) + { + match_ast! { + match parent { + // If the `parent` is a `MethodCallExpr`, that means the `node` + // is the receiver of the method call, because only the receiver + // can be a direct child of a method call. The method name + // itself is not an expression but a `NameRef`, and an argument + // is a direct child of an `ArgList`. + ast::MethodCallExpr(_) => continue, + ast::FieldExpr(_) => continue, + ast::RefExpr(_) => continue, + ast::BinExpr(it) => { + // Check if the `node` is the left-hand side of an + // assignment, if so, we don't want to wrap it in an unsafe + // block, e.g. `unsafe_expr += 1` + let is_left_hand_side_of_assignment = { + if let Some(ast::BinaryOp::Assignment { .. }) = it.op_kind() { + it.lhs().map(|lhs| lhs.syntax().text_range().contains_range(node.text_range())).unwrap_or(false) + } else { + false + } + }; + if !is_left_hand_side_of_assignment { + return Some(node); } - }; - if !is_left_hand_side_of_assignment { - return node; - } - } - _ => { - return node; + }, + _ => { return Some(node); } + } } } - unsafe_expr.syntax().clone() + None } #[cfg(test)] mod tests { - use crate::tests::{check_diagnostics, check_fix}; + use crate::tests::{check_diagnostics, check_fix, check_no_fix}; #[test] fn missing_unsafe_diagnostic_with_raw_ptr() { @@ -467,4 +451,19 @@ fn main() { "#, ) } + + #[test] + fn unsafe_expr_in_macro_call() { + check_no_fix( + r#" +unsafe fn foo() -> u8 { + 0 +} + +fn main() { + let x = format!("foo: {}", foo$0()); +} + "#, + ) + } }