diff --git a/crates/ide-assists/src/handlers/extract_function.rs b/crates/ide-assists/src/handlers/extract_function.rs index e9db38aca0f..80a60b4a667 100644 --- a/crates/ide-assists/src/handlers/extract_function.rs +++ b/crates/ide-assists/src/handlers/extract_function.rs @@ -1385,31 +1385,30 @@ enum FlowHandler { impl FlowHandler { fn from_ret_ty(fun: &Function, ret_ty: &FunType) -> FlowHandler { - match &fun.control_flow.kind { - None => FlowHandler::None, - Some(flow_kind) => { - let action = flow_kind.clone(); - if let FunType::Unit = ret_ty { - match flow_kind { - FlowKind::Return(None) - | FlowKind::Break(_, None) - | FlowKind::Continue(_) => FlowHandler::If { action }, - FlowKind::Return(_) | FlowKind::Break(_, _) => { - FlowHandler::IfOption { action } - } - FlowKind::Try { kind } => FlowHandler::Try { kind: kind.clone() }, - } - } else { - match flow_kind { - FlowKind::Return(None) - | FlowKind::Break(_, None) - | FlowKind::Continue(_) => FlowHandler::MatchOption { none: action }, - FlowKind::Return(_) | FlowKind::Break(_, _) => { - FlowHandler::MatchResult { err: action } - } - FlowKind::Try { kind } => FlowHandler::Try { kind: kind.clone() }, - } + if fun.contains_tail_expr { + return FlowHandler::None; + } + let Some(action) = fun.control_flow.kind.clone() else { + return FlowHandler::None; + }; + + if let FunType::Unit = ret_ty { + match action { + FlowKind::Return(None) | FlowKind::Break(_, None) | FlowKind::Continue(_) => { + FlowHandler::If { action } } + FlowKind::Return(_) | FlowKind::Break(_, _) => FlowHandler::IfOption { action }, + FlowKind::Try { kind } => FlowHandler::Try { kind }, + } + } else { + match action { + FlowKind::Return(None) | FlowKind::Break(_, None) | FlowKind::Continue(_) => { + FlowHandler::MatchOption { none: action } + } + FlowKind::Return(_) | FlowKind::Break(_, _) => { + FlowHandler::MatchResult { err: action } + } + FlowKind::Try { kind } => FlowHandler::Try { kind }, } } } @@ -1654,11 +1653,7 @@ fn make_param_list(&self, ctx: &AssistContext<'_>, module: hir::Module) -> ast:: fn make_ret_ty(&self, ctx: &AssistContext<'_>, module: hir::Module) -> Option { let fun_ty = self.return_type(ctx); - let handler = if self.contains_tail_expr { - FlowHandler::None - } else { - FlowHandler::from_ret_ty(self, &fun_ty) - }; + let handler = FlowHandler::from_ret_ty(self, &fun_ty); let ret_ty = match &handler { FlowHandler::None => { if matches!(fun_ty, FunType::Unit) { @@ -1728,11 +1723,7 @@ fn make_body( fun: &Function, ) -> ast::BlockExpr { let ret_ty = fun.return_type(ctx); - let handler = if fun.contains_tail_expr { - FlowHandler::None - } else { - FlowHandler::from_ret_ty(fun, &ret_ty) - }; + let handler = FlowHandler::from_ret_ty(fun, &ret_ty); let block = match &fun.body { FunctionBody::Expr(expr) => { @@ -4471,7 +4462,7 @@ async fn foo() -> Result<(), ()> { "#, r#" async fn foo() -> Result<(), ()> { - fun_name().await? + fun_name().await } async fn $0fun_name() -> Result<(), ()> { @@ -4690,7 +4681,7 @@ fn extract_does_not_wrap_res_in_res() { check_assist( extract_function, r#" -//- minicore: result +//- minicore: result, try fn foo() -> Result<(), i64> { $0Result::::Ok(0)?; Ok(())$0 @@ -4698,7 +4689,7 @@ fn foo() -> Result<(), i64> { "#, r#" fn foo() -> Result<(), i64> { - fun_name()? + fun_name() } fn $0fun_name() -> Result<(), i64> { @@ -5753,6 +5744,34 @@ fn $0fun_name(t: T, v: V) -> i32 where T: Into + Copy, V: Into { ); } + #[test] + fn tail_expr_no_extra_control_flow() { + check_assist( + extract_function, + r#" +//- minicore: result +fn fallible() -> Result<(), ()> { + $0if true { + return Err(()); + } + Ok(())$0 +} +"#, + r#" +fn fallible() -> Result<(), ()> { + fun_name() +} + +fn $0fun_name() -> Result<(), ()> { + if true { + return Err(()); + } + Ok(()) +} +"#, + ); + } + #[test] fn non_tail_expr_of_tail_expr_loop() { check_assist( @@ -5800,12 +5819,6 @@ fn non_tail_expr_of_tail_if_block() { extract_function, r#" //- minicore: option, try -impl core::ops::Try for Option { - type Output = T; - type Residual = Option; -} -impl core::ops::FromResidual for Option {} - fn f() -> Option<()> { if true { let a = $0if true { @@ -5820,12 +5833,6 @@ fn f() -> Option<()> { } "#, r#" -impl core::ops::Try for Option { - type Output = T; - type Residual = Option; -} -impl core::ops::FromResidual for Option {} - fn f() -> Option<()> { if true { let a = fun_name()?;; @@ -5852,12 +5859,6 @@ fn tail_expr_of_tail_block_nested() { extract_function, r#" //- minicore: option, try -impl core::ops::Try for Option { - type Output = T; - type Residual = Option; -} -impl core::ops::FromResidual for Option {} - fn f() -> Option<()> { if true { $0{ @@ -5874,15 +5875,9 @@ fn f() -> Option<()> { } "#, r#" -impl core::ops::Try for Option { - type Output = T; - type Residual = Option; -} -impl core::ops::FromResidual for Option {} - fn f() -> Option<()> { if true { - fun_name()? + fun_name() } else { None }