diff --git a/crates/ra_assists/src/handlers/auto_import.rs b/crates/ra_assists/src/handlers/auto_import.rs index a9778fab776..9a366414cab 100644 --- a/crates/ra_assists/src/handlers/auto_import.rs +++ b/crates/ra_assists/src/handlers/auto_import.rs @@ -1,15 +1,17 @@ use ra_ide_db::{imports_locator::ImportsLocator, RootDatabase}; -use ra_syntax::ast::{self, AstNode}; +use ra_syntax::{ + ast::{self, AstNode}, + SyntaxNode, +}; use crate::{ assist_ctx::{Assist, AssistCtx}, insert_use_statement, AssistId, }; -use ast::{FnDefOwner, ModuleItem, ModuleItemOwner}; use hir::{ db::{DefDatabase, HirDatabase}, - Adt, AssocContainerId, Crate, Function, HasSource, InFile, ModPath, Module, ModuleDef, - PathResolution, SourceAnalyzer, SourceBinder, Trait, + AssocContainerId, AssocItem, Crate, Function, ModPath, Module, ModuleDef, PathResolution, + SourceAnalyzer, Trait, Type, }; use rustc_hash::FxHashSet; use std::collections::BTreeSet; @@ -34,36 +36,28 @@ use std::collections::BTreeSet; // # pub mod std { pub mod collections { pub struct HashMap { } } } // ``` pub(crate) fn auto_import(ctx: AssistCtx) -> Option { - let path_under_caret: ast::Path = ctx.find_node_at_offset()?; - if path_under_caret.syntax().ancestors().find_map(ast::UseItem::cast).is_some() { - return None; - } - - let module = path_under_caret.syntax().ancestors().find_map(ast::Module::cast); - let position = match module.and_then(|it| it.item_list()) { - Some(item_list) => item_list.syntax().clone(), - None => { - let current_file = - path_under_caret.syntax().ancestors().find_map(ast::SourceFile::cast)?; - current_file.syntax().clone() - } + let auto_import_assets = if let Some(path_under_caret) = ctx.find_node_at_offset::() + { + AutoImportAssets::for_regular_path(path_under_caret, &ctx)? + } else { + AutoImportAssets::for_method_call(ctx.find_node_at_offset()?, &ctx)? }; - let source_analyzer = ctx.source_analyzer(&position, None); - let module_with_name_to_import = source_analyzer.module()?; - let import_candidate = ImportCandidate::new(&path_under_caret, &source_analyzer, ctx.db)?; - let proposed_imports = import_candidate.search_for_imports(ctx.db, module_with_name_to_import); + let proposed_imports = auto_import_assets + .search_for_imports(ctx.db, auto_import_assets.module_with_name_to_import); if proposed_imports.is_empty() { return None; } - let mut group = ctx.add_assist_group(format!("Import {}", import_candidate.get_search_query())); + let mut group = + // TODO kb create another method and add something about traits there + ctx.add_assist_group(format!("Import {}", auto_import_assets.get_search_query())); for import in proposed_imports { group.add_assist(AssistId("auto_import"), format!("Import `{}`", &import), |edit| { - edit.target(path_under_caret.syntax().text_range()); + edit.target(auto_import_assets.syntax_under_caret.text_range()); insert_use_statement( - &position, - path_under_caret.syntax(), + &auto_import_assets.syntax_under_caret, + &auto_import_assets.syntax_under_caret, &import, edit.text_edit_builder(), ); @@ -72,17 +66,204 @@ pub(crate) fn auto_import(ctx: AssistCtx) -> Option { group.finish() } +struct AutoImportAssets { + import_candidate: ImportCandidate, + module_with_name_to_import: Module, + syntax_under_caret: SyntaxNode, +} + +impl AutoImportAssets { + fn for_method_call(method_call: ast::MethodCallExpr, ctx: &AssistCtx) -> Option { + let syntax_under_caret = method_call.syntax().to_owned(); + let source_analyzer = ctx.source_analyzer(&syntax_under_caret, None); + let module_with_name_to_import = source_analyzer.module()?; + Some(Self { + import_candidate: ImportCandidate::for_method_call( + &method_call, + &source_analyzer, + ctx.db, + )?, + module_with_name_to_import, + syntax_under_caret, + }) + } + + fn for_regular_path(path_under_caret: ast::Path, ctx: &AssistCtx) -> Option { + let syntax_under_caret = path_under_caret.syntax().to_owned(); + if syntax_under_caret.ancestors().find_map(ast::UseItem::cast).is_some() { + return None; + } + + let source_analyzer = ctx.source_analyzer(&syntax_under_caret, None); + let module_with_name_to_import = source_analyzer.module()?; + Some(Self { + import_candidate: ImportCandidate::for_regular_path( + &path_under_caret, + &source_analyzer, + ctx.db, + )?, + module_with_name_to_import, + syntax_under_caret, + }) + } + + fn get_search_query(&self) -> String { + match &self.import_candidate { + ImportCandidate::UnqualifiedName(name_ref) + | ImportCandidate::QualifierStart(name_ref) => name_ref.syntax().to_string(), + ImportCandidate::TraitFunction(_, trait_function) => { + trait_function.syntax().to_string() + } + ImportCandidate::TraitMethod(_, trait_method) => trait_method.syntax().to_string(), + } + } + + fn search_for_imports( + &self, + db: &RootDatabase, + module_with_name_to_import: Module, + ) -> BTreeSet { + ImportsLocator::new(db) + .find_imports(&self.get_search_query()) + .into_iter() + .map(|module_def| match &self.import_candidate { + ImportCandidate::TraitFunction(function_callee, _) => { + let mut applicable_traits = Vec::new(); + if let ModuleDef::Function(located_function) = module_def { + let trait_candidates = Self::get_trait_candidates( + db, + located_function, + module_with_name_to_import.krate(), + ) + .into_iter() + .map(|trait_candidate| trait_candidate.into()) + .collect(); + + function_callee.iterate_path_candidates( + db, + module_with_name_to_import.krate(), + &trait_candidates, + None, + |_, assoc| { + if let AssocContainerId::TraitId(trait_id) = assoc.container(db) { + applicable_traits.push( + module_with_name_to_import + .find_use_path(db, ModuleDef::Trait(trait_id.into())), + ); + }; + None::<()> + }, + ); + } + applicable_traits + } + ImportCandidate::TraitMethod(function_callee, _) => { + let mut applicable_traits = Vec::new(); + if let ModuleDef::Function(located_function) = module_def { + let trait_candidates: FxHashSet<_> = Self::get_trait_candidates( + db, + located_function, + module_with_name_to_import.krate(), + ) + .into_iter() + .map(|trait_candidate| trait_candidate.into()) + .collect(); + + if !trait_candidates.is_empty() { + function_callee.iterate_method_candidates( + db, + module_with_name_to_import.krate(), + &trait_candidates, + None, + |_, funciton| { + if let AssocContainerId::TraitId(trait_id) = + funciton.container(db) + { + applicable_traits.push( + module_with_name_to_import.find_use_path( + db, + ModuleDef::Trait(trait_id.into()), + ), + ); + }; + None::<()> + }, + ); + } + } + applicable_traits + } + _ => vec![module_with_name_to_import.find_use_path(db, module_def)], + }) + .flatten() + .filter_map(std::convert::identity) + .filter(|use_path| !use_path.segments.is_empty()) + .take(20) + .collect::>() + } + + fn get_trait_candidates( + db: &RootDatabase, + called_function: Function, + root_crate: Crate, + ) -> FxHashSet { + root_crate + .dependencies(db) + .into_iter() + .map(|dependency| db.crate_def_map(dependency.krate.into())) + .chain(std::iter::once(db.crate_def_map(root_crate.into()))) + .map(|crate_def_map| { + crate_def_map + .modules + .iter() + .map(|(_, module_data)| { + let mut traits = Vec::new(); + for module_def_id in module_data.scope.declarations() { + if let ModuleDef::Trait(trait_candidate) = module_def_id.into() { + if trait_candidate + .items(db) + .into_iter() + .any(|item| item == AssocItem::Function(called_function)) + { + traits.push(trait_candidate) + } + } + } + traits + }) + .flatten() + .collect::>() + }) + .flatten() + .collect() + } +} + #[derive(Debug)] // TODO kb rustdocs enum ImportCandidate { UnqualifiedName(ast::NameRef), QualifierStart(ast::NameRef), - TraitFunction(Adt, ast::PathSegment), + TraitFunction(Type, ast::PathSegment), + TraitMethod(Type, ast::NameRef), } impl ImportCandidate { - // TODO kb refactor this mess - fn new( + fn for_method_call( + method_call: &ast::MethodCallExpr, + source_analyzer: &SourceAnalyzer, + db: &impl HirDatabase, + ) -> Option { + if source_analyzer.resolve_method_call(method_call).is_some() { + return None; + } + Some(Self::TraitMethod( + source_analyzer.type_of(db, &method_call.expr()?)?, + method_call.name_ref()?, + )) + } + + fn for_regular_path( path_under_caret: &ast::Path, source_analyzer: &SourceAnalyzer, db: &impl HirDatabase, @@ -105,7 +286,7 @@ impl ImportCandidate { source_analyzer.resolve_path(db, &qualifier)? }; if let PathResolution::Def(ModuleDef::Adt(function_callee)) = qualifier_resolution { - Some(ImportCandidate::TraitFunction(function_callee, segment)) + Some(ImportCandidate::TraitFunction(function_callee.ty(db), segment)) } else { None } @@ -122,107 +303,6 @@ impl ImportCandidate { } } } - - fn get_search_query(&self) -> String { - match self { - ImportCandidate::UnqualifiedName(name_ref) - | ImportCandidate::QualifierStart(name_ref) => name_ref.syntax().to_string(), - ImportCandidate::TraitFunction(_, trait_function) => { - trait_function.syntax().to_string() - } - } - } - - fn search_for_imports( - &self, - db: &RootDatabase, - module_with_name_to_import: Module, - ) -> BTreeSet { - ImportsLocator::new(db) - .find_imports(&self.get_search_query()) - .into_iter() - .map(|module_def| match self { - ImportCandidate::TraitFunction(function_callee, _) => { - let mut applicable_traits = Vec::new(); - if let ModuleDef::Function(located_function) = module_def { - let trait_candidates = Self::get_trait_candidates( - db, - located_function, - module_with_name_to_import.krate(), - ) - .into_iter() - .map(|trait_candidate| trait_candidate.into()) - .collect(); - - function_callee.ty(db).iterate_path_candidates( - db, - module_with_name_to_import.krate(), - &trait_candidates, - None, - |_, assoc| { - if let AssocContainerId::TraitId(trait_id) = assoc.container(db) { - applicable_traits.push( - module_with_name_to_import - .find_use_path(db, ModuleDef::Trait(trait_id.into())), - ); - }; - None::<()> - }, - ); - } - applicable_traits - } - _ => vec![module_with_name_to_import.find_use_path(db, module_def)], - }) - .flatten() - .filter_map(std::convert::identity) - .filter(|use_path| !use_path.segments.is_empty()) - .take(20) - .collect::>() - } - - fn get_trait_candidates( - db: &RootDatabase, - called_function: Function, - root_crate: Crate, - ) -> FxHashSet { - let mut source_binder = SourceBinder::new(db); - root_crate - .dependencies(db) - .into_iter() - .map(|dependency| db.crate_def_map(dependency.krate.into())) - .chain(std::iter::once(db.crate_def_map(root_crate.into()))) - .map(|crate_def_map| { - crate_def_map - .modules - .iter() - .filter_map(|(_, module_data)| module_data.declaration_source(db)) - .filter_map(|in_file_module| { - Some((in_file_module.file_id, in_file_module.value.item_list()?.items())) - }) - .map(|(file_id, item_list)| { - let mut if_file_trait_defs = Vec::new(); - for module_item in item_list { - if let ModuleItem::TraitDef(trait_def) = module_item { - if let Some(item_list) = trait_def.item_list() { - if item_list - .functions() - .any(|fn_def| fn_def == called_function.source(db).value) - { - if_file_trait_defs.push(InFile::new(file_id, trait_def)) - } - } - } - } - if_file_trait_defs - }) - .flatten() - .filter_map(|in_file_trait_def| source_binder.to_def(in_file_trait_def)) - .collect::>() - }) - .flatten() - .collect() - } } #[cfg(test)] @@ -525,32 +605,25 @@ mod tests { } #[test] - fn not_applicable_for_imported_trait() { + fn not_applicable_for_imported_trait_for_function() { check_assist_not_applicable( auto_import, r" mod test_mod { pub trait TestTrait { - fn test_method(&self); fn test_function(); } - pub trait TestTrait2 { - fn test_method(&self); fn test_function(); } pub enum TestEnum { One, Two, } - impl TestTrait2 for TestEnum { - fn test_method(&self) {} fn test_function() {} } - impl TestTrait for TestEnum { - fn test_method(&self) {} fn test_function() {} } } @@ -580,7 +653,7 @@ mod tests { fn main() { let test_struct = test_mod::TestStruct {}; - test_struct.test_method<|> + test_struct.test_meth<|>od() } ", r" @@ -598,9 +671,42 @@ mod tests { fn main() { let test_struct = test_mod::TestStruct {}; - test_struct.test_method<|> + test_struct.test_meth<|>od() } ", ); } + + #[test] + fn not_applicable_for_imported_trait_for_method() { + check_assist_not_applicable( + auto_import, + r" + mod test_mod { + pub trait TestTrait { + fn test_method(&self); + } + pub trait TestTrait2 { + fn test_method(&self); + } + pub enum TestEnum { + One, + Two, + } + impl TestTrait2 for TestEnum { + fn test_method(&self) {} + } + impl TestTrait for TestEnum { + fn test_method(&self) {} + } + } + + use test_mod::TestTrait2; + fn main() { + let one = test_mod::TestEnum::One; + one.test<|>_method(); + } + ", + ) + } } diff --git a/crates/ra_hir/src/code_model.rs b/crates/ra_hir/src/code_model.rs index 73158b8bd4c..140b3a87f1c 100644 --- a/crates/ra_hir/src/code_model.rs +++ b/crates/ra_hir/src/code_model.rs @@ -548,6 +548,10 @@ impl Function { let mut validator = ExprValidator::new(self.id, infer, sink); validator.validate_body(db); } + + pub fn container(self, db: &impl DefDatabase) -> AssocContainerId { + self.id.lookup(db).container + } } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -699,7 +703,7 @@ impl AssocItem { pub fn container(self, db: &impl DefDatabase) -> AssocContainerId { match self { - AssocItem::Function(f) => f.id.lookup(db).container, + AssocItem::Function(f) => f.container(db), AssocItem::Const(c) => c.id.lookup(db).container, AssocItem::TypeAlias(t) => t.id.lookup(db).container, }