diff --git a/crates/assists/src/handlers/extract_function.rs b/crates/assists/src/handlers/extract_function.rs index dce7ffd7b12..93ff66b246b 100644 --- a/crates/assists/src/handlers/extract_function.rs +++ b/crates/assists/src/handlers/extract_function.rs @@ -60,115 +60,21 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option return None; } - let node = match node { - syntax::NodeOrToken::Node(n) => n, - syntax::NodeOrToken::Token(t) => t.parent(), - }; + let node = element_to_node(node); - let mut body = None; - if node.text_range() == ctx.frange.range { - body = FunctionBody::from_whole_node(node.clone()); - } - if body.is_none() && node.kind() == BLOCK_EXPR { - body = FunctionBody::from_range(&node, ctx.frange.range); - } - if let Some(parent) = node.parent() { - if body.is_none() && parent.kind() == BLOCK_EXPR { - body = FunctionBody::from_range(&parent, ctx.frange.range); - } - } - if body.is_none() { - body = FunctionBody::from_whole_node(node.clone()); - } - if body.is_none() { - body = node.ancestors().find_map(FunctionBody::from_whole_node); - } - let body = body?; + let body = extraction_target(&node, ctx.frange.range)?; let vars_used_in_body = vars_used_in_body(&body, &ctx); - let mut self_param = None; - let param_pats: Vec<_> = vars_used_in_body - .iter() - .map(|node| (node, node.source(ctx.db()))) - .filter(|(_, src)| { - src.file_id.original_file(ctx.db()) == ctx.frange.file_id - && !body.contains_node(&either_syntax(&src.value)) - }) - .filter_map(|(&node, src)| match src.value { - Either::Left(_) => Some(node), - Either::Right(it) => { - // we filter self param, as there can only be one - self_param = Some((node, it)); - None - } - }) - .collect(); + let self_param = self_param_from_usages(ctx, &body, &vars_used_in_body); let anchor = if self_param.is_some() { Anchor::Method } else { Anchor::Freestanding }; - let insert_after = body.scope_for_fn_insertion(anchor)?; + let insert_after = scope_for_fn_insertion(&body, anchor)?; let module = ctx.sema.scope(&insert_after).module()?; - let vars_defined_in_body = vars_defined_in_body(&body, ctx); + let vars_defined_in_body_and_outlive = vars_defined_in_body_and_outlive(ctx, &body); + let ret_ty = body_return_ty(ctx, &body)?; - let vars_defined_in_body_and_outlive: Vec<_> = vars_defined_in_body - .iter() - .copied() - .filter(|node| { - let usages = Definition::Local(*node) - .usages(&ctx.sema) - .in_scope(SearchScope::single_file(ctx.frange.file_id)) - .all(); - let mut usages = usages.iter().flat_map(|(_, rs)| rs.iter()); - - usages.any(|reference| body.preceedes_range(reference.range)) - }) - .collect(); - - let params: Vec<_> = param_pats - .into_iter() - .map(|node| { - let usages = Definition::Local(node) - .usages(&ctx.sema) - .in_scope(SearchScope::single_file(ctx.frange.file_id)) - .all(); - - let has_usages_afterwards = usages - .iter() - .flat_map(|(_, rs)| rs.iter()) - .any(|reference| body.preceedes_range(reference.range)); - let has_mut_inside_body = usages - .iter() - .flat_map(|(_, rs)| rs.iter()) - .filter(|reference| body.contains_range(reference.range)) - .any(|reference| { - if reference.access == Some(ReferenceAccess::Write) { - return true; - } - - let path = path_at_offset(&body, reference); - if is_mut_ref_expr(path.as_ref()).unwrap_or(false) { - return true; - } - - if is_mut_method_call(ctx, path.as_ref()).unwrap_or(false) { - return true; - } - - false - }); - - Param { node, has_usages_afterwards, has_mut_inside_body, is_copy: true } - }) - .collect(); - - let expr = body.tail_expr(); - let ret_ty = match expr { - Some(expr) => Some(ctx.sema.type_of_expr(&expr)?), - None => None, - }; - - let has_unit_ret = ret_ty.as_ref().map_or(true, |it| it.is_unit()); - if stdx::never!(!vars_defined_in_body_and_outlive.is_empty() && !has_unit_ret) { + if stdx::never!(!vars_defined_in_body_and_outlive.is_empty() && !ret_ty.is_unit()) { // We should not have variables that outlive body if we have expression block return None; } @@ -183,6 +89,8 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option "Extract into function", target_range, move |builder| { + let params = extracted_function_params(ctx, &body, &vars_used_in_body); + let fun = Function { name: "fun_name".to_string(), self_param: self_param.map(|(_, pat)| pat), @@ -203,65 +111,19 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option ) } -fn format_replacement(ctx: &AssistContext, fun: &Function) -> String { - let mut buf = String::new(); - - match fun.vars_defined_in_body_and_outlive.as_slice() { - [] => {} - [var] => format_to!(buf, "let {} = ", var.name(ctx.db()).unwrap()), - [v0, vs @ ..] => { - buf.push_str("let ("); - format_to!(buf, "{}", v0.name(ctx.db()).unwrap()); - for local in vs { - format_to!(buf, ", {}", local.name(ctx.db()).unwrap()); - } - buf.push_str(") = "); - } - } - - if fun.self_param.is_some() { - format_to!(buf, "self."); - } - format_to!(buf, "{}(", fun.name); - { - let mut it = fun.params.iter(); - if let Some(param) = it.next() { - format_to!(buf, "{}{}", param.value_prefix(), param.node.name(ctx.db()).unwrap()); - } - for param in it { - format_to!(buf, ", {}{}", param.value_prefix(), param.node.name(ctx.db()).unwrap()); - } - } - format_to!(buf, ")"); - - if fun.has_unit_ret() { - format_to!(buf, ";"); - } - - buf -} - +#[derive(Debug)] struct Function { name: String, self_param: Option, params: Vec, - ret_ty: Option, + ret_ty: RetType, body: FunctionBody, vars_defined_in_body_and_outlive: Vec, } -impl Function { - fn has_unit_ret(&self) -> bool { - match &self.ret_ty { - Some(ty) => ty.is_unit(), - None => true, - } - } -} - #[derive(Debug)] struct Param { - node: Local, + var: Local, has_usages_afterwards: bool, has_mut_inside_body: bool, is_copy: bool, @@ -293,8 +155,7 @@ fn kind(&self) -> ParamKind { fn value_prefix(&self) -> &'static str { match self.kind() { - ParamKind::Value => "", - ParamKind::MutValue => "", + ParamKind::Value | ParamKind::MutValue => "", ParamKind::SharedRef => "&", ParamKind::MutRef => "&mut ", } @@ -302,8 +163,7 @@ fn value_prefix(&self) -> &'static str { fn type_prefix(&self) -> &'static str { match self.kind() { - ParamKind::Value => "", - ParamKind::MutValue => "", + ParamKind::Value | ParamKind::MutValue => "", ParamKind::SharedRef => "&", ParamKind::MutRef => "&mut ", } @@ -317,186 +177,27 @@ fn mut_pattern(&self) -> &'static str { } } -fn format_function( - ctx: &AssistContext, - module: hir::Module, - fun: &Function, - indent: IndentLevel, -) -> String { - let mut fn_def = String::new(); - format_to!(fn_def, "\n\n{}fn $0{}(", indent, fun.name); - { - let mut it = fun.params.iter(); - if let Some(self_param) = &fun.self_param { - format_to!(fn_def, "{}", self_param); - } else if let Some(param) = it.next() { - format_to!( - fn_def, - "{}{}: {}{}", - param.mut_pattern(), - param.node.name(ctx.db()).unwrap(), - param.type_prefix(), - format_type(¶m.node.ty(ctx.db()), ctx, module) - ); - } - for param in it { - format_to!( - fn_def, - ", {}{}: {}{}", - param.mut_pattern(), - param.node.name(ctx.db()).unwrap(), - param.type_prefix(), - format_type(¶m.node.ty(ctx.db()), ctx, module) - ); +#[derive(Debug)] +enum RetType { + Expr(hir::Type), + Stmt, +} + +impl RetType { + fn is_unit(&self) -> bool { + match self { + RetType::Expr(ty) => ty.is_unit(), + RetType::Stmt => true, } } - format_to!(fn_def, ")"); - if !fun.has_unit_ret() { - if let Some(ty) = &fun.ret_ty { - format_to!(fn_def, " -> {}", format_type(ty, ctx, module)); - } - } else { - match fun.vars_defined_in_body_and_outlive.as_slice() { - [] => {} - [var] => { - format_to!(fn_def, " -> {}", format_type(&var.ty(ctx.db()), ctx, module)); - } - [v0, vs @ ..] => { - format_to!(fn_def, " -> ({}", format_type(&v0.ty(ctx.db()), ctx, module)); - for var in vs { - format_to!(fn_def, ", {}", format_type(&var.ty(ctx.db()), ctx, module)); - } - fn_def.push(')'); - } + fn as_fn_ret(&self) -> Option<&hir::Type> { + match self { + RetType::Stmt => None, + RetType::Expr(ty) if ty.is_unit() => None, + RetType::Expr(ty) => Some(ty), } } - fn_def.push_str(" {"); - - match &fun.body { - FunctionBody::Expr(expr) => { - fn_def.push('\n'); - let expr = expr.indent(indent); - let expr = fix_param_usages(ctx, &fun.params, expr.syntax()); - format_to!(fn_def, "{}{}", indent + 1, expr); - fn_def.push('\n'); - } - FunctionBody::Span { elements, leading_indent } => { - format_to!(fn_def, "{}", leading_indent); - for element in elements { - match element { - syntax::NodeOrToken::Node(node) => { - format_to!(fn_def, "{}", fix_param_usages(ctx, &fun.params, node)); - } - syntax::NodeOrToken::Token(token) => { - format_to!(fn_def, "{}", token); - } - } - } - if !fn_def.ends_with('\n') { - fn_def.push('\n'); - } - } - } - - match fun.vars_defined_in_body_and_outlive.as_slice() { - [] => {} - [var] => format_to!(fn_def, "{}{}\n", indent + 1, var.name(ctx.db()).unwrap()), - [v0, vs @ ..] => { - format_to!(fn_def, "{}({}", indent + 1, v0.name(ctx.db()).unwrap()); - for var in vs { - format_to!(fn_def, ", {}", var.name(ctx.db()).unwrap()); - } - fn_def.push_str(")\n"); - } - } - - format_to!(fn_def, "{}}}", indent); - - fn_def -} - -fn format_type(ty: &hir::Type, ctx: &AssistContext, module: hir::Module) -> String { - ty.display_source_code(ctx.db(), module.into()).ok().unwrap_or_else(|| "()".to_string()) -} - -fn path_at_offset(body: &FunctionBody, reference: &FileReference) -> Option { - let var = body.token_at_offset(reference.range.start()).right_biased()?; - let path = var.ancestors().find_map(ast::Expr::cast)?; - stdx::always!(matches!(path, ast::Expr::PathExpr(_))); - Some(path) -} - -fn is_mut_ref_expr(path: Option<&ast::Expr>) -> Option { - let path = path?; - let ref_expr = path.syntax().parent().and_then(ast::RefExpr::cast)?; - Some(ref_expr.mut_token().is_some()) -} - -fn is_mut_method_call(ctx: &AssistContext, path: Option<&ast::Expr>) -> Option { - let path = path?; - let method_call = path.syntax().parent().and_then(ast::MethodCallExpr::cast)?; - - let func = ctx.sema.resolve_method_call(&method_call)?; - let self_param = func.self_param(ctx.db())?; - let access = self_param.access(ctx.db()); - - Some(matches!(access, hir::Access::Exclusive)) -} - -fn fix_param_usages(ctx: &AssistContext, params: &[Param], syntax: &SyntaxNode) -> SyntaxNode { - let mut rewriter = SyntaxRewriter::default(); - for param in params { - if !param.kind().is_ref() { - continue; - } - - let usages = Definition::Local(param.node) - .usages(&ctx.sema) - .in_scope(SearchScope::single_file(ctx.frange.file_id)) - .all(); - let usages = usages - .iter() - .flat_map(|(_, rs)| rs.iter()) - .filter(|reference| syntax.text_range().contains_range(reference.range)); - for reference in usages { - let token = match syntax.token_at_offset(reference.range.start()).right_biased() { - Some(a) => a, - None => { - stdx::never!(false, "cannot find token at variable usage: {:?}", reference); - continue; - } - }; - let path = match token.ancestors().find_map(ast::Expr::cast) { - Some(n) => n, - None => { - stdx::never!(false, "cannot find path parent of variable usage: {:?}", token); - continue; - } - }; - stdx::always!(matches!(path, ast::Expr::PathExpr(_))); - match path.syntax().ancestors().skip(1).find_map(ast::Expr::cast) { - Some(ast::Expr::MethodCallExpr(_)) => { - // do nothing - } - Some(ast::Expr::RefExpr(node)) - if param.kind() == ParamKind::MutRef && node.mut_token().is_some() => - { - rewriter.replace_ast(&node.clone().into(), &node.expr().unwrap()); - } - Some(ast::Expr::RefExpr(node)) - if param.kind() == ParamKind::SharedRef && node.mut_token().is_none() => - { - rewriter.replace_ast(&node.clone().into(), &node.expr().unwrap()); - } - Some(_) | None => { - rewriter.replace_ast(&path, &ast::make::expr_prefix(T![*], path.clone())); - } - }; - } - } - - rewriter.rewrite(syntax) } #[derive(Debug)] @@ -505,11 +206,6 @@ enum FunctionBody { Span { elements: Vec, leading_indent: String }, } -enum Anchor { - Freestanding, - Method, -} - impl FunctionBody { fn from_whole_node(node: SyntaxNode) -> Option { match node.kind() { @@ -568,16 +264,6 @@ fn tail_expr(&self) -> Option { } } - fn scope_for_fn_insertion(&self, anchor: Anchor) -> Option { - match self { - FunctionBody::Expr(e) => scope_for_fn_insertion(e.syntax(), anchor), - FunctionBody::Span { elements, .. } => { - let node = elements.iter().find_map(|e| e.as_node())?; - scope_for_fn_insertion(&node, anchor) - } - } - } - fn descendants(&self) -> impl Iterator + '_ { match self { FunctionBody::Expr(expr) => Either::Right(expr.syntax().descendants()), @@ -590,6 +276,30 @@ fn descendants(&self) -> impl Iterator + '_ { } } + fn text_range(&self) -> TextRange { + match self { + FunctionBody::Expr(expr) => expr.syntax().text_range(), + FunctionBody::Span { elements, .. } => TextRange::new( + elements.first().unwrap().text_range().start(), + elements.last().unwrap().text_range().end(), + ), + } + } + + fn contains_range(&self, range: TextRange) -> bool { + self.text_range().contains_range(range) + } + + fn preceedes_range(&self, range: TextRange) -> bool { + self.text_range().end() <= range.start() + } + + fn contains_node(&self, node: &SyntaxNode) -> bool { + self.contains_range(node.text_range()) + } +} + +impl HasTokenAtOffset for FunctionBody { fn token_at_offset(&self, offset: TextSize) -> TokenAtOffset { match self { FunctionBody::Expr(expr) => expr.syntax().token_at_offset(offset), @@ -621,31 +331,278 @@ fn token_at_offset(&self, offset: TextSize) -> TokenAtOffset { } } } +} - fn text_range(&self) -> TextRange { - match self { - FunctionBody::Expr(expr) => expr.syntax().text_range(), - FunctionBody::Span { elements, .. } => TextRange::new( - elements.first().unwrap().text_range().start(), - elements.last().unwrap().text_range().end(), - ), - } - } - - fn contains_range(&self, range: TextRange) -> bool { - self.text_range().contains_range(range) - } - - fn preceedes_range(&self, range: TextRange) -> bool { - self.text_range().end() <= range.start() - } - - fn contains_node(&self, node: &SyntaxNode) -> bool { - self.contains_range(node.text_range()) +fn element_to_node(node: SyntaxElement) -> SyntaxNode { + match node { + syntax::NodeOrToken::Node(n) => n, + syntax::NodeOrToken::Token(t) => t.parent(), } } -fn scope_for_fn_insertion(node: &SyntaxNode, anchor: Anchor) -> Option { +fn extraction_target(node: &SyntaxNode, selection_range: TextRange) -> Option { + if node.text_range() == selection_range { + let body = FunctionBody::from_whole_node(node.clone()); + if body.is_some() { + return body; + } + } + + if node.kind() == BLOCK_EXPR { + let body = FunctionBody::from_range(&node, selection_range); + if body.is_some() { + return body; + } + } + if let Some(parent) = node.parent() { + if parent.kind() == BLOCK_EXPR { + let body = FunctionBody::from_range(&parent, selection_range); + if body.is_some() { + return body; + } + } + } + + let body = FunctionBody::from_whole_node(node.clone()); + if body.is_some() { + return body; + } + + let body = node.ancestors().find_map(FunctionBody::from_whole_node); + if body.is_some() { + return body; + } + + None +} + +/// Returns a vector of local variables that are referenced in `body` +fn vars_used_in_body(body: &FunctionBody, ctx: &AssistContext) -> Vec { + body.descendants() + .filter_map(ast::NameRef::cast) + .filter_map(|name_ref| NameRefClass::classify(&ctx.sema, &name_ref)) + .map(|name_kind| name_kind.referenced(ctx.db())) + .filter_map(|definition| match definition { + Definition::Local(local) => Some(local), + _ => None, + }) + .unique() + .collect() +} + +fn self_param_from_usages( + ctx: &AssistContext, + body: &FunctionBody, + vars_used_in_body: &[Local], +) -> Option<(Local, ast::SelfParam)> { + let mut iter = vars_used_in_body + .iter() + .filter(|var| var.is_self(ctx.db())) + .map(|var| (var, var.source(ctx.db()))) + .filter(|(_, src)| is_defined_before(ctx, body, src)) + .filter_map(|(&node, src)| match src.value { + Either::Right(it) => Some((node, it)), + Either::Left(_) => { + stdx::never!(false, "Local::is_self returned true, but source is IdentPat"); + None + } + }); + + let self_param = iter.next(); + stdx::always!( + iter.next().is_none(), + "body references two different self params both defined outside" + ); + + self_param +} + +fn extracted_function_params( + ctx: &AssistContext, + body: &FunctionBody, + vars_used_in_body: &[Local], +) -> Vec { + vars_used_in_body + .iter() + .filter(|var| !var.is_self(ctx.db())) + .map(|node| (node, node.source(ctx.db()))) + .filter(|(_, src)| is_defined_before(ctx, body, src)) + .filter_map(|(&node, src)| { + if src.value.is_left() { + Some(node) + } else { + stdx::never!(false, "Local::is_self returned false, but source is SelfParam"); + None + } + }) + .map(|var| { + let usages = LocalUsages::find(ctx, var); + Param { + var, + has_usages_afterwards: has_usages_after_body(&usages, body), + has_mut_inside_body: has_exclusive_usages(ctx, &usages, body), + is_copy: true, + } + }) + .collect() +} + +fn has_usages_after_body(usages: &LocalUsages, body: &FunctionBody) -> bool { + usages.iter().any(|reference| body.preceedes_range(reference.range)) +} + +fn has_exclusive_usages(ctx: &AssistContext, usages: &LocalUsages, body: &FunctionBody) -> bool { + usages + .iter() + .filter(|reference| body.contains_range(reference.range)) + .any(|reference| reference_is_exclusive(reference, body, ctx)) +} + +fn reference_is_exclusive( + reference: &FileReference, + body: &FunctionBody, + ctx: &AssistContext, +) -> bool { + if reference.access == Some(ReferenceAccess::Write) { + return true; + } + + let path = path_at_offset(body, reference); + if is_mut_ref_expr(path.as_ref()).unwrap_or(false) { + return true; + } + + if is_mut_method_call(ctx, path.as_ref()).unwrap_or(false) { + return true; + } + + false +} + +struct LocalUsages(ide_db::search::UsageSearchResult); + +impl LocalUsages { + fn find(ctx: &AssistContext, var: Local) -> Self { + Self( + Definition::Local(var) + .usages(&ctx.sema) + .in_scope(SearchScope::single_file(ctx.frange.file_id)) + .all(), + ) + } + + fn iter(&self) -> impl Iterator + '_ { + self.0.iter().flat_map(|(_, rs)| rs.iter()) + } +} + +trait HasTokenAtOffset { + fn token_at_offset(&self, offset: TextSize) -> TokenAtOffset; +} + +impl HasTokenAtOffset for SyntaxNode { + fn token_at_offset(&self, offset: TextSize) -> TokenAtOffset { + SyntaxNode::token_at_offset(&self, offset) + } +} + +fn path_at_offset(node: &dyn HasTokenAtOffset, reference: &FileReference) -> Option { + let token = node.token_at_offset(reference.range.start()).right_biased().or_else(|| { + stdx::never!(false, "cannot find token at variable usage: {:?}", reference); + None + })?; + let path = token.ancestors().find_map(ast::Expr::cast).or_else(|| { + stdx::never!(false, "cannot find path parent of variable usage: {:?}", token); + None + })?; + stdx::always!(matches!(path, ast::Expr::PathExpr(_))); + Some(path) +} + +fn is_mut_ref_expr(path: Option<&ast::Expr>) -> Option { + let path = path?; + let ref_expr = path.syntax().parent().and_then(ast::RefExpr::cast)?; + Some(ref_expr.mut_token().is_some()) +} + +fn is_mut_method_call(ctx: &AssistContext, path: Option<&ast::Expr>) -> Option { + let path = path?; + let method_call = path.syntax().parent().and_then(ast::MethodCallExpr::cast)?; + + let func = ctx.sema.resolve_method_call(&method_call)?; + let self_param = func.self_param(ctx.db())?; + let access = self_param.access(ctx.db()); + + Some(matches!(access, hir::Access::Exclusive)) +} + +/// Returns a vector of local variables that are defined in `body` +fn vars_defined_in_body(body: &FunctionBody, ctx: &AssistContext) -> Vec { + body.descendants() + .filter_map(ast::IdentPat::cast) + .filter_map(|let_stmt| ctx.sema.to_def(&let_stmt)) + .unique() + .collect() +} + +fn vars_defined_in_body_and_outlive(ctx: &AssistContext, body: &FunctionBody) -> Vec { + let mut vars_defined_in_body = vars_defined_in_body(&body, ctx); + vars_defined_in_body.retain(|var| var_outlives_body(ctx, body, var)); + vars_defined_in_body +} + +fn is_defined_before( + ctx: &AssistContext, + body: &FunctionBody, + src: &hir::InFile>, +) -> bool { + src.file_id.original_file(ctx.db()) == ctx.frange.file_id + && !body.contains_node(&either_syntax(&src.value)) +} + +fn either_syntax(value: &Either) -> &SyntaxNode { + match value { + Either::Left(pat) => pat.syntax(), + Either::Right(it) => it.syntax(), + } +} + +fn var_outlives_body(ctx: &AssistContext, body: &FunctionBody, var: &Local) -> bool { + let usages = Definition::Local(*var) + .usages(&ctx.sema) + .in_scope(SearchScope::single_file(ctx.frange.file_id)) + .all(); + let mut usages = usages.iter().flat_map(|(_, rs)| rs.iter()); + + usages.any(|reference| body.preceedes_range(reference.range)) +} + +fn body_return_ty(ctx: &AssistContext, body: &FunctionBody) -> Option { + match body.tail_expr() { + Some(expr) => { + let ty = ctx.sema.type_of_expr(&expr)?; + Some(RetType::Expr(ty)) + } + None => Some(RetType::Stmt), + } +} +#[derive(Debug)] +enum Anchor { + Freestanding, + Method, +} + +fn scope_for_fn_insertion(body: &FunctionBody, anchor: Anchor) -> Option { + match body { + FunctionBody::Expr(e) => scope_for_fn_insertion_node(e.syntax(), anchor), + FunctionBody::Span { elements, .. } => { + let node = elements.iter().find_map(|e| e.as_node())?; + scope_for_fn_insertion_node(&node, anchor) + } + } +} + +fn scope_for_fn_insertion_node(node: &SyntaxNode, anchor: Anchor) -> Option { let mut ancestors = node.ancestors().peekable(); let mut last_ancestor = None; while let Some(next_ancestor) = ancestors.next() { @@ -674,34 +631,207 @@ fn scope_for_fn_insertion(node: &SyntaxNode, anchor: Anchor) -> Option) -> &SyntaxNode { - match value { - Either::Left(pat) => pat.syntax(), - Either::Right(it) => it.syntax(), +fn format_replacement(ctx: &AssistContext, fun: &Function) -> String { + let mut buf = String::new(); + + match fun.vars_defined_in_body_and_outlive.as_slice() { + [] => {} + [var] => format_to!(buf, "let {} = ", var.name(ctx.db()).unwrap()), + [v0, vs @ ..] => { + buf.push_str("let ("); + format_to!(buf, "{}", v0.name(ctx.db()).unwrap()); + for var in vs { + format_to!(buf, ", {}", var.name(ctx.db()).unwrap()); + } + buf.push_str(") = "); + } + } + + if fun.self_param.is_some() { + format_to!(buf, "self."); + } + format_to!(buf, "{}(", fun.name); + format_arg_list_to(&mut buf, fun, ctx); + format_to!(buf, ")"); + + if fun.ret_ty.is_unit() { + format_to!(buf, ";"); + } + + buf +} + +fn format_arg_list_to(buf: &mut String, fun: &Function, ctx: &AssistContext) { + let mut it = fun.params.iter(); + if let Some(param) = it.next() { + format_arg_to(buf, ctx, param); + } + for param in it { + buf.push_str(", "); + format_arg_to(buf, ctx, param); } } -/// Returns a vector of local variables that are referenced in `body` -fn vars_used_in_body(body: &FunctionBody, ctx: &AssistContext) -> Vec { - body.descendants() - .filter_map(ast::NameRef::cast) - .filter_map(|name_ref| NameRefClass::classify(&ctx.sema, &name_ref)) - .map(|name_kind| name_kind.referenced(ctx.db())) - .filter_map(|definition| match definition { - Definition::Local(local) => Some(local), - _ => None, - }) - .unique() - .collect() +fn format_arg_to(buf: &mut String, ctx: &AssistContext, param: &Param) { + format_to!(buf, "{}{}", param.value_prefix(), param.var.name(ctx.db()).unwrap()); } -/// Returns a vector of local variables that are defined in `body` -fn vars_defined_in_body(body: &FunctionBody, ctx: &AssistContext) -> Vec { - body.descendants() - .filter_map(ast::IdentPat::cast) - .filter_map(|let_stmt| ctx.sema.to_def(&let_stmt)) - .unique() - .collect() +fn format_function( + ctx: &AssistContext, + module: hir::Module, + fun: &Function, + indent: IndentLevel, +) -> String { + let mut fn_def = String::new(); + format_to!(fn_def, "\n\n{}fn $0{}(", indent, fun.name); + format_function_param_list_to(&mut fn_def, ctx, module, fun); + fn_def.push(')'); + format_function_ret_to(&mut fn_def, ctx, module, fun); + fn_def.push_str(" {"); + format_function_body_to(&mut fn_def, ctx, indent, fun); + format_to!(fn_def, "{}}}", indent); + + fn_def +} + +fn format_function_param_list_to( + fn_def: &mut String, + ctx: &AssistContext, + module: hir::Module, + fun: &Function, +) { + let mut it = fun.params.iter(); + if let Some(self_param) = &fun.self_param { + format_to!(fn_def, "{}", self_param); + } else if let Some(param) = it.next() { + format_param_to(fn_def, ctx, module, param); + } + for param in it { + fn_def.push_str(", "); + format_param_to(fn_def, ctx, module, param); + } +} + +fn format_param_to(fn_def: &mut String, ctx: &AssistContext, module: hir::Module, param: &Param) { + format_to!( + fn_def, + "{}{}: {}{}", + param.mut_pattern(), + param.var.name(ctx.db()).unwrap(), + param.type_prefix(), + format_type(¶m.var.ty(ctx.db()), ctx, module) + ); +} + +fn format_function_ret_to( + fn_def: &mut String, + ctx: &AssistContext, + module: hir::Module, + fun: &Function, +) { + if let Some(ty) = fun.ret_ty.as_fn_ret() { + format_to!(fn_def, " -> {}", format_type(ty, ctx, module)); + } else { + match fun.vars_defined_in_body_and_outlive.as_slice() { + [] => {} + [var] => { + format_to!(fn_def, " -> {}", format_type(&var.ty(ctx.db()), ctx, module)); + } + [v0, vs @ ..] => { + format_to!(fn_def, " -> ({}", format_type(&v0.ty(ctx.db()), ctx, module)); + for var in vs { + format_to!(fn_def, ", {}", format_type(&var.ty(ctx.db()), ctx, module)); + } + fn_def.push(')'); + } + } + } +} + +fn format_function_body_to( + fn_def: &mut String, + ctx: &AssistContext, + indent: IndentLevel, + fun: &Function, +) { + match &fun.body { + FunctionBody::Expr(expr) => { + fn_def.push('\n'); + let expr = expr.indent(indent); + let expr = fix_param_usages(ctx, &fun.params, expr.syntax()); + format_to!(fn_def, "{}{}", indent + 1, expr); + fn_def.push('\n'); + } + FunctionBody::Span { elements, leading_indent } => { + format_to!(fn_def, "{}", leading_indent); + for element in elements { + match element { + syntax::NodeOrToken::Node(node) => { + format_to!(fn_def, "{}", fix_param_usages(ctx, &fun.params, node)); + } + syntax::NodeOrToken::Token(token) => { + format_to!(fn_def, "{}", token); + } + } + } + if !fn_def.ends_with('\n') { + fn_def.push('\n'); + } + } + } + + match fun.vars_defined_in_body_and_outlive.as_slice() { + [] => {} + [var] => format_to!(fn_def, "{}{}\n", indent + 1, var.name(ctx.db()).unwrap()), + [v0, vs @ ..] => { + format_to!(fn_def, "{}({}", indent + 1, v0.name(ctx.db()).unwrap()); + for var in vs { + format_to!(fn_def, ", {}", var.name(ctx.db()).unwrap()); + } + fn_def.push_str(")\n"); + } + } +} + +fn format_type(ty: &hir::Type, ctx: &AssistContext, module: hir::Module) -> String { + ty.display_source_code(ctx.db(), module.into()).ok().unwrap_or_else(|| "()".to_string()) +} + +fn fix_param_usages(ctx: &AssistContext, params: &[Param], syntax: &SyntaxNode) -> SyntaxNode { + let mut rewriter = SyntaxRewriter::default(); + for param in params { + if !param.kind().is_ref() { + continue; + } + + let usages = LocalUsages::find(ctx, param.var); + let usages = usages + .iter() + .filter(|reference| syntax.text_range().contains_range(reference.range)) + .filter_map(|reference| path_at_offset(syntax, reference)); + for path in usages { + match path.syntax().ancestors().skip(1).find_map(ast::Expr::cast) { + Some(ast::Expr::MethodCallExpr(_)) => { + // do nothing + } + Some(ast::Expr::RefExpr(node)) + if param.kind() == ParamKind::MutRef && node.mut_token().is_some() => + { + rewriter.replace_ast(&node.clone().into(), &node.expr().unwrap()); + } + Some(ast::Expr::RefExpr(node)) + if param.kind() == ParamKind::SharedRef && node.mut_token().is_none() => + { + rewriter.replace_ast(&node.clone().into(), &node.expr().unwrap()); + } + Some(_) | None => { + rewriter.replace_ast(&path, &ast::make::expr_prefix(T![*], path.clone())); + } + }; + } + } + + rewriter.rewrite(syntax) } #[cfg(test)]