wrap_return_type_in_result works on the HIR

This commit is contained in:
Lukas Wirth 2021-07-02 23:19:32 +02:00
parent 3770fce086
commit 26dd0c4e5b
2 changed files with 90 additions and 99 deletions

View File

@ -1,6 +1,6 @@
use std::iter; use std::iter;
use ide_db::helpers::for_each_tail_expr; use ide_db::helpers::{for_each_tail_expr, FamousDefs};
use syntax::{ use syntax::{
ast::{self, make, Expr}, ast::{self, make, Expr},
match_ast, AstNode, match_ast, AstNode,
@ -33,16 +33,15 @@ pub(crate) fn wrap_return_type_in_result(acc: &mut Assists, ctx: &AssistContext)
_ => return None, _ => return None,
} }
}; };
let body = ast::Expr::BlockExpr(body);
let type_ref = &ret_type.ty()?; let type_ref = &ret_type.ty()?;
let ret_type_str = type_ref.syntax().text().to_string(); let ty = ctx.sema.resolve_type(type_ref).and_then(|ty| ty.as_adt());
let first_part_ret_type = ret_type_str.splitn(2, '<').next(); let result_enum =
if let Some(ret_type_first_part) = first_part_ret_type { FamousDefs(&ctx.sema, ctx.sema.scope(type_ref.syntax()).krate()).core_result_Result()?;
if ret_type_first_part.ends_with("Result") {
cov_mark::hit!(wrap_return_type_in_result_simple_return_type_already_result); if matches!(ty, Some(hir::Adt::Enum(ret_type)) if ret_type == result_enum) {
return None; cov_mark::hit!(wrap_return_type_in_result_simple_return_type_already_result);
} return None;
} }
acc.add( acc.add(
@ -50,6 +49,8 @@ pub(crate) fn wrap_return_type_in_result(acc: &mut Assists, ctx: &AssistContext)
"Wrap return type in Result", "Wrap return type in Result",
type_ref.syntax().text_range(), type_ref.syntax().text_range(),
|builder| { |builder| {
let body = ast::Expr::BlockExpr(body);
let mut exprs_to_wrap = Vec::new(); let mut exprs_to_wrap = Vec::new();
let tail_cb = &mut |e: &_| tail_cb_impl(&mut exprs_to_wrap, e); let tail_cb = &mut |e: &_| tail_cb_impl(&mut exprs_to_wrap, e);
body.walk(&mut |expr| { body.walk(&mut |expr| {
@ -88,6 +89,11 @@ fn tail_cb_impl(acc: &mut Vec<ast::Expr>, e: &ast::Expr) {
for_each_tail_expr(&break_expr_arg, &mut |e| tail_cb_impl(acc, e)) for_each_tail_expr(&break_expr_arg, &mut |e| tail_cb_impl(acc, e))
} }
} }
Expr::ReturnExpr(ret_expr) => {
if let Some(ret_expr_arg) = &ret_expr.expr() {
for_each_tail_expr(ret_expr_arg, &mut |e| tail_cb_impl(acc, e));
}
}
e => acc.push(e.clone()), e => acc.push(e.clone()),
} }
} }
@ -98,10 +104,17 @@ mod tests {
use super::*; use super::*;
#[test] fn check(ra_fixture_before: &str, ra_fixture_after: &str) {
fn wrap_return_type_in_result_simple() {
check_assist( check_assist(
wrap_return_type_in_result, wrap_return_type_in_result,
&format!("//- minicore: result\n{}", ra_fixture_before.trim_start()),
ra_fixture_after,
);
}
#[test]
fn wrap_return_type_in_result_simple() {
check(
r#" r#"
fn foo() -> i3$02 { fn foo() -> i3$02 {
let test = "test"; let test = "test";
@ -119,8 +132,7 @@ fn foo() -> Result<i32, ${0:_}> {
#[test] #[test]
fn wrap_return_type_break_split_tail() { fn wrap_return_type_break_split_tail() {
check_assist( check(
wrap_return_type_in_result,
r#" r#"
fn foo() -> i3$02 { fn foo() -> i3$02 {
loop { loop {
@ -148,8 +160,7 @@ fn foo() -> Result<i32, ${0:_}> {
#[test] #[test]
fn wrap_return_type_in_result_simple_closure() { fn wrap_return_type_in_result_simple_closure() {
check_assist( check(
wrap_return_type_in_result,
r#" r#"
fn foo() { fn foo() {
|| -> i32$0 { || -> i32$0 {
@ -207,7 +218,8 @@ fn wrap_return_type_in_result_simple_return_type_already_result_std() {
check_assist_not_applicable( check_assist_not_applicable(
wrap_return_type_in_result, wrap_return_type_in_result,
r#" r#"
fn foo() -> std::result::Result<i32$0, String> { //- minicore: result
fn foo() -> core::result::Result<i32$0, String> {
let test = "test"; let test = "test";
return 42i32; return 42i32;
} }
@ -221,6 +233,7 @@ fn wrap_return_type_in_result_simple_return_type_already_result() {
check_assist_not_applicable( check_assist_not_applicable(
wrap_return_type_in_result, wrap_return_type_in_result,
r#" r#"
//- minicore: result
fn foo() -> Result<i32$0, String> { fn foo() -> Result<i32$0, String> {
let test = "test"; let test = "test";
return 42i32; return 42i32;
@ -246,8 +259,7 @@ fn foo() {
#[test] #[test]
fn wrap_return_type_in_result_simple_with_cursor() { fn wrap_return_type_in_result_simple_with_cursor() {
check_assist( check(
wrap_return_type_in_result,
r#" r#"
fn foo() -> $0i32 { fn foo() -> $0i32 {
let test = "test"; let test = "test";
@ -265,8 +277,7 @@ fn foo() -> Result<i32, ${0:_}> {
#[test] #[test]
fn wrap_return_type_in_result_simple_with_tail() { fn wrap_return_type_in_result_simple_with_tail() {
check_assist( check(
wrap_return_type_in_result,
r#" r#"
fn foo() ->$0 i32 { fn foo() ->$0 i32 {
let test = "test"; let test = "test";
@ -284,8 +295,7 @@ fn foo() -> Result<i32, ${0:_}> {
#[test] #[test]
fn wrap_return_type_in_result_simple_with_tail_closure() { fn wrap_return_type_in_result_simple_with_tail_closure() {
check_assist( check(
wrap_return_type_in_result,
r#" r#"
fn foo() { fn foo() {
|| ->$0 i32 { || ->$0 i32 {
@ -307,17 +317,12 @@ fn foo() {
#[test] #[test]
fn wrap_return_type_in_result_simple_with_tail_only() { fn wrap_return_type_in_result_simple_with_tail_only() {
check_assist( check(r#"fn foo() -> i32$0 { 42i32 }"#, r#"fn foo() -> Result<i32, ${0:_}> { Ok(42i32) }"#);
wrap_return_type_in_result,
r#"fn foo() -> i32$0 { 42i32 }"#,
r#"fn foo() -> Result<i32, ${0:_}> { Ok(42i32) }"#,
);
} }
#[test] #[test]
fn wrap_return_type_in_result_simple_with_tail_block_like() { fn wrap_return_type_in_result_simple_with_tail_block_like() {
check_assist( check(
wrap_return_type_in_result,
r#" r#"
fn foo() -> i32$0 { fn foo() -> i32$0 {
if true { if true {
@ -341,8 +346,7 @@ fn foo() -> Result<i32, ${0:_}> {
#[test] #[test]
fn wrap_return_type_in_result_simple_without_block_closure() { fn wrap_return_type_in_result_simple_without_block_closure() {
check_assist( check(
wrap_return_type_in_result,
r#" r#"
fn foo() { fn foo() {
|| -> i32$0 { || -> i32$0 {
@ -370,8 +374,7 @@ fn foo() {
#[test] #[test]
fn wrap_return_type_in_result_simple_with_nested_if() { fn wrap_return_type_in_result_simple_with_nested_if() {
check_assist( check(
wrap_return_type_in_result,
r#" r#"
fn foo() -> i32$0 { fn foo() -> i32$0 {
if true { if true {
@ -403,8 +406,7 @@ fn foo() -> Result<i32, ${0:_}> {
#[test] #[test]
fn wrap_return_type_in_result_simple_with_await() { fn wrap_return_type_in_result_simple_with_await() {
check_assist( check(
wrap_return_type_in_result,
r#" r#"
async fn foo() -> i$032 { async fn foo() -> i$032 {
if true { if true {
@ -436,8 +438,7 @@ async fn foo() -> Result<i32, ${0:_}> {
#[test] #[test]
fn wrap_return_type_in_result_simple_with_array() { fn wrap_return_type_in_result_simple_with_array() {
check_assist( check(
wrap_return_type_in_result,
r#"fn foo() -> [i32;$0 3] { [1, 2, 3] }"#, r#"fn foo() -> [i32;$0 3] { [1, 2, 3] }"#,
r#"fn foo() -> Result<[i32; 3], ${0:_}> { Ok([1, 2, 3]) }"#, r#"fn foo() -> Result<[i32; 3], ${0:_}> { Ok([1, 2, 3]) }"#,
); );
@ -445,8 +446,7 @@ fn wrap_return_type_in_result_simple_with_array() {
#[test] #[test]
fn wrap_return_type_in_result_simple_with_cast() { fn wrap_return_type_in_result_simple_with_cast() {
check_assist( check(
wrap_return_type_in_result,
r#" r#"
fn foo() -$0> i32 { fn foo() -$0> i32 {
if true { if true {
@ -478,8 +478,7 @@ fn foo() -> Result<i32, ${0:_}> {
#[test] #[test]
fn wrap_return_type_in_result_simple_with_tail_block_like_match() { fn wrap_return_type_in_result_simple_with_tail_block_like_match() {
check_assist( check(
wrap_return_type_in_result,
r#" r#"
fn foo() -> i32$0 { fn foo() -> i32$0 {
let my_var = 5; let my_var = 5;
@ -503,8 +502,7 @@ fn foo() -> Result<i32, ${0:_}> {
#[test] #[test]
fn wrap_return_type_in_result_simple_with_loop_with_tail() { fn wrap_return_type_in_result_simple_with_loop_with_tail() {
check_assist( check(
wrap_return_type_in_result,
r#" r#"
fn foo() -> i32$0 { fn foo() -> i32$0 {
let my_var = 5; let my_var = 5;
@ -530,8 +528,7 @@ fn foo() -> Result<i32, ${0:_}> {
#[test] #[test]
fn wrap_return_type_in_result_simple_with_loop_in_let_stmt() { fn wrap_return_type_in_result_simple_with_loop_in_let_stmt() {
check_assist( check(
wrap_return_type_in_result,
r#" r#"
fn foo() -> i32$0 { fn foo() -> i32$0 {
let my_var = let x = loop { let my_var = let x = loop {
@ -553,8 +550,7 @@ fn foo() -> Result<i32, ${0:_}> {
#[test] #[test]
fn wrap_return_type_in_result_simple_with_tail_block_like_match_return_expr() { fn wrap_return_type_in_result_simple_with_tail_block_like_match_return_expr() {
check_assist( check(
wrap_return_type_in_result,
r#" r#"
fn foo() -> i32$0 { fn foo() -> i32$0 {
let my_var = 5; let my_var = 5;
@ -577,8 +573,7 @@ fn foo() -> Result<i32, ${0:_}> {
"#, "#,
); );
check_assist( check(
wrap_return_type_in_result,
r#" r#"
fn foo() -> i32$0 { fn foo() -> i32$0 {
let my_var = 5; let my_var = 5;
@ -606,8 +601,7 @@ fn foo() -> Result<i32, ${0:_}> {
#[test] #[test]
fn wrap_return_type_in_result_simple_with_tail_block_like_match_deeper() { fn wrap_return_type_in_result_simple_with_tail_block_like_match_deeper() {
check_assist( check(
wrap_return_type_in_result,
r#" r#"
fn foo() -> i32$0 { fn foo() -> i32$0 {
let my_var = 5; let my_var = 5;
@ -655,8 +649,7 @@ fn foo() -> Result<i32, ${0:_}> {
#[test] #[test]
fn wrap_return_type_in_result_simple_with_tail_block_like_early_return() { fn wrap_return_type_in_result_simple_with_tail_block_like_early_return() {
check_assist( check(
wrap_return_type_in_result,
r#" r#"
fn foo() -> i$032 { fn foo() -> i$032 {
let test = "test"; let test = "test";
@ -680,8 +673,7 @@ fn foo() -> Result<i32, ${0:_}> {
#[test] #[test]
fn wrap_return_type_in_result_simple_with_closure() { fn wrap_return_type_in_result_simple_with_closure() {
check_assist( check(
wrap_return_type_in_result,
r#" r#"
fn foo(the_field: u32) ->$0 u32 { fn foo(the_field: u32) ->$0 u32 {
let true_closure = || { return true; }; let true_closure = || { return true; };
@ -712,55 +704,53 @@ fn foo(the_field: u32) -> Result<u32, ${0:_}> {
"#, "#,
); );
check_assist( check(
wrap_return_type_in_result,
r#" r#"
fn foo(the_field: u32) -> u32$0 { fn foo(the_field: u32) -> u32$0 {
let true_closure = || { let true_closure = || {
return true; return true;
}; };
if the_field < 5 { if the_field < 5 {
let mut i = 0; let mut i = 0;
if true_closure() { if true_closure() {
return 99; return 99;
} else { } else {
return 0; return 0;
} }
} }
let t = None; let t = None;
t.unwrap_or_else(|| the_field) t.unwrap_or_else(|| the_field)
} }
"#, "#,
r#" r#"
fn foo(the_field: u32) -> Result<u32, ${0:_}> { fn foo(the_field: u32) -> Result<u32, ${0:_}> {
let true_closure = || { let true_closure = || {
return true; return true;
}; };
if the_field < 5 { if the_field < 5 {
let mut i = 0; let mut i = 0;
if true_closure() { if true_closure() {
return Ok(99); return Ok(99);
} else { } else {
return Ok(0); return Ok(0);
} }
} }
let t = None; let t = None;
Ok(t.unwrap_or_else(|| the_field)) Ok(t.unwrap_or_else(|| the_field))
} }
"#, "#,
); );
} }
#[test] #[test]
fn wrap_return_type_in_result_simple_with_weird_forms() { fn wrap_return_type_in_result_simple_with_weird_forms() {
check_assist( check(
wrap_return_type_in_result,
r#" r#"
fn foo() -> i32$0 { fn foo() -> i32$0 {
let test = "test"; let test = "test";
@ -793,8 +783,7 @@ fn foo() -> Result<i32, ${0:_}> {
"#, "#,
); );
check_assist( check(
wrap_return_type_in_result,
r#" r#"
fn foo(the_field: u32) -> u32$0 { fn foo(the_field: u32) -> u32$0 {
if the_field < 5 { if the_field < 5 {
@ -833,8 +822,7 @@ fn foo(the_field: u32) -> Result<u32, ${0:_}> {
"#, "#,
); );
check_assist( check(
wrap_return_type_in_result,
r#" r#"
fn foo(the_field: u32) -> u3$02 { fn foo(the_field: u32) -> u3$02 {
if the_field < 5 { if the_field < 5 {
@ -861,8 +849,7 @@ fn foo(the_field: u32) -> Result<u32, ${0:_}> {
"#, "#,
); );
check_assist( check(
wrap_return_type_in_result,
r#" r#"
fn foo(the_field: u32) -> u32$0 { fn foo(the_field: u32) -> u32$0 {
if the_field < 5 { if the_field < 5 {
@ -891,8 +878,7 @@ fn foo(the_field: u32) -> Result<u32, ${0:_}> {
"#, "#,
); );
check_assist( check(
wrap_return_type_in_result,
r#" r#"
fn foo(the_field: u32) -> $0u32 { fn foo(the_field: u32) -> $0u32 {
if the_field < 5 { if the_field < 5 {

View File

@ -122,6 +122,10 @@ pub fn core_option_Option(&self) -> Option<Enum> {
self.find_enum("core:option:Option") self.find_enum("core:option:Option")
} }
pub fn core_result_Result(&self) -> Option<Enum> {
self.find_enum("core:result:Result")
}
pub fn core_default_Default(&self) -> Option<Trait> { pub fn core_default_Default(&self) -> Option<Trait> {
self.find_trait("core:default:Default") self.find_trait("core:default:Default")
} }
@ -206,6 +210,7 @@ pub const fn new(allow_snippets: bool) -> Option<SnippetCap> {
} }
/// Calls `cb` on each expression inside `expr` that is at "tail position". /// Calls `cb` on each expression inside `expr` that is at "tail position".
/// Does not walk into `break` or `return` expressions.
pub fn for_each_tail_expr(expr: &ast::Expr, cb: &mut dyn FnMut(&ast::Expr)) { pub fn for_each_tail_expr(expr: &ast::Expr, cb: &mut dyn FnMut(&ast::Expr)) {
match expr { match expr {
ast::Expr::BlockExpr(b) => { ast::Expr::BlockExpr(b) => {