diff --git a/crates/assists/src/handlers/extract_function.rs b/crates/assists/src/handlers/extract_function.rs index 1a6cfebed41..09c2a9bc755 100644 --- a/crates/assists/src/handlers/extract_function.rs +++ b/crates/assists/src/handlers/extract_function.rs @@ -68,6 +68,11 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option if body.is_none() && node.kind() == BLOCK_EXPR { body = FunctionBody::from_range(&node, ctx.frange.range); } + if let Some(parent) = node.parent() { + if body.is_none() && parent.kind() == BLOCK_EXPR { + body = FunctionBody::from_range(&parent, ctx.frange.range); + } + } if body.is_none() { body = FunctionBody::from_whole_node(node.clone()); } @@ -76,10 +81,47 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option } let body = body?; - let insert_after = body.scope_for_fn_insertion()?; + let mut self_param = None; + let mut param_pats: Vec<_> = local_variables(&body, &ctx) + .into_iter() + .map(|node| node.source(ctx.db())) + .filter(|src| { + src.file_id.original_file(ctx.db()) == ctx.frange.file_id + && !body.contains_node(&either_syntax(&src.value)) + }) + .filter_map(|src| match src.value { + Either::Left(pat) => Some(pat), + Either::Right(it) => { + // we filter self param, as there can only be one + self_param = Some(it); + None + } + }) + .collect(); + deduplicate_params(&mut param_pats); + let anchor = if self_param.is_some() { Anchor::Method } else { Anchor::Freestanding }; + let insert_after = body.scope_for_fn_insertion(anchor)?; let module = ctx.sema.scope(&insert_after).module()?; + let params = param_pats + .into_iter() + .map(|pat| { + let ty = pat + .pat() + .and_then(|pat| ctx.sema.type_of_pat(&pat)) + .and_then(|ty| ty.display_source_code(ctx.db(), module.into()).ok()) + .unwrap_or_else(|| "()".to_string()); + + let name = pat.name().unwrap().to_string(); + + Param { name, ty } + }) + .collect::>(); + + let self_param = + if let Some(self_param) = self_param { Some(self_param.to_string()) } else { None }; + let expr = body.tail_expr(); let ret_ty = match expr { Some(expr) => { @@ -96,36 +138,12 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option FunctionBody::Span { .. } => ctx.frange.range, }; - let mut params = local_variables(&body, &ctx) - .into_iter() - .map(|node| node.source(ctx.db())) - .filter(|src| src.file_id.original_file(ctx.db()) == ctx.frange.file_id) - .map(|src| match src.value { - Either::Left(pat) => { - (pat.syntax().clone(), pat.name(), ctx.sema.type_of_pat(&pat.into())) - } - Either::Right(it) => (it.syntax().clone(), it.name(), ctx.sema.type_of_self(&it)), - }) - .filter(|(node, _, _)| !body.contains_node(node)) - .map(|(_, name, ty)| { - let ty = ty - .and_then(|ty| ty.display_source_code(ctx.db(), module.into()).ok()) - .unwrap_or_else(|| "()".to_string()); - - let name = name.unwrap().to_string(); - - Param { name, ty } - }) - .collect::>(); - deduplicate_params(&mut params); - acc.add( AssistId("extract_function", crate::AssistKind::RefactorExtract), "Extract into function", target_range, move |builder| { - - let fun = Function { name: "fun_name".to_string(), params, ret_ty, body }; + let fun = Function { name: "fun_name".to_string(), self_param, params, ret_ty, body }; builder.replace(target_range, format_replacement(&fun)); @@ -140,6 +158,9 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option fn format_replacement(fun: &Function) -> String { let mut buf = String::new(); + if fun.self_param.is_some() { + format_to!(buf, "self."); + } format_to!(buf, "{}(", fun.name); { let mut it = fun.params.iter(); @@ -161,6 +182,7 @@ fn format_replacement(fun: &Function) -> String { struct Function { name: String, + self_param: Option, params: Vec, ret_ty: Option, body: FunctionBody, @@ -186,7 +208,9 @@ fn format_function(fun: &Function, indent: IndentLevel) -> String { format_to!(fn_def, "\n\n{}fn $0{}(", indent, fun.name); { let mut it = fun.params.iter(); - if let Some(param) = it.next() { + if let Some(self_param) = &fun.self_param { + format_to!(fn_def, "{}", self_param); + } else if let Some(param) = it.next() { format_to!(fn_def, "{}: {}", param.name, param.ty); } for param in it { @@ -230,6 +254,11 @@ enum FunctionBody { Span { elements: Vec, leading_indent: String }, } +enum Anchor { + Freestanding, + Method, +} + impl FunctionBody { fn from_whole_node(node: SyntaxNode) -> Option { match node.kind() { @@ -288,12 +317,12 @@ fn tail_expr(&self) -> Option { } } - fn scope_for_fn_insertion(&self) -> Option { + fn scope_for_fn_insertion(&self, anchor: Anchor) -> Option { match self { - FunctionBody::Expr(e) => scope_for_fn_insertion(e.syntax()), + FunctionBody::Expr(e) => scope_for_fn_insertion(e.syntax(), anchor), FunctionBody::Span { elements, .. } => { let node = elements.iter().find_map(|e| e.as_node())?; - scope_for_fn_insertion(&node) + scope_for_fn_insertion(&node, anchor) } } } @@ -325,14 +354,25 @@ fn is_node(body: &FunctionBody, n: &SyntaxNode) -> bool { } } -fn scope_for_fn_insertion(node: &SyntaxNode) -> Option { +fn scope_for_fn_insertion(node: &SyntaxNode, anchor: Anchor) -> Option { let mut ancestors = node.ancestors().peekable(); let mut last_ancestor = None; while let Some(next_ancestor) = ancestors.next() { match next_ancestor.kind() { SyntaxKind::SOURCE_FILE => break, SyntaxKind::ITEM_LIST => { - if ancestors.peek().map(|a| a.kind()) == Some(SyntaxKind::MODULE) { + if !matches!(anchor, Anchor::Freestanding) { + continue; + } + if ancestors.peek().map(SyntaxNode::kind) == Some(SyntaxKind::MODULE) { + break; + } + } + SyntaxKind::ASSOC_ITEM_LIST => { + if !matches!(anchor, Anchor::Method) { + continue; + } + if ancestors.peek().map(SyntaxNode::kind) == Some(SyntaxKind::IMPL) { break; } } @@ -343,15 +383,21 @@ fn scope_for_fn_insertion(node: &SyntaxNode) -> Option { last_ancestor } -fn deduplicate_params(params: &mut Vec) { +fn deduplicate_params(params: &mut Vec) { let mut seen_params = FxHashSet::default(); - params.retain(|p| seen_params.insert(p.name.clone())); + params.retain(|p| seen_params.insert(p.clone())); +} + +fn either_syntax(value: &Either) -> &SyntaxNode { + match value { + Either::Left(pat) => pat.syntax(), + Either::Right(it) => it.syntax(), + } } /// Returns a vector of local variables that are refferenced in `body` fn local_variables(body: &FunctionBody, ctx: &AssistContext) -> Vec { - body - .descendants() + body.descendants() .filter_map(ast::NameRef::cast) .filter_map(|name_ref| NameRefClass::classify(&ctx.sema, &name_ref)) .map(|name_kind| name_kind.referenced(ctx.db())) @@ -386,7 +432,7 @@ fn $0fun_name() -> i32 { }"#, ); } - + #[test] fn no_args_from_binary_expr_in_module() { check_assist( @@ -816,4 +862,112 @@ fn $0fun_name() -> u32 { fn return_not_applicable() { check_assist_not_applicable(extract_function, r"fn foo() { $0return$0; } "); } + + #[test] + fn method_to_freestanding() { + check_assist( + extract_function, + r" +struct S; + +impl S { + fn foo(&self) -> i32 { + $01+1$0 + } +}", + r" +struct S; + +impl S { + fn foo(&self) -> i32 { + fun_name() + } +} + +fn $0fun_name() -> i32 { + 1+1 +}", + ); + } + + #[test] + fn method_with_reference() { + check_assist( + extract_function, + r" +struct S { f: i32 }; + +impl S { + fn foo(&self) -> i32 { + $01+self.f$0 + } +}", + r" +struct S { f: i32 }; + +impl S { + fn foo(&self) -> i32 { + self.fun_name() + } + + fn $0fun_name(&self) -> i32 { + 1+self.f + } +}", + ); + } + + #[test] + fn method_with_mut() { + check_assist( + extract_function, + r" +struct S { f: i32 }; + +impl S { + fn foo(&mut self) { + $0self.f += 1;$0 + } +}", + r" +struct S { f: i32 }; + +impl S { + fn foo(&mut self) { + self.fun_name(); + } + + fn $0fun_name(&mut self) { + self.f += 1; + } +}", + ); + } + + #[test] + fn method_with_mut_downgrade_to_shared() { + check_assist( + extract_function, + r" +struct S { f: i32 }; + +impl S { + fn foo(&mut self) -> i32 { + $01+self.f$0 + } +}", + r" +struct S { f: i32 }; + +impl S { + fn foo(&mut self) -> i32 { + self.fun_name() + } + + fn $0fun_name(&self) -> i32 { + 1+self.f + } +}", + ); + } }