diff --git a/crates/ide-diagnostics/src/handlers/missing_unsafe.rs b/crates/ide-diagnostics/src/handlers/missing_unsafe.rs index 60086ed4a4e..c709c8c482b 100644 --- a/crates/ide-diagnostics/src/handlers/missing_unsafe.rs +++ b/crates/ide-diagnostics/src/handlers/missing_unsafe.rs @@ -1,6 +1,5 @@ use hir::db::AstDatabase; use ide_db::{assists::Assist, source_change::SourceChange}; -use syntax::ast::{ExprStmt, LetStmt}; use syntax::AstNode; use syntax::{ast, SyntaxNode}; use text_edit::TextEdit; @@ -23,7 +22,7 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::MissingUnsafe) -> Option, d: &hir::MissingUnsafe) -> Option, - expr: &ast::Expr, -) -> SyntaxNode { - let Some(let_or_expr_stmt) = ctx.sema.ancestors_with_macros(expr.syntax().clone()).find(|node| { - LetStmt::can_cast(node.kind()) || ExprStmt::can_cast(node.kind()) - }) else { - // Is this reachable? - return expr.syntax().clone(); - }; - let_or_expr_stmt +// Pick the first ancestor expression of the unsafe `expr` that is not a +// receiver of a method call, a field access, the left-hand side of an +// assignment, or a reference. As all of those cases would incur a forced move +// if wrapped which might not be wanted. That is: +// - `unsafe_expr.foo` -> `unsafe { unsafe_expr.foo }` +// - `unsafe_expr.foo.bar` -> `unsafe { unsafe_expr.foo.bar }` +// - `unsafe_expr.foo()` -> `unsafe { unsafe_expr.foo() }` +// - `unsafe_expr.foo.bar()` -> `unsafe { unsafe_expr.foo.bar() }` +// - `unsafe_expr += 1` -> `unsafe { unsafe_expr += 1 }` +// - `&unsafe_expr` -> `unsafe { &unsafe_expr }` +// - `&&unsafe_expr` -> `unsafe { &&unsafe_expr }` +fn pick_best_node_to_add_unsafe_block(unsafe_expr: &ast::Expr) -> SyntaxNode { + // The `unsafe_expr` might be: + // - `ast::CallExpr`: call an unsafe function + // - `ast::MethodCallExpr`: call an unsafe method + // - `ast::PrefixExpr`: dereference a raw pointer + // - `ast::PathExpr`: access a static mut variable + for node in unsafe_expr.syntax().ancestors() { + let Some(parent) = node.parent() else { + return node; + }; + match parent.kind() { + syntax::SyntaxKind::METHOD_CALL_EXPR => { + // Check if the `node` is the receiver of the method call + let method_call_expr = ast::MethodCallExpr::cast(parent.clone()).unwrap(); + if method_call_expr + .receiver() + .map(|receiver| { + receiver.syntax().text_range().contains_range(node.text_range()) + }) + .unwrap_or(false) + { + // Actually, I think it's not necessary to check whether the + // text range of the `node` (which is the ancestor of the + // `unsafe_expr`) is contained in the text range of the + // receiver. The `node` could potentially be the receiver, the + // method name, or the argument list. Since the `node` is the + // ancestor of the unsafe_expr, it cannot be the method name. + // Additionally, if the `node` is the argument list, the loop + // would break at least when `parent` reaches the argument list. + // + // Dispite this, I still check the text range because I think it + // makes the code easier to understand. + continue; + } + return node; + } + syntax::SyntaxKind::FIELD_EXPR | syntax::SyntaxKind::REF_EXPR => continue, + syntax::SyntaxKind::BIN_EXPR => { + // Check if the `node` is the left-hand side of an assignment + let is_left_hand_side_of_assignment = { + let bin_expr = ast::BinExpr::cast(parent.clone()).unwrap(); + if let Some(ast::BinaryOp::Assignment { .. }) = bin_expr.op_kind() { + let is_left_hand_side = bin_expr + .lhs() + .map(|lhs| lhs.syntax().text_range().contains_range(node.text_range())) + .unwrap_or(false); + is_left_hand_side + } else { + false + } + }; + if !is_left_hand_side_of_assignment { + return node; + } + } + _ => { + return node; + } + } + } + unsafe_expr.syntax().clone() } #[cfg(test)] @@ -168,7 +206,7 @@ fn main() { r#" fn main() { let x = &5 as *const usize; - unsafe { let z = *x; } + let z = unsafe { *x }; } "#, ); @@ -192,7 +230,7 @@ unsafe fn func() { let z = *x; } fn main() { - unsafe { func(); } + unsafe { func() }; } "#, ) @@ -224,7 +262,7 @@ impl S { } fn main() { let s = S(5); - unsafe { s.func(); } + unsafe { s.func() }; } "#, ) @@ -252,7 +290,7 @@ struct Ty { static mut STATIC_MUT: Ty = Ty { a: 0 }; fn main() { - unsafe { let x = STATIC_MUT.a; } + let x = unsafe { STATIC_MUT.a }; } "#, ) @@ -276,7 +314,155 @@ extern "rust-intrinsic" { } fn main() { - unsafe { let _ = floorf32(12.0); } + let _ = unsafe { floorf32(12.0) }; +} +"#, + ) + } + + #[test] + fn unsafe_expr_as_a_receiver_of_a_method_call() { + check_fix( + r#" +unsafe fn foo() -> String { + "string".to_string() +} + +fn main() { + foo$0().len(); +} +"#, + r#" +unsafe fn foo() -> String { + "string".to_string() +} + +fn main() { + unsafe { foo().len() }; +} +"#, + ) + } + + #[test] + fn unsafe_expr_as_an_argument_of_a_method_call() { + check_fix( + r#" +static mut STATIC_MUT: u8 = 0; + +fn main() { + let mut v = vec![]; + v.push(STATIC_MUT$0); +} +"#, + r#" +static mut STATIC_MUT: u8 = 0; + +fn main() { + let mut v = vec![]; + v.push(unsafe { STATIC_MUT }); +} +"#, + ) + } + + #[test] + fn unsafe_expr_as_left_hand_side_of_assignment() { + check_fix( + r#" +static mut STATIC_MUT: u8 = 0; + +fn main() { + STATIC_MUT$0 = 1; +} +"#, + r#" +static mut STATIC_MUT: u8 = 0; + +fn main() { + unsafe { STATIC_MUT = 1 }; +} +"#, + ) + } + + #[test] + fn unsafe_expr_as_right_hand_side_of_assignment() { + check_fix( + r#" +static mut STATIC_MUT: u8 = 0; + +fn main() { + let x; + x = STATIC_MUT$0; +} +"#, + r#" +static mut STATIC_MUT: u8 = 0; + +fn main() { + let x; + x = unsafe { STATIC_MUT }; +} +"#, + ) + } + + #[test] + fn unsafe_expr_in_binary_plus() { + check_fix( + r#" +static mut STATIC_MUT: u8 = 0; + +fn main() { + let x = STATIC_MUT$0 + 1; +} +"#, + r#" +static mut STATIC_MUT: u8 = 0; + +fn main() { + let x = unsafe { STATIC_MUT } + 1; +} +"#, + ) + } + + #[test] + fn ref_to_unsafe_expr() { + check_fix( + r#" +static mut STATIC_MUT: u8 = 0; + +fn main() { + let x = &STATIC_MUT$0; +} +"#, + r#" +static mut STATIC_MUT: u8 = 0; + +fn main() { + let x = unsafe { &STATIC_MUT }; +} +"#, + ) + } + + #[test] + fn ref_ref_to_unsafe_expr() { + check_fix( + r#" +static mut STATIC_MUT: u8 = 0; + +fn main() { + let x = &&STATIC_MUT$0; +} +"#, + r#" +static mut STATIC_MUT: u8 = 0; + +fn main() { + let x = unsafe { &&STATIC_MUT }; } "#, )