Add infer_function_return_type assist
This commit is contained in:
parent
7709b6a2d4
commit
a14df19d82
113
crates/assists/src/handlers/infer_function_return_type.rs
Normal file
113
crates/assists/src/handlers/infer_function_return_type.rs
Normal file
@ -0,0 +1,113 @@
|
||||
use hir::HirDisplay;
|
||||
use syntax::{ast, AstNode, TextSize};
|
||||
use test_utils::mark;
|
||||
|
||||
use crate::{AssistContext, AssistId, AssistKind, Assists};
|
||||
|
||||
// Assist: infer_function_return_type
|
||||
//
|
||||
// Adds the return type to a function inferred from its tail expression if it doesn't have a return
|
||||
// type specified.
|
||||
//
|
||||
// ```
|
||||
// fn foo() { 4<|>2i32 }
|
||||
// ```
|
||||
// ->
|
||||
// ```
|
||||
// fn foo() -> i32 { 42i32 }
|
||||
// ```
|
||||
pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
|
||||
let expr = ctx.find_node_at_offset::<ast::Expr>()?;
|
||||
let func = expr.syntax().ancestors().find_map(ast::Fn::cast)?;
|
||||
|
||||
if func.ret_type().is_some() {
|
||||
mark::hit!(existing_ret_type);
|
||||
return None;
|
||||
}
|
||||
let body = func.body()?;
|
||||
let tail_expr = body.expr()?;
|
||||
// check whether the expr we were at is indeed the tail expression
|
||||
if !tail_expr.syntax().text_range().contains_range(expr.syntax().text_range()) {
|
||||
mark::hit!(not_tail_expr);
|
||||
return None;
|
||||
}
|
||||
let module = ctx.sema.scope(func.syntax()).module()?;
|
||||
let ty = ctx.sema.type_of_expr(&tail_expr)?;
|
||||
let ty = ty.display_source_code(ctx.db(), module.into()).ok()?;
|
||||
let rparen = func.param_list()?.r_paren_token()?;
|
||||
|
||||
acc.add(
|
||||
AssistId("change_return_type_to_result", AssistKind::RefactorRewrite),
|
||||
"Wrap return type in Result",
|
||||
tail_expr.syntax().text_range(),
|
||||
|builder| {
|
||||
let insert_pos = rparen.text_range().end() + TextSize::from(1);
|
||||
|
||||
builder.insert(insert_pos, &format!("-> {} ", ty));
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::tests::{check_assist, check_assist_not_applicable};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn infer_return_type() {
|
||||
check_assist(
|
||||
infer_function_return_type,
|
||||
r#"fn foo() {
|
||||
45<|>
|
||||
}"#,
|
||||
r#"fn foo() -> i32 {
|
||||
45
|
||||
}"#,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn infer_return_type_nested() {
|
||||
check_assist(
|
||||
infer_function_return_type,
|
||||
r#"fn foo() {
|
||||
if true {
|
||||
3<|>
|
||||
} else {
|
||||
5
|
||||
}
|
||||
}"#,
|
||||
r#"fn foo() -> i32 {
|
||||
if true {
|
||||
3
|
||||
} else {
|
||||
5
|
||||
}
|
||||
}"#,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn not_applicable_ret_type_specified() {
|
||||
mark::check!(existing_ret_type);
|
||||
check_assist_not_applicable(
|
||||
infer_function_return_type,
|
||||
r#"fn foo() -> i32 {
|
||||
( 45<|> + 32 ) * 123
|
||||
}"#,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn not_applicable_non_tail_expr() {
|
||||
mark::check!(not_tail_expr);
|
||||
check_assist_not_applicable(
|
||||
infer_function_return_type,
|
||||
r#"fn foo() {
|
||||
let x = <|>3;
|
||||
( 45 + 32 ) * 123
|
||||
}"#,
|
||||
);
|
||||
}
|
||||
}
|
@ -143,6 +143,7 @@ mod handlers {
|
||||
mod generate_function;
|
||||
mod generate_impl;
|
||||
mod generate_new;
|
||||
mod infer_function_return_type;
|
||||
mod inline_local_variable;
|
||||
mod introduce_named_lifetime;
|
||||
mod invert_if;
|
||||
@ -190,6 +191,7 @@ pub(crate) fn all() -> &'static [Handler] {
|
||||
generate_function::generate_function,
|
||||
generate_impl::generate_impl,
|
||||
generate_new::generate_new,
|
||||
infer_function_return_type::infer_function_return_type,
|
||||
inline_local_variable::inline_local_variable,
|
||||
introduce_named_lifetime::introduce_named_lifetime,
|
||||
invert_if::invert_if,
|
||||
|
@ -505,6 +505,19 @@ fn $0new(data: T) -> Self { Self { data } }
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn doctest_infer_function_return_type() {
|
||||
check_doc_test(
|
||||
"infer_function_return_type",
|
||||
r#####"
|
||||
fn foo() { 4<|>2i32 }
|
||||
"#####,
|
||||
r#####"
|
||||
fn foo() -> i32 { 42i32 }
|
||||
"#####,
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn doctest_inline_local_variable() {
|
||||
check_doc_test(
|
||||
|
Loading…
Reference in New Issue
Block a user