Correctly handle inlining of async fn

This commit is contained in:
oxalica 2023-06-18 16:59:11 +08:00
parent fcfc6afe05
commit 52f1ce17aa
2 changed files with 146 additions and 10 deletions

View File

@ -15,7 +15,7 @@
}; };
use itertools::{izip, Itertools}; use itertools::{izip, Itertools};
use syntax::{ use syntax::{
ast::{self, edit_in_place::Indent, HasArgList, PathExpr}, ast::{self, edit::IndentLevel, edit_in_place::Indent, HasArgList, PathExpr},
ted, AstNode, NodeOrToken, SyntaxKind, ted, AstNode, NodeOrToken, SyntaxKind,
}; };
@ -306,7 +306,7 @@ fn inline(
params: &[(ast::Pat, Option<ast::Type>, hir::Param)], params: &[(ast::Pat, Option<ast::Type>, hir::Param)],
CallInfo { node, arguments, generic_arg_list }: &CallInfo, CallInfo { node, arguments, generic_arg_list }: &CallInfo,
) -> ast::Expr { ) -> ast::Expr {
let body = if sema.hir_file_for(fn_body.syntax()).is_macro() { let mut body = if sema.hir_file_for(fn_body.syntax()).is_macro() {
cov_mark::hit!(inline_call_defined_in_macro); cov_mark::hit!(inline_call_defined_in_macro);
if let Some(body) = ast::BlockExpr::cast(insert_ws_into(fn_body.syntax().clone())) { if let Some(body) = ast::BlockExpr::cast(insert_ws_into(fn_body.syntax().clone())) {
body body
@ -391,19 +391,19 @@ fn inline(
} }
} }
let mut let_stmts = Vec::new();
// Inline parameter expressions or generate `let` statements depending on whether inlining works or not. // Inline parameter expressions or generate `let` statements depending on whether inlining works or not.
for ((pat, param_ty, _), usages, expr) in izip!(params, param_use_nodes, arguments).rev() { for ((pat, param_ty, _), usages, expr) in izip!(params, param_use_nodes, arguments) {
// izip confuses RA due to our lack of hygiene info currently losing us type info causing incorrect errors // izip confuses RA due to our lack of hygiene info currently losing us type info causing incorrect errors
let usages: &[ast::PathExpr] = &usages; let usages: &[ast::PathExpr] = &usages;
let expr: &ast::Expr = expr; let expr: &ast::Expr = expr;
let insert_let_stmt = || { let mut insert_let_stmt = || {
let ty = sema.type_of_expr(expr).filter(TypeInfo::has_adjustment).and(param_ty.clone()); let ty = sema.type_of_expr(expr).filter(TypeInfo::has_adjustment).and(param_ty.clone());
if let Some(stmt_list) = body.stmt_list() { let_stmts.push(
stmt_list.push_front( make::let_stmt(pat.clone(), ty, Some(expr.clone())).clone_for_update().into(),
make::let_stmt(pat.clone(), ty, Some(expr.clone())).clone_for_update().into(), );
)
}
}; };
// check if there is a local var in the function that conflicts with parameter // check if there is a local var in the function that conflicts with parameter
@ -457,6 +457,24 @@ fn inline(
} }
} }
let is_async_fn = function.is_async(sema.db);
if is_async_fn {
cov_mark::hit!(inline_call_async_fn);
body = make::async_move_block_expr(body.statements(), body.tail_expr()).clone_for_update();
// Arguments should be evaluated outside the async block, and then moved into it.
if !let_stmts.is_empty() {
cov_mark::hit!(inline_call_async_fn_with_let_stmts);
body.indent(IndentLevel(1));
body = make::block_expr(let_stmts, Some(body.into())).clone_for_update();
}
} else if let Some(stmt_list) = body.stmt_list() {
ted::insert_all(
ted::Position::after(stmt_list.l_curly_token().unwrap()),
let_stmts.into_iter().map(|stmt| stmt.syntax().clone().into()).collect(),
);
}
let original_indentation = match node { let original_indentation = match node {
ast::CallableExpr::Call(it) => it.indent_level(), ast::CallableExpr::Call(it) => it.indent_level(),
ast::CallableExpr::MethodCall(it) => it.indent_level(), ast::CallableExpr::MethodCall(it) => it.indent_level(),
@ -464,7 +482,7 @@ fn inline(
body.reindent_to(original_indentation); body.reindent_to(original_indentation);
match body.tail_expr() { match body.tail_expr() {
Some(expr) if body.statements().next().is_none() => expr, Some(expr) if !is_async_fn && body.statements().next().is_none() => expr,
_ => match node _ => match node
.syntax() .syntax()
.parent() .parent()
@ -1351,6 +1369,109 @@ fn main() {
bar * b * a * 6 bar * b * a * 6
}; };
} }
"#,
);
}
#[test]
fn async_fn_single_expression() {
cov_mark::check!(inline_call_async_fn);
check_assist(
inline_call,
r#"
async fn bar(x: u32) -> u32 { x + 1 }
async fn foo(arg: u32) -> u32 {
bar(arg).await * 2
}
fn spawn<T>(_: T) {}
fn main() {
spawn(foo$0(42));
}
"#,
r#"
async fn bar(x: u32) -> u32 { x + 1 }
async fn foo(arg: u32) -> u32 {
bar(arg).await * 2
}
fn spawn<T>(_: T) {}
fn main() {
spawn(async move {
bar(42).await * 2
});
}
"#,
);
}
#[test]
fn async_fn_multiple_statements() {
cov_mark::check!(inline_call_async_fn);
check_assist(
inline_call,
r#"
async fn bar(x: u32) -> u32 { x + 1 }
async fn foo(arg: u32) -> u32 {
bar(arg).await;
42
}
fn spawn<T>(_: T) {}
fn main() {
spawn(foo$0(42));
}
"#,
r#"
async fn bar(x: u32) -> u32 { x + 1 }
async fn foo(arg: u32) -> u32 {
bar(arg).await;
42
}
fn spawn<T>(_: T) {}
fn main() {
spawn(async move {
bar(42).await;
42
});
}
"#,
);
}
#[test]
fn async_fn_with_let_statements() {
cov_mark::check!(inline_call_async_fn);
cov_mark::check!(inline_call_async_fn_with_let_stmts);
check_assist(
inline_call,
r#"
async fn bar(x: u32) -> u32 { x + 1 }
async fn foo(x: u32, y: u32, z: &u32) -> u32 {
bar(x).await;
y + y + *z
}
fn spawn<T>(_: T) {}
fn main() {
let var = 42;
spawn(foo$0(var, var + 1, &var));
}
"#,
r#"
async fn bar(x: u32) -> u32 { x + 1 }
async fn foo(x: u32, y: u32, z: &u32) -> u32 {
bar(x).await;
y + y + *z
}
fn spawn<T>(_: T) {}
fn main() {
let var = 42;
spawn({
let y = var + 1;
let z: &u32 = &var;
async move {
bar(var).await;
y + y + *z
}
});
}
"#, "#,
); );
} }

View File

@ -447,6 +447,21 @@ pub fn block_expr(
ast_from_text(&format!("fn f() {buf}")) ast_from_text(&format!("fn f() {buf}"))
} }
pub fn async_move_block_expr(
stmts: impl IntoIterator<Item = ast::Stmt>,
tail_expr: Option<ast::Expr>,
) -> ast::BlockExpr {
let mut buf = "async move {\n".to_string();
for stmt in stmts.into_iter() {
format_to!(buf, " {stmt}\n");
}
if let Some(tail_expr) = tail_expr {
format_to!(buf, " {tail_expr}\n");
}
buf += "}";
ast_from_text(&format!("const _: () = {buf};"))
}
pub fn tail_only_block_expr(tail_expr: ast::Expr) -> ast::BlockExpr { pub fn tail_only_block_expr(tail_expr: ast::Expr) -> ast::BlockExpr {
ast_from_text(&format!("fn f() {{ {tail_expr} }}")) ast_from_text(&format!("fn f() {{ {tail_expr} }}"))
} }