diff --git a/crates/assists/src/handlers/extract_function.rs b/crates/assists/src/handlers/extract_function.rs index c5e6ec7331b..ffa8bd77dc6 100644 --- a/crates/assists/src/handlers/extract_function.rs +++ b/crates/assists/src/handlers/extract_function.rs @@ -2,19 +2,20 @@ use hir::{HirDisplay, Local}; use ide_db::{ defs::{Definition, NameRefClass}, - search::SearchScope, + search::{ReferenceAccess, SearchScope}, }; use itertools::Itertools; use stdx::format_to; use syntax::{ + algo::SyntaxRewriter, ast::{ self, edit::{AstNodeEdit, IndentLevel}, - AstNode, NameOwner, + AstNode, }, Direction, SyntaxElement, SyntaxKind::{self, BLOCK_EXPR, BREAK_EXPR, COMMENT, PATH_EXPR, RETURN_EXPR}, - SyntaxNode, TextRange, + SyntaxNode, TextRange, T, }; use test_utils::mark; @@ -88,16 +89,16 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option let mut self_param = None; let param_pats: Vec<_> = vars_used_in_body .iter() - .map(|node| node.source(ctx.db())) - .filter(|src| { + .map(|node| (node, node.source(ctx.db()))) + .filter(|(_, src)| { src.file_id.original_file(ctx.db()) == ctx.frange.file_id && !body.contains_node(&either_syntax(&src.value)) }) - .filter_map(|src| match src.value { - Either::Left(pat) => Some(pat), + .filter_map(|(&node, src)| match src.value { + Either::Left(_) => Some(node), Either::Right(it) => { // we filter self param, as there can only be one - self_param = Some(it); + self_param = Some((node, it)); None } }) @@ -109,7 +110,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option let vars_defined_in_body = vars_defined_in_body(&body, ctx); - let vars_in_body_used_afterwards: Vec<_> = vars_defined_in_body + let vars_defined_in_body_and_outlive: Vec<_> = vars_defined_in_body .iter() .copied() .filter(|node| { @@ -123,20 +124,27 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option }) .collect(); - let params = param_pats + let params: Vec<_> = param_pats .into_iter() - .map(|pat| { - let name = pat.name().unwrap().to_string(); + .map(|node| { + let usages = Definition::Local(node) + .usages(&ctx.sema) + .in_scope(SearchScope::single_file(ctx.frange.file_id)) + .all(); - let ty = ctx - .sema - .type_of_pat(&pat.into()) - .and_then(|ty| ty.display_source_code(ctx.db(), module.into()).ok()) - .unwrap_or_else(|| "()".to_string()); + let has_usages_afterwards = usages + .iter() + .flat_map(|(_, rs)| rs.iter()) + .any(|reference| body.preceedes_range(reference.range)); + let has_mut_inside_body = usages + .iter() + .flat_map(|(_, rs)| rs.iter()) + .filter(|reference| body.contains_range(reference.range)) + .any(|reference| reference.access == Some(ReferenceAccess::Write)); - Param { name, ty } + Param { node, has_usages_afterwards, has_mut_inside_body, is_copy: true } }) - .collect::>(); + .collect(); let expr = body.tail_expr(); let ret_ty = match expr { @@ -145,7 +153,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option }; let has_unit_ret = ret_ty.as_ref().map_or(true, |it| it.is_unit()); - if stdx::never!(!vars_in_body_used_afterwards.is_empty() && !has_unit_ret) { + if stdx::never!(!vars_defined_in_body_and_outlive.is_empty() && !has_unit_ret) { // We should not have variables that outlive body if we have expression block return None; } @@ -162,11 +170,11 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option move |builder| { let fun = Function { name: "fun_name".to_string(), - self_param, + self_param: self_param.map(|(_, pat)| pat), params, ret_ty, body, - vars_in_body_used_afterwards, + vars_defined_in_body_and_outlive, }; builder.replace(target_range, format_replacement(ctx, &fun)); @@ -183,17 +191,13 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option fn format_replacement(ctx: &AssistContext, fun: &Function) -> String { let mut buf = String::new(); - match fun.vars_in_body_used_afterwards.len() { - 0 => {} - 1 => format_to!( - buf, - "let {} = ", - fun.vars_in_body_used_afterwards[0].name(ctx.db()).unwrap() - ), - _ => { + match fun.vars_defined_in_body_and_outlive.as_slice() { + [] => {} + [var] => format_to!(buf, "let {} = ", var.name(ctx.db()).unwrap()), + [v0, vs @ ..] => { buf.push_str("let ("); - format_to!(buf, "{}", fun.vars_in_body_used_afterwards[0].name(ctx.db()).unwrap()); - for local in fun.vars_in_body_used_afterwards.iter().skip(1) { + format_to!(buf, "{}", v0.name(ctx.db()).unwrap()); + for local in vs { format_to!(buf, ", {}", local.name(ctx.db()).unwrap()); } buf.push_str(") = "); @@ -207,10 +211,10 @@ fn format_replacement(ctx: &AssistContext, fun: &Function) -> String { { let mut it = fun.params.iter(); if let Some(param) = it.next() { - format_to!(buf, "{}", param.name); + format_to!(buf, "{}{}", param.value_prefix(), param.node.name(ctx.db()).unwrap()); } for param in it { - format_to!(buf, ", {}", param.name); + format_to!(buf, ", {}{}", param.value_prefix(), param.node.name(ctx.db()).unwrap()); } } format_to!(buf, ")"); @@ -228,7 +232,7 @@ struct Function { params: Vec, ret_ty: Option, body: FunctionBody, - vars_in_body_used_afterwards: Vec, + vars_defined_in_body_and_outlive: Vec, } impl Function { @@ -242,8 +246,60 @@ fn has_unit_ret(&self) -> bool { #[derive(Debug)] struct Param { - name: String, - ty: String, + node: Local, + has_usages_afterwards: bool, + has_mut_inside_body: bool, + is_copy: bool, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ParamKind { + Value, + MutValue, + SharedRef, + MutRef, +} + +impl ParamKind { + fn is_ref(&self) -> bool { + matches!(self, ParamKind::SharedRef | ParamKind::MutRef) + } +} + +impl Param { + fn kind(&self) -> ParamKind { + match (self.has_usages_afterwards, self.has_mut_inside_body, self.is_copy) { + (true, true, _) => ParamKind::MutRef, + (true, false, false) => ParamKind::SharedRef, + (false, true, _) => ParamKind::MutValue, + (true, false, true) | (false, false, _) => ParamKind::Value, + } + } + + fn value_prefix(&self) -> &'static str { + match self.kind() { + ParamKind::Value => "", + ParamKind::MutValue => "", + ParamKind::SharedRef => "&", + ParamKind::MutRef => "&mut ", + } + } + + fn type_prefix(&self) -> &'static str { + match self.kind() { + ParamKind::Value => "", + ParamKind::MutValue => "", + ParamKind::SharedRef => "&", + ParamKind::MutRef => "&mut ", + } + } + + fn mut_pattern(&self) -> &'static str { + match self.kind() { + ParamKind::MutValue => "mut ", + _ => "", + } + } } fn format_function( @@ -259,10 +315,24 @@ fn format_function( if let Some(self_param) = &fun.self_param { format_to!(fn_def, "{}", self_param); } else if let Some(param) = it.next() { - format_to!(fn_def, "{}: {}", param.name, param.ty); + format_to!( + fn_def, + "{}{}: {}{}", + param.mut_pattern(), + param.node.name(ctx.db()).unwrap(), + param.type_prefix(), + format_type(¶m.node.ty(ctx.db()), ctx, module) + ); } for param in it { - format_to!(fn_def, ", {}: {}", param.name, param.ty); + format_to!( + fn_def, + ", {}{}: {}{}", + param.mut_pattern(), + param.node.name(ctx.db()).unwrap(), + param.type_prefix(), + format_type(¶m.node.ty(ctx.db()), ctx, module) + ); } } @@ -272,7 +342,7 @@ fn format_function( format_to!(fn_def, " -> {}", format_type(ty, ctx, module)); } } else { - match fun.vars_in_body_used_afterwards.as_slice() { + match fun.vars_defined_in_body_and_outlive.as_slice() { [] => {} [var] => { format_to!(fn_def, " -> {}", format_type(&var.ty(ctx.db()), ctx, module)); @@ -292,13 +362,21 @@ fn format_function( FunctionBody::Expr(expr) => { fn_def.push('\n'); let expr = expr.indent(indent); - format_to!(fn_def, "{}{}", indent + 1, expr.syntax()); + let expr = fix_param_usages(ctx, &fun.params, expr.syntax()); + format_to!(fn_def, "{}{}", indent + 1, expr); fn_def.push('\n'); } FunctionBody::Span { elements, leading_indent } => { format_to!(fn_def, "{}", leading_indent); - for e in elements { - format_to!(fn_def, "{}", e); + for element in elements { + match element { + syntax::NodeOrToken::Node(node) => { + format_to!(fn_def, "{}", fix_param_usages(ctx, &fun.params, node)); + } + syntax::NodeOrToken::Token(token) => { + format_to!(fn_def, "{}", token); + } + } } if !fn_def.ends_with('\n') { fn_def.push('\n'); @@ -306,7 +384,7 @@ fn format_function( } } - match fun.vars_in_body_used_afterwards.as_slice() { + match fun.vars_defined_in_body_and_outlive.as_slice() { [] => {} [var] => format_to!(fn_def, "{}{}\n", indent + 1, var.name(ctx.db()).unwrap()), [v0, vs @ ..] => { @@ -327,6 +405,61 @@ fn format_type(ty: &hir::Type, ctx: &AssistContext, module: hir::Module) -> Stri ty.display_source_code(ctx.db(), module.into()).ok().unwrap_or_else(|| "()".to_string()) } +fn fix_param_usages(ctx: &AssistContext, params: &[Param], syntax: &SyntaxNode) -> SyntaxNode { + let mut rewriter = SyntaxRewriter::default(); + for param in params { + if !param.kind().is_ref() { + continue; + } + + let usages = Definition::Local(param.node) + .usages(&ctx.sema) + .in_scope(SearchScope::single_file(ctx.frange.file_id)) + .all(); + let usages = usages + .iter() + .flat_map(|(_, rs)| rs.iter()) + .filter(|reference| syntax.text_range().contains_range(reference.range)); + for reference in usages { + let token = match syntax.token_at_offset(reference.range.start()).right_biased() { + Some(a) => a, + None => { + stdx::never!(false, "cannot find token at variable usage: {:?}", reference); + continue; + } + }; + let path = match token.ancestors().find_map(ast::Expr::cast) { + Some(n) => n, + None => { + stdx::never!(false, "cannot find path parent of variable usage: {:?}", token); + continue; + } + }; + stdx::always!(matches!(path, ast::Expr::PathExpr(_))); + match path.syntax().ancestors().skip(1).find_map(ast::Expr::cast) { + Some(ast::Expr::MethodCallExpr(_)) => { + // do nothing + } + Some(ast::Expr::RefExpr(node)) + if param.kind() == ParamKind::MutRef && node.mut_token().is_some() => + { + rewriter.replace_ast(&node.clone().into(), &node.expr().unwrap()); + } + Some(ast::Expr::RefExpr(node)) + if param.kind() == ParamKind::SharedRef && node.mut_token().is_none() => + { + rewriter.replace_ast(&node.clone().into(), &node.expr().unwrap()); + } + Some(_) | None => { + rewriter.replace_ast(&path, &ast::make::expr_prefix(T![*], path.clone())); + } + }; + } + } + + rewriter.rewrite(syntax) +} + #[derive(Debug)] enum FunctionBody { Expr(ast::Expr), @@ -1112,6 +1245,164 @@ fn $0fun_name(n: i32) -> (i32, i32) { let k = n * n; let m = k + 2; (k, m) +}", + ); + } + + #[test] + fn mut_var_from_outer_scope() { + check_assist( + extract_function, + r" +fn foo() { + let mut n = 1; + $0n += 1;$0 + let m = n + 1; +}", + r" +fn foo() { + let mut n = 1; + fun_name(&mut n); + let m = n + 1; +} + +fn $0fun_name(n: &mut i32) { + *n += 1; +}", + ); + } + + #[test] + fn mut_param_many_usages_stmt() { + check_assist( + extract_function, + r" +fn bar(k: i32) {} +trait I: Copy { + fn succ(&self) -> Self; + fn inc(&mut self) -> Self { let v = self.succ(); *self = v; v } +} +impl I for i32 { + fn succ(&self) -> Self { *self + 1 } +} +fn foo() { + let mut n = 1; + $0n += n; + bar(n); + bar(n+1); + bar(n*n); + bar(&n); + n.inc(); + let v = &mut n; + *v = v.succ(); + n.succ();$0 + let m = n + 1; +}", + r" +fn bar(k: i32) {} +trait I: Copy { + fn succ(&self) -> Self; + fn inc(&mut self) -> Self { let v = self.succ(); *self = v; v } +} +impl I for i32 { + fn succ(&self) -> Self { *self + 1 } +} +fn foo() { + let mut n = 1; + fun_name(&mut n); + let m = n + 1; +} + +fn $0fun_name(n: &mut i32) { + *n += *n; + bar(*n); + bar(*n+1); + bar(*n**n); + bar(&*n); + n.inc(); + let v = n; + *v = v.succ(); + n.succ(); +}", + ); + } + + #[test] + fn mut_param_many_usages_expr() { + check_assist( + extract_function, + r" +fn bar(k: i32) {} +trait I: Copy { + fn succ(&self) -> Self; + fn inc(&mut self) -> Self { let v = self.succ(); *self = v; v } +} +impl I for i32 { + fn succ(&self) -> Self { *self + 1 } +} +fn foo() { + let mut n = 1; + $0{ + n += n; + bar(n); + bar(n+1); + bar(n*n); + bar(&n); + n.inc(); + let v = &mut n; + *v = v.succ(); + n.succ(); + }$0 + let m = n + 1; +}", + r" +fn bar(k: i32) {} +trait I: Copy { + fn succ(&self) -> Self; + fn inc(&mut self) -> Self { let v = self.succ(); *self = v; v } +} +impl I for i32 { + fn succ(&self) -> Self { *self + 1 } +} +fn foo() { + let mut n = 1; + fun_name(&mut n); + let m = n + 1; +} + +fn $0fun_name(n: &mut i32) { + { + *n += *n; + bar(*n); + bar(*n+1); + bar(*n**n); + bar(&*n); + n.inc(); + let v = n; + *v = v.succ(); + n.succ(); + } +}", + ); + } + + #[test] + fn mut_param_by_value() { + check_assist( + extract_function, + r" +fn foo() { + let mut n = 1; + $0n += 1;$0 +}", + r" +fn foo() { + let mut n = 1; + fun_name(n); +} + +fn $0fun_name(mut n: i32) { + n += 1; }", ); } diff --git a/crates/syntax/src/ast/make.rs b/crates/syntax/src/ast/make.rs index b755c969288..1da5a125ed3 100644 --- a/crates/syntax/src/ast/make.rs +++ b/crates/syntax/src/ast/make.rs @@ -487,7 +487,7 @@ pub mod tokens { use crate::{ast, AstNode, Parse, SourceFile, SyntaxKind::*, SyntaxToken}; pub(super) static SOURCE_FILE: Lazy> = - Lazy::new(|| SourceFile::parse("const C: <()>::Item = (1 != 1, 2 == 2, !true)\n;\n\n")); + Lazy::new(|| SourceFile::parse("const C: <()>::Item = (1 != 1, 2 == 2, !true, *p)\n;\n\n")); pub fn single_space() -> SyntaxToken { SOURCE_FILE