diff --git a/crates/assists/src/handlers/infer_function_return_type.rs b/crates/assists/src/handlers/infer_function_return_type.rs index 520d07ae067..aa584eb0348 100644 --- a/crates/assists/src/handlers/infer_function_return_type.rs +++ b/crates/assists/src/handlers/infer_function_return_type.rs @@ -17,7 +17,7 @@ // fn foo() -> i32 { 42i32 } // ``` pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { - let (tail_expr, builder_edit_pos, wrap_expr) = extract_tail(ctx)?; + let (fn_type, tail_expr, builder_edit_pos) = extract_tail(ctx)?; let module = ctx.sema.scope(tail_expr.syntax()).module()?; let ty = ctx.sema.type_of_expr(&tail_expr)?; if ty.is_unit() { @@ -27,7 +27,10 @@ pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext) acc.add( AssistId("infer_function_return_type", AssistKind::RefactorRewrite), - "Add this function's return type", + match fn_type { + FnType::Function => "Add this function's return type", + FnType::Closure { .. } => "Add this closure's return type", + }, tail_expr.syntax().text_range(), |builder| { match builder_edit_pos { @@ -38,7 +41,7 @@ pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext) builder.replace(text_range, &format!("-> {}", ty)) } } - if wrap_expr { + if let FnType::Closure { wrap_expr: true } = fn_type { mark::hit!(wrap_closure_non_block_expr); // `|x| x` becomes `|x| -> T x` which is invalid, so wrap it in a block builder.replace(tail_expr.syntax().text_range(), &format!("{{{}}}", tail_expr)); @@ -72,8 +75,13 @@ fn ret_ty_to_action(ret_ty: Option, insert_pos: TextSize) -> Optio } } -fn extract_tail(ctx: &AssistContext) -> Option<(ast::Expr, InsertOrReplace, bool)> { - let (tail_expr, return_type_range, action, wrap_expr) = +enum FnType { + Function, + Closure { wrap_expr: bool }, +} + +fn extract_tail(ctx: &AssistContext) -> Option<(FnType, ast::Expr, InsertOrReplace)> { + let (fn_type, tail_expr, return_type_range, action) = if let Some(closure) = ctx.find_node_at_offset::() { let rpipe_pos = closure.param_list()?.syntax().last_token()?.text_range().end(); let action = ret_ty_to_action(closure.ret_type(), rpipe_pos)?; @@ -86,7 +94,7 @@ fn extract_tail(ctx: &AssistContext) -> Option<(ast::Expr, InsertOrReplace, bool }; let ret_range = TextRange::new(rpipe_pos, body_start); - (tail_expr, ret_range, action, wrap_expr) + (FnType::Closure { wrap_expr }, tail_expr, ret_range, action) } else { let func = ctx.find_node_at_offset::()?; let rparen_pos = func.param_list()?.r_paren_token()?.text_range().end(); @@ -97,7 +105,7 @@ fn extract_tail(ctx: &AssistContext) -> Option<(ast::Expr, InsertOrReplace, bool let ret_range_end = body.l_curly_token()?.text_range().start(); let ret_range = TextRange::new(rparen_pos, ret_range_end); - (tail_expr, ret_range, action, false) + (FnType::Function, tail_expr, ret_range, action) }; let frange = ctx.frange.range; if return_type_range.contains_range(frange) { @@ -109,7 +117,7 @@ fn extract_tail(ctx: &AssistContext) -> Option<(ast::Expr, InsertOrReplace, bool } else { return None; } - Some((tail_expr, action, wrap_expr)) + Some((fn_type, tail_expr, action)) } #[cfg(test)]