diff --git a/crates/ide_assists/src/assist_context.rs b/crates/ide_assists/src/assist_context.rs index 8714e4978c5..4b0bba2abce 100644 --- a/crates/ide_assists/src/assist_context.rs +++ b/crates/ide_assists/src/assist_context.rs @@ -185,7 +185,29 @@ pub(crate) struct AssistBuilder { source_change: SourceChange, /// Maps the original, immutable `SyntaxNode` to a `clone_for_update` twin. - mutated_tree: Option<(SyntaxNode, SyntaxNode)>, + mutated_tree: Option, +} + +pub(crate) struct TreeMutator { + immutable: SyntaxNode, + mutable_clone: SyntaxNode, +} + +impl TreeMutator { + pub(crate) fn new(immutable: &SyntaxNode) -> TreeMutator { + let immutable = immutable.ancestors().last().unwrap(); + let mutable_clone = immutable.clone_for_update(); + TreeMutator { immutable, mutable_clone } + } + + pub(crate) fn make_mut(&self, node: &N) -> N { + N::cast(self.make_syntax_mut(node.syntax())).unwrap() + } + + pub(crate) fn make_syntax_mut(&self, node: &SyntaxNode) -> SyntaxNode { + let ptr = SyntaxNodePtr::new(node); + ptr.to_node(&self.mutable_clone) + } } impl AssistBuilder { @@ -204,8 +226,8 @@ pub(crate) fn edit_file(&mut self, file_id: FileId) { } fn commit(&mut self) { - if let Some((old, new)) = self.mutated_tree.take() { - algo::diff(&old, &new).into_text_edit(&mut self.edit) + if let Some(tm) = self.mutated_tree.take() { + algo::diff(&tm.immutable, &tm.mutable_clone).into_text_edit(&mut self.edit) } let edit = mem::take(&mut self.edit).finish(); @@ -228,16 +250,7 @@ pub(crate) fn make_ast_mut(&mut self, node: N) -> N { /// phase, and then get their mutable couterparts using `make_mut` in the /// mutable state. pub(crate) fn make_mut(&mut self, node: SyntaxNode) -> SyntaxNode { - let root = &self - .mutated_tree - .get_or_insert_with(|| { - let immutable = node.ancestors().last().unwrap(); - let mutable = immutable.clone_for_update(); - (immutable, mutable) - }) - .1; - let ptr = SyntaxNodePtr::new(&&node); - ptr.to_node(root) + self.mutated_tree.get_or_insert_with(|| TreeMutator::new(&node)).make_syntax_mut(&node) } /// Remove specified `range` of text. diff --git a/crates/ide_assists/src/handlers/extract_function.rs b/crates/ide_assists/src/handlers/extract_function.rs index b30652a9de6..93b28370cae 100644 --- a/crates/ide_assists/src/handlers/extract_function.rs +++ b/crates/ide_assists/src/handlers/extract_function.rs @@ -16,12 +16,13 @@ edit::{AstNodeEdit, IndentLevel}, AstNode, }, + ted, SyntaxKind::{self, BLOCK_EXPR, BREAK_EXPR, COMMENT, PATH_EXPR, RETURN_EXPR}, SyntaxNode, SyntaxToken, TextRange, TextSize, TokenAtOffset, WalkEvent, T, }; use crate::{ - assist_context::{AssistContext, Assists}, + assist_context::{AssistContext, Assists, TreeMutator}, AssistId, }; @@ -1366,7 +1367,10 @@ fn rewrite_body_segment( /// change all usages to account for added `&`/`&mut` for some params fn fix_param_usages(ctx: &AssistContext, params: &[Param], syntax: &SyntaxNode) -> SyntaxNode { - let mut rewriter = SyntaxRewriter::default(); + let mut usages_for_param: Vec<(&Param, Vec)> = Vec::new(); + + let tm = TreeMutator::new(syntax); + for param in params { if !param.kind().is_ref() { continue; @@ -1376,30 +1380,39 @@ fn fix_param_usages(ctx: &AssistContext, params: &[Param], syntax: &SyntaxNode) let usages = usages .iter() .filter(|reference| syntax.text_range().contains_range(reference.range)) - .filter_map(|reference| path_element_of_reference(syntax, reference)); - for path in usages { - match path.syntax().ancestors().skip(1).find_map(ast::Expr::cast) { + .filter_map(|reference| path_element_of_reference(syntax, reference)) + .map(|expr| tm.make_mut(&expr)); + + usages_for_param.push((param, usages.collect())); + } + + let res = tm.make_syntax_mut(syntax); + + for (param, usages) in usages_for_param { + for usage in usages { + match usage.syntax().ancestors().skip(1).find_map(ast::Expr::cast) { Some(ast::Expr::MethodCallExpr(_)) | Some(ast::Expr::FieldExpr(_)) => { // do nothing } Some(ast::Expr::RefExpr(node)) if param.kind() == ParamKind::MutRef && node.mut_token().is_some() => { - rewriter.replace_ast(&node.clone().into(), &node.expr().unwrap()); + ted::replace(node.syntax(), node.expr().unwrap().syntax()); } Some(ast::Expr::RefExpr(node)) if param.kind() == ParamKind::SharedRef && node.mut_token().is_none() => { - rewriter.replace_ast(&node.clone().into(), &node.expr().unwrap()); + ted::replace(node.syntax(), node.expr().unwrap().syntax()); } Some(_) | None => { - rewriter.replace_ast(&path, &make::expr_prefix(T![*], path.clone())); + let p = &make::expr_prefix(T![*], usage.clone()).clone_for_update(); + ted::replace(usage.syntax(), p.syntax()) } - }; + } } } - rewriter.rewrite(syntax) + res } fn update_external_control_flow(handler: &FlowHandler, syntax: &SyntaxNode) -> SyntaxNode {