From b537cb186ed7b200c8ca86a70be81c56ecd154a3 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Thu, 29 Jul 2021 17:17:45 +0200 Subject: [PATCH] Use more strictly typed syntax nodes for analysis in extract_function assist --- Cargo.lock | 1 + crates/ide_assists/Cargo.toml | 1 + .../src/handlers/extract_function.rs | 212 ++++++++++-------- crates/syntax/src/ast/node_ext.rs | 75 +++++++ 4 files changed, 200 insertions(+), 89 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3b27871b477..169961eabf9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -600,6 +600,7 @@ dependencies = [ "expect-test", "hir", "ide_db", + "indexmap", "itertools", "profile", "rustc-hash", diff --git a/crates/ide_assists/Cargo.toml b/crates/ide_assists/Cargo.toml index 05b7811d908..c34798d0ea6 100644 --- a/crates/ide_assists/Cargo.toml +++ b/crates/ide_assists/Cargo.toml @@ -13,6 +13,7 @@ cov-mark = "2.0.0-pre.1" rustc-hash = "1.1.0" itertools = "0.10.0" either = "1.6.1" +indexmap = "1.6.2" stdx = { path = "../stdx", version = "0.0.0" } syntax = { path = "../syntax", version = "0.0.0" } diff --git a/crates/ide_assists/src/handlers/extract_function.rs b/crates/ide_assists/src/handlers/extract_function.rs index abf8329f017..350e204a164 100644 --- a/crates/ide_assists/src/handlers/extract_function.rs +++ b/crates/ide_assists/src/handlers/extract_function.rs @@ -1,13 +1,14 @@ -use std::iter; +use std::{hash::BuildHasherDefault, iter}; use ast::make; use either::Either; -use hir::{HirDisplay, Local}; +use hir::{HirDisplay, Local, Semantics}; use ide_db::{ defs::{Definition, NameRefClass}, search::{FileReference, ReferenceAccess, SearchScope}, + RootDatabase, }; -use itertools::Itertools; +use rustc_hash::FxHasher; use stdx::format_to; use syntax::{ ast::{ @@ -25,6 +26,8 @@ AssistId, }; +type FxIndexSet = indexmap::IndexSet>; + // Assist: extract_function // // Extracts selected statements into new function. @@ -51,7 +54,8 @@ // } // ``` pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { - if ctx.frange.range.is_empty() { + let range = ctx.frange.range; + if range.is_empty() { return None; } @@ -65,11 +69,9 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option syntax::NodeOrToken::Node(n) => n, syntax::NodeOrToken::Token(t) => t.parent()?, }; + let body = extraction_target(&node, range)?; - let body = extraction_target(&node, ctx.frange.range)?; - - let vars_used_in_body = vars_used_in_body(ctx, &body); - let self_param = self_param_from_usages(ctx, &body, &vars_used_in_body); + let (locals_used, has_await, self_param) = analyze_body(&ctx.sema, &body); let anchor = if self_param.is_some() { Anchor::Method } else { Anchor::Freestanding }; let insert_after = scope_for_fn_insertion(&body, anchor)?; @@ -95,7 +97,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option "Extract into function", target_range, move |builder| { - let params = extracted_function_params(ctx, &body, &vars_used_in_body); + let params = extracted_function_params(ctx, &body, locals_used.iter().copied()); let fun = Function { name: "fun_name".to_string(), @@ -109,15 +111,10 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option let new_indent = IndentLevel::from_node(&insert_after); let old_indent = fun.body.indent_level(); - let body_contains_await = body_contains_await(&fun.body); - builder.replace( - target_range, - format_replacement(ctx, &fun, old_indent, body_contains_await), - ); + builder.replace(target_range, format_replacement(ctx, &fun, old_indent, has_await)); - let fn_def = - format_function(ctx, module, &fun, old_indent, new_indent, body_contains_await); + let fn_def = format_function(ctx, module, &fun, old_indent, new_indent, has_await); let insert_offset = insert_after.text_range().end(); match ctx.config.snippet_cap { Some(cap) => builder.insert_snippet(cap, insert_offset, fn_def), @@ -500,15 +497,59 @@ fn tail_expr(&self) -> Option { } } - fn descendants(&self) -> impl Iterator + '_ { + fn walk_expr(&self, cb: &mut dyn FnMut(ast::Expr)) { match self { - FunctionBody::Expr(expr) => Either::Right(expr.syntax().descendants()), - FunctionBody::Span { parent, text_range } => Either::Left( + FunctionBody::Expr(expr) => expr.walk(cb), + FunctionBody::Span { parent, text_range } => { parent - .syntax() - .descendants() - .filter(move |it| text_range.contains_range(it.text_range())), - ), + .statements() + .filter(|stmt| text_range.contains_range(stmt.syntax().text_range())) + .filter_map(|stmt| match stmt { + ast::Stmt::ExprStmt(expr_stmt) => expr_stmt.expr(), + ast::Stmt::Item(_) => None, + ast::Stmt::LetStmt(stmt) => stmt.initializer(), + }) + .for_each(|expr| expr.walk(cb)); + if let Some(expr) = parent + .tail_expr() + .filter(|it| text_range.contains_range(it.syntax().text_range())) + { + expr.walk(cb); + } + } + } + } + + fn walk_pat(&self, cb: &mut dyn FnMut(ast::Pat)) { + match self { + FunctionBody::Expr(expr) => expr.walk_patterns(cb), + FunctionBody::Span { parent, text_range } => { + parent + .statements() + .filter(|stmt| text_range.contains_range(stmt.syntax().text_range())) + .for_each(|stmt| match stmt { + ast::Stmt::ExprStmt(expr_stmt) => { + if let Some(expr) = expr_stmt.expr() { + expr.walk_patterns(cb) + } + } + ast::Stmt::Item(_) => (), + ast::Stmt::LetStmt(stmt) => { + if let Some(pat) = stmt.pat() { + pat.walk(cb); + } + if let Some(expr) = stmt.initializer() { + expr.walk_patterns(cb); + } + } + }); + if let Some(expr) = parent + .tail_expr() + .filter(|it| text_range.contains_range(it.syntax().text_range())) + { + expr.walk_patterns(cb); + } + } } } @@ -622,58 +663,48 @@ fn extraction_target(node: &SyntaxNode, selection_range: TextRange) -> Option Vec { - // FIXME: currently usages inside macros are not found - body.descendants() - .filter_map(ast::NameRef::cast) - .filter_map(|name_ref| NameRefClass::classify(&ctx.sema, &name_ref)) - .map(|name_kind| match name_kind { - NameRefClass::Definition(def) => def, - NameRefClass::FieldShorthand { local_ref, field_ref: _ } => { - Definition::Local(local_ref) - } - }) - .filter_map(|definition| match definition { - Definition::Local(local) => Some(local), - _ => None, - }) - .unique() - .collect() -} - -fn body_contains_await(body: &FunctionBody) -> bool { - body.descendants().any(|d| matches!(d.kind(), SyntaxKind::AWAIT_EXPR)) -} - -/// find `self` param, that was not defined inside `body` -/// -/// It should skip `self` params from impls inside `body` -fn self_param_from_usages( - ctx: &AssistContext, +/// Analyzes a function body, returning the used local variables that are referenced in it as well as +/// whether it contains an await expression. +fn analyze_body( + sema: &Semantics, body: &FunctionBody, - vars_used_in_body: &[Local], -) -> Option<(Local, ast::SelfParam)> { - let mut iter = vars_used_in_body - .iter() - .filter(|var| var.is_self(ctx.db())) - .map(|var| (var, var.source(ctx.db()))) - .filter(|(_, src)| is_defined_before(ctx, body, src)) - .filter_map(|(&node, src)| match src.value { - Either::Right(it) => Some((node, it)), - Either::Left(_) => { - stdx::never!(false, "Local::is_self returned true, but source is IdentPat"); - None +) -> (FxIndexSet, bool, Option<(Local, ast::SelfParam)>) { + // FIXME: currently usages inside macros are not found + let mut has_await = false; + let mut self_param = None; + let mut res = FxIndexSet::default(); + body.walk_expr(&mut |expr| { + has_await |= matches!(expr, ast::Expr::AwaitExpr(_)); + let name_ref = match expr { + ast::Expr::PathExpr(path_expr) => { + path_expr.path().and_then(|it| it.as_single_name_ref()) } - }); - - let self_param = iter.next(); - stdx::always!( - iter.next().is_none(), - "body references two different self params, both defined outside" - ); - - self_param + _ => return, + }; + if let Some(name_ref) = name_ref { + if let Some( + NameRefClass::Definition(Definition::Local(local_ref)) + | NameRefClass::FieldShorthand { local_ref, field_ref: _ }, + ) = NameRefClass::classify(sema, &name_ref) + { + res.insert(local_ref); + if local_ref.is_self(sema.db) { + match local_ref.source(sema.db).value { + Either::Right(it) => { + stdx::always!( + self_param.replace((local_ref, it)).is_none(), + "body references two different self params" + ); + } + Either::Left(_) => { + stdx::never!("Local::is_self returned true, but source is IdentPat"); + } + } + } + } + } + }); + (res, has_await, self_param) } /// find variables that should be extracted as params @@ -682,16 +713,15 @@ fn self_param_from_usages( fn extracted_function_params( ctx: &AssistContext, body: &FunctionBody, - vars_used_in_body: &[Local], + locals: impl Iterator, ) -> Vec { - vars_used_in_body - .iter() - .filter(|var| !var.is_self(ctx.db())) - .map(|node| (node, node.source(ctx.db()))) - .filter(|(_, src)| is_defined_before(ctx, body, src)) - .filter_map(|(&node, src)| { + locals + .filter(|local| !local.is_self(ctx.db())) + .map(|local| (local, local.source(ctx.db()))) + .filter(|(_, src)| is_defined_outside_of_body(ctx, body, src)) + .filter_map(|(local, src)| { if src.value.is_left() { - Some(node) + Some(local) } else { stdx::never!(false, "Local::is_self returned false, but source is SelfParam"); None @@ -838,14 +868,18 @@ fn path_element_of_reference( } /// list local variables defined inside `body` -fn vars_defined_in_body(body: &FunctionBody, ctx: &AssistContext) -> Vec { +fn locals_defined_in_body(body: &FunctionBody, ctx: &AssistContext) -> FxIndexSet { // FIXME: this doesn't work well with macros // see https://github.com/rust-analyzer/rust-analyzer/pull/7535#discussion_r570048550 - body.descendants() - .filter_map(ast::IdentPat::cast) - .filter_map(|let_stmt| ctx.sema.to_def(&let_stmt)) - .unique() - .collect() + let mut res = FxIndexSet::default(); + body.walk_pat(&mut |pat| { + if let ast::Pat::IdentPat(pat) = pat { + if let Some(local) = ctx.sema.to_def(&pat) { + res.insert(local); + } + } + }); + res } /// list local variables defined inside `body` that should be returned from extracted function @@ -854,7 +888,7 @@ fn vars_defined_in_body_and_outlive( body: &FunctionBody, parent: &SyntaxNode, ) -> Vec { - let vars_defined_in_body = vars_defined_in_body(body, ctx); + let vars_defined_in_body = locals_defined_in_body(body, ctx); vars_defined_in_body .into_iter() .filter_map(|var| var_outlives_body(ctx, body, var, parent)) @@ -862,7 +896,7 @@ fn vars_defined_in_body_and_outlive( } /// checks if the relevant local was defined before(outside of) body -fn is_defined_before( +fn is_defined_outside_of_body( ctx: &AssistContext, body: &FunctionBody, src: &hir::InFile>, diff --git a/crates/syntax/src/ast/node_ext.rs b/crates/syntax/src/ast/node_ext.rs index 73d66281802..e9465536c8d 100644 --- a/crates/syntax/src/ast/node_ext.rs +++ b/crates/syntax/src/ast/node_ext.rs @@ -103,6 +103,81 @@ pub fn walk(&self, cb: &mut dyn FnMut(ast::Expr)) { } } } + + /// Preorder walk all the expression's child patterns. + pub fn walk_patterns(&self, cb: &mut dyn FnMut(ast::Pat)) { + let mut preorder = self.syntax().preorder(); + while let Some(event) = preorder.next() { + let node = match event { + WalkEvent::Enter(node) => node, + WalkEvent::Leave(_) => continue, + }; + match ast::Stmt::cast(node.clone()) { + Some(ast::Stmt::LetStmt(l)) => { + if let Some(pat) = l.pat() { + pat.walk(cb); + } + if let Some(expr) = l.initializer() { + expr.walk_patterns(cb); + } + preorder.skip_subtree(); + } + // Don't skip subtree since we want to process the expression child next + Some(ast::Stmt::ExprStmt(_)) => (), + // skip inner items which might have their own patterns + Some(ast::Stmt::Item(_)) => preorder.skip_subtree(), + None => { + // skip const args, those are a different context + if ast::GenericArg::can_cast(node.kind()) { + preorder.skip_subtree(); + } else if let Some(expr) = ast::Expr::cast(node.clone()) { + let is_different_context = match &expr { + ast::Expr::EffectExpr(effect) => { + matches!( + effect.effect(), + ast::Effect::Async(_) + | ast::Effect::Try(_) + | ast::Effect::Const(_) + ) + } + ast::Expr::ClosureExpr(_) => true, + _ => false, + }; + if is_different_context { + preorder.skip_subtree(); + } + } else if let Some(pat) = ast::Pat::cast(node) { + preorder.skip_subtree(); + pat.walk(cb); + } + } + } + } + } +} + +impl ast::Pat { + /// Preorder walk all the pattern's sub patterns. + pub fn walk(&self, cb: &mut dyn FnMut(ast::Pat)) { + let mut preorder = self.syntax().preorder(); + while let Some(event) = preorder.next() { + let node = match event { + WalkEvent::Enter(node) => node, + WalkEvent::Leave(_) => continue, + }; + match ast::Pat::cast(node.clone()) { + Some(ast::Pat::ConstBlockPat(_)) => preorder.skip_subtree(), + Some(pat) => { + cb(pat); + } + // skip const args + None if ast::GenericArg::can_cast(node.kind()) => { + preorder.skip_subtree(); + } + None => (), + } + } + } } #[derive(Debug, PartialEq, Eq, Clone)]