use hir::AssocItem; use ide_db::{ assists::{AssistId, AssistKind}, base_db::FileId, defs::Definition, search::FileReference, syntax_helpers::node_ext::full_path_of_name_ref, traits::resolve_target_trait, }; use syntax::{ ast::{self, HasName, NameLike, NameRef}, AstNode, SyntaxKind, TextRange, }; use crate::{AssistContext, Assists}; // Assist: unnecessary_async // // Removes the `async` mark from functions which have no `.await` in their body. // Looks for calls to the functions and removes the `.await` on the call site. // // ``` // pub async f$0n foo() {} // pub async fn bar() { foo().await } // ``` // -> // ``` // pub fn foo() {} // pub async fn bar() { foo() } // ``` pub(crate) fn unnecessary_async(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { let function: ast::Fn = ctx.find_node_at_offset()?; // Do nothing if the cursor is not on the prototype. This is so that the check does not pollute // when the user asks us for assists when in the middle of the function body. // We consider the prototype to be anything that is before the body of the function. let cursor_position = ctx.offset(); if cursor_position >= function.body()?.syntax().text_range().start() { return None; } // Do nothing if the function isn't async. if let None = function.async_token() { return None; } // Do nothing if the function has an `await` expression in its body. if function.body()?.syntax().descendants().find_map(ast::AwaitExpr::cast).is_some() { return None; } // Do nothing if the method is an async member of trait. if let Some(fname) = function.name() { if let Some(trait_item) = find_corresponding_trait_member(ctx, fname.to_string()) { if let AssocItem::Function(method) = trait_item { if method.is_async(ctx.db()) { return None; } } } } // Remove the `async` keyword plus whitespace after it, if any. let async_range = { let async_token = function.async_token()?; let next_token = async_token.next_token()?; if matches!(next_token.kind(), SyntaxKind::WHITESPACE) { TextRange::new(async_token.text_range().start(), next_token.text_range().end()) } else { async_token.text_range() } }; // Otherwise, we may remove the `async` keyword. acc.add( AssistId("unnecessary_async", AssistKind::QuickFix), "Remove unnecessary async", async_range, |edit| { // Remove async on the function definition. edit.replace(async_range, ""); // Remove all `.await`s from calls to the function we remove `async` from. if let Some(fn_def) = ctx.sema.to_def(&function) { for await_expr in find_all_references(ctx, &Definition::Function(fn_def)) // Keep only references that correspond NameRefs. .filter_map(|(_, reference)| match reference.name { NameLike::NameRef(nameref) => Some(nameref), _ => None, }) // Keep only references that correspond to await expressions .filter_map(|nameref| find_await_expression(ctx, &nameref)) { if let Some(await_token) = &await_expr.await_token() { edit.replace(await_token.text_range(), ""); } if let Some(dot_token) = &await_expr.dot_token() { edit.replace(dot_token.text_range(), ""); } } } }, ) } fn find_corresponding_trait_member( ctx: &AssistContext<'_>, function_name: String, ) -> Option { let impl_ = ctx.find_node_at_offset::()?; let trait_ = resolve_target_trait(&ctx.sema, &impl_)?; trait_ .items(ctx.db()) .iter() .find(|item| match item.name(ctx.db()) { Some(method_name) => method_name.to_string() == function_name, _ => false, }) .cloned() } fn find_all_references( ctx: &AssistContext<'_>, def: &Definition, ) -> impl Iterator { def.usages(&ctx.sema).all().into_iter().flat_map(|(file_id, references)| { references.into_iter().map(move |reference| (file_id, reference)) }) } /// Finds the await expression for the given `NameRef`. /// If no await expression is found, returns None. fn find_await_expression(ctx: &AssistContext<'_>, nameref: &NameRef) -> Option { // From the nameref, walk up the tree to the await expression. let await_expr = if let Some(path) = full_path_of_name_ref(&nameref) { // Function calls. path.syntax() .parent() .and_then(ast::PathExpr::cast)? .syntax() .parent() .and_then(ast::CallExpr::cast)? .syntax() .parent() .and_then(ast::AwaitExpr::cast) } else { // Method calls. nameref .syntax() .parent() .and_then(ast::MethodCallExpr::cast)? .syntax() .parent() .and_then(ast::AwaitExpr::cast) }; ctx.sema.original_ast_node(await_expr?) } #[cfg(test)] mod tests { use super::*; use crate::tests::{check_assist, check_assist_not_applicable}; #[test] fn applies_on_empty_function() { check_assist(unnecessary_async, "pub async f$0n f() {}", "pub fn f() {}") } #[test] fn applies_and_removes_whitespace() { check_assist(unnecessary_async, "pub async f$0n f() {}", "pub fn f() {}") } #[test] fn does_not_apply_on_non_async_function() { check_assist_not_applicable(unnecessary_async, "pub f$0n f() {}") } #[test] fn applies_on_function_with_a_non_await_expr() { check_assist(unnecessary_async, "pub async f$0n f() { f2() }", "pub fn f() { f2() }") } #[test] fn does_not_apply_on_function_with_an_await_expr() { check_assist_not_applicable(unnecessary_async, "pub async f$0n f() { f2().await }") } #[test] fn applies_and_removes_await_on_reference() { check_assist( unnecessary_async, r#" pub async fn f4() { } pub async f$0n f2() { } pub async fn f() { f2().await } pub async fn f3() { f2().await }"#, r#" pub async fn f4() { } pub fn f2() { } pub async fn f() { f2() } pub async fn f3() { f2() }"#, ) } #[test] fn applies_and_removes_await_from_within_module() { check_assist( unnecessary_async, r#" pub async fn f4() { } mod a { pub async f$0n f2() { } } pub async fn f() { a::f2().await } pub async fn f3() { a::f2().await }"#, r#" pub async fn f4() { } mod a { pub fn f2() { } } pub async fn f() { a::f2() } pub async fn f3() { a::f2() }"#, ) } #[test] fn applies_and_removes_await_on_inner_await() { check_assist( unnecessary_async, // Ensure that it is the first await on the 3rd line that is removed r#" pub async fn f() { f2().await } pub async f$0n f2() -> i32 { 1 } pub async fn f3() { f4(f2().await).await } pub async fn f4(i: i32) { }"#, r#" pub async fn f() { f2() } pub fn f2() -> i32 { 1 } pub async fn f3() { f4(f2()).await } pub async fn f4(i: i32) { }"#, ) } #[test] fn applies_and_removes_await_on_outer_await() { check_assist( unnecessary_async, // Ensure that it is the second await on the 3rd line that is removed r#" pub async fn f() { f2().await } pub async f$0n f2(i: i32) { } pub async fn f3() { f2(f4().await).await } pub async fn f4() -> i32 { 1 }"#, r#" pub async fn f() { f2() } pub fn f2(i: i32) { } pub async fn f3() { f2(f4().await) } pub async fn f4() -> i32 { 1 }"#, ) } #[test] fn applies_on_method_call() { check_assist( unnecessary_async, r#" pub struct S { } impl S { pub async f$0n f2(&self) { } } pub async fn f(s: &S) { s.f2().await }"#, r#" pub struct S { } impl S { pub fn f2(&self) { } } pub async fn f(s: &S) { s.f2() }"#, ) } #[test] fn does_not_apply_on_function_with_a_nested_await_expr() { check_assist_not_applicable( unnecessary_async, "async f$0n f() { if true { loop { f2().await } } }", ) } #[test] fn does_not_apply_when_not_on_prototype() { check_assist_not_applicable(unnecessary_async, "pub async fn f() { $0f2() }") } #[test] fn applies_on_unnecessary_async_on_trait_method() { check_assist( unnecessary_async, r#" trait Trait { fn foo(); } impl Trait for () { $0async fn foo() {} }"#, r#" trait Trait { fn foo(); } impl Trait for () { fn foo() {} }"#, ); } #[test] fn does_not_apply_on_async_trait_method() { check_assist_not_applicable( unnecessary_async, r#" trait Trait { async fn foo(); } impl Trait for () { $0async fn foo() {} }"#, ); } }