diff --git a/crates/ide-assists/src/handlers/extract_function.rs b/crates/ide-assists/src/handlers/extract_function.rs index 68dcb6c0d23..c9f01ba64a5 100644 --- a/crates/ide-assists/src/handlers/extract_function.rs +++ b/crates/ide-assists/src/handlers/extract_function.rs @@ -2,7 +2,7 @@ use ast::make; use either::Either; -use hir::{HirDisplay, InFile, Local, ModuleDef, Semantics, TypeInfo}; +use hir::{HasSource, HirDisplay, InFile, Local, ModuleDef, Semantics, TypeInfo}; use ide_db::{ defs::{Definition, NameRefClass}, famous_defs::FamousDefs, @@ -27,6 +27,7 @@ use crate::{ assist_context::{AssistContext, Assists, TreeMutator}, + utils::generate_impl_text, AssistId, }; @@ -106,6 +107,8 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option let params = body.extracted_function_params(ctx, &container_info, locals_used.iter().copied()); + let extracted_from_trait_impl = body.extracted_from_trait_impl(); + let name = make_function_name(&semantics_scope); let fun = Function { @@ -124,8 +127,13 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option builder.replace(target_range, make_call(ctx, &fun, old_indent)); - let fn_def = format_function(ctx, module, &fun, old_indent, new_indent); - let insert_offset = insert_after.text_range().end(); + let fn_def = match fun.self_param_adt(ctx) { + Some(adt) if extracted_from_trait_impl => { + let fn_def = format_function(ctx, module, &fun, old_indent, new_indent + 1); + generate_impl_text(&adt, &fn_def).replace("{\n\n", "{") + } + _ => format_function(ctx, module, &fun, old_indent, new_indent), + }; if fn_def.contains("ControlFlow") { let scope = match scope { @@ -150,6 +158,8 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option } } + let insert_offset = insert_after.text_range().end(); + match ctx.config.snippet_cap { Some(cap) => builder.insert_snippet(cap, insert_offset, fn_def), None => builder.insert(insert_offset, fn_def), @@ -381,6 +391,14 @@ fn return_type(&self, ctx: &AssistContext) -> FunType { }, } } + + fn self_param_adt(&self, ctx: &AssistContext) -> Option { + let self_param = self.self_param.as_ref()?; + let def = ctx.sema.to_def(self_param)?; + let adt = def.ty(ctx.db()).strip_references().as_adt()?; + let InFile { file_id: _, value } = adt.source(ctx.db())?; + Some(value) + } } impl ParamKind { @@ -485,6 +503,20 @@ fn parent(&self) -> Option { } } + fn node(&self) -> &SyntaxNode { + match self { + FunctionBody::Expr(e) => e.syntax(), + FunctionBody::Span { parent, .. } => parent.syntax(), + } + } + + fn extracted_from_trait_impl(&self) -> bool { + match self.node().ancestors().find_map(ast::Impl::cast) { + Some(c) => return c.trait_().is_some(), + None => false, + } + } + fn from_expr(expr: ast::Expr) -> Option { match expr { ast::Expr::BreakExpr(it) => it.expr().map(Self::Expr), @@ -1111,10 +1143,7 @@ fn either_syntax(value: &Either) -> &SyntaxNode { /// /// Function should be put right after returned node fn node_to_insert_after(body: &FunctionBody, anchor: Anchor) -> Option { - let node = match body { - FunctionBody::Expr(e) => e.syntax(), - FunctionBody::Span { parent, .. } => parent.syntax(), - }; + let node = body.node(); let mut ancestors = node.ancestors().peekable(); let mut last_ancestor = None; while let Some(next_ancestor) = ancestors.next() { @@ -1126,9 +1155,8 @@ fn node_to_insert_after(body: &FunctionBody, anchor: Anchor) -> Option { - continue; - } + SyntaxKind::ASSOC_ITEM_LIST if !matches!(anchor, Anchor::Method) => continue, + SyntaxKind::ASSOC_ITEM_LIST if body.extracted_from_trait_impl() => continue, SyntaxKind::ASSOC_ITEM_LIST => { if ancestors.peek().map(SyntaxNode::kind) == Some(SyntaxKind::IMPL) { break; @@ -4777,6 +4805,43 @@ fn fun_name() { fn $0fun_name2() { let x = 0; } +"#, + ); + } + + #[test] + fn extract_method_from_trait_impl() { + check_assist( + extract_function, + r#" +struct Struct(i32); +trait Trait { + fn bar(&self) -> i32; +} + +impl Trait for Struct { + fn bar(&self) -> i32 { + $0self.0 + 2$0 + } +} +"#, + r#" +struct Struct(i32); +trait Trait { + fn bar(&self) -> i32; +} + +impl Trait for Struct { + fn bar(&self) -> i32 { + self.fun_name() + } +} + +impl Struct { + fn $0fun_name(&self) -> i32 { + self.0 + 2 + } +} "#, ); }