diff --git a/crates/ide-assists/src/handlers/unwrap_result_return_type.rs b/crates/ide-assists/src/handlers/unwrap_result_return_type.rs index 9ef4ae047ef..26f3c192617 100644 --- a/crates/ide-assists/src/handlers/unwrap_result_return_type.rs +++ b/crates/ide-assists/src/handlers/unwrap_result_return_type.rs @@ -5,7 +5,7 @@ use itertools::Itertools; use syntax::{ ast::{self, Expr}, - match_ast, AstNode, TextRange, TextSize, + match_ast, AstNode, NodeOrToken, SyntaxKind, TextRange, }; use crate::{AssistContext, AssistId, AssistKind, Assists}; @@ -38,14 +38,15 @@ pub(crate) fn unwrap_result_return_type(acc: &mut Assists, ctx: &AssistContext<' }; let type_ref = &ret_type.ty()?; - let ty = ctx.sema.resolve_type(type_ref)?.as_adt(); + let Some(hir::Adt::Enum(ret_enum)) = ctx.sema.resolve_type(type_ref)?.as_adt() else { return None; }; let result_enum = FamousDefs(&ctx.sema, ctx.sema.scope(type_ref.syntax())?.krate()).core_result_Result()?; - - if !matches!(ty, Some(hir::Adt::Enum(ret_type)) if ret_type == result_enum) { + if ret_enum != result_enum { return None; } + let Some(ok_type) = unwrap_result_type(type_ref) else { return None; }; + acc.add( AssistId("unwrap_result_return_type", AssistKind::RefactorRewrite), "Unwrap Result return type", @@ -64,26 +65,19 @@ pub(crate) fn unwrap_result_return_type(acc: &mut Assists, ctx: &AssistContext<' }); for_each_tail_expr(&body, tail_cb); - let mut is_unit_type = false; - if let Some((_, inner_type)) = type_ref.to_string().split_once('<') { - let inner_type = match inner_type.split_once(',') { - Some((success_inner_type, _)) => success_inner_type, - None => inner_type, - }; - let new_ret_type = inner_type.strip_suffix('>').unwrap_or(inner_type); - if new_ret_type == "()" { - is_unit_type = true; - let text_range = TextRange::new( - ret_type.syntax().text_range().start(), - ret_type.syntax().text_range().end() + TextSize::from(1u32), - ); - builder.delete(text_range) - } else { - builder.replace( - type_ref.syntax().text_range(), - inner_type.strip_suffix('>').unwrap_or(inner_type), - ) + let is_unit_type = is_unit_type(&ok_type); + if is_unit_type { + let mut text_range = ret_type.syntax().text_range(); + + if let Some(NodeOrToken::Token(token)) = ret_type.syntax().next_sibling_or_token() { + if token.kind() == SyntaxKind::WHITESPACE { + text_range = TextRange::new(text_range.start(), token.text_range().end()); + } } + + builder.delete(text_range); + } else { + builder.replace(type_ref.syntax().text_range(), ok_type.syntax().text()); } for ret_expr_arg in exprs_to_unwrap { @@ -134,6 +128,22 @@ fn tail_cb_impl(acc: &mut Vec, e: &ast::Expr) { } } +// Tries to extract `T` from `Result`. +fn unwrap_result_type(ty: &ast::Type) -> Option { + let ast::Type::PathType(path_ty) = ty else { return None; }; + let path = path_ty.path()?; + let segment = path.first_segment()?; + let generic_arg_list = segment.generic_arg_list()?; + let generic_args: Vec<_> = generic_arg_list.generic_args().collect(); + let ast::GenericArg::TypeArg(ok_type) = generic_args.first()? else { return None; }; + ok_type.ty() +} + +fn is_unit_type(ty: &ast::Type) -> bool { + let ast::Type::TupleType(tuple) = ty else { return false }; + tuple.fields().next().is_none() +} + #[cfg(test)] mod tests { use crate::tests::{check_assist, check_assist_not_applicable}; @@ -173,6 +183,21 @@ fn foo() -> Result<(), Box> { r#" fn foo() { } +"#, + ); + + // Unformatted return type + check_assist( + unwrap_result_return_type, + r#" +//- minicore: result +fn foo() -> Result<(), Box>{ + Ok(()) +} +"#, + r#" +fn foo() { +} "#, ); } @@ -1014,6 +1039,54 @@ fn foo(the_field: u32) -> u32 { } the_field } +"#, + ); + } + + #[test] + fn unwrap_result_return_type_nested_type() { + check_assist( + unwrap_result_return_type, + r#" +//- minicore: result, option +fn foo() -> Result, ()> { + Ok(Some(42)) +} +"#, + r#" +fn foo() -> Option { + Some(42) +} +"#, + ); + + check_assist( + unwrap_result_return_type, + r#" +//- minicore: result, option +fn foo() -> Result>, ()> { + Ok(None) +} +"#, + r#" +fn foo() -> Option> { + None +} +"#, + ); + + check_assist( + unwrap_result_return_type, + r#" +//- minicore: result, option, iterators +fn foo() -> Result$0, ()> { + Ok(Some(42).into_iter()) +} +"#, + r#" +fn foo() -> impl Iterator { + Some(42).into_iter() +} "#, ); }