diff --git a/crates/assists/src/handlers/inline_local_variable.rs b/crates/assists/src/handlers/inline_local_variable.rs index 8d28431cf6d..9b228443f86 100644 --- a/crates/assists/src/handlers/inline_local_variable.rs +++ b/crates/assists/src/handlers/inline_local_variable.rs @@ -1,6 +1,5 @@ -use std::collections::HashMap; - use ide_db::{defs::Definition, search::FileReference}; +use rustc_hash::FxHashMap; use syntax::{ ast::{self, AstNode, AstToken}, TextRange, @@ -111,7 +110,7 @@ pub(crate) fn inline_local_variable(acc: &mut Assists, ctx: &AssistContext) -> O .collect::>() .map(|b| (file_id, b)) }) - .collect::>, _>>()?; + .collect::>, _>>()?; let init_str = initializer_expr.syntax().text().to_string(); let init_in_paren = format!("({})", &init_str); diff --git a/crates/ide/src/references.rs b/crates/ide/src/references.rs index 17086f7d48c..a83b82f1b6c 100644 --- a/crates/ide/src/references.rs +++ b/crates/ide/src/references.rs @@ -11,7 +11,7 @@ pub(crate) mod rename; -use hir::Semantics; +use hir::{PathResolution, Semantics}; use ide_db::{ base_db::FileId, defs::{Definition, NameClass, NameRefClass}, @@ -22,7 +22,7 @@ use syntax::{ algo::find_node_at_offset, ast::{self, NameOwner}, - AstNode, SyntaxNode, TextRange, TokenAtOffset, T, + match_ast, AstNode, SyntaxNode, TextRange, T, }; use crate::{display::TryToNav, FilePosition, NavigationTarget}; @@ -47,29 +47,40 @@ pub(crate) fn find_all_refs( let _p = profile::span("find_all_refs"); let syntax = sema.parse(position.file_id).syntax().clone(); - let (opt_name, ctor_filter): (_, Option bool>) = if let Some(name) = - get_struct_def_name_for_struct_literal_search(&sema, &syntax, position) - { - ( - Some(name), - Some(|name_ref| is_record_lit_name_ref(name_ref) || is_call_expr_name_ref(name_ref)), - ) - } else if let Some(name) = get_enum_def_name_for_struct_literal_search(&sema, &syntax, position) - { - (Some(name), Some(is_enum_lit_name_ref)) - } else { - (sema.find_node_at_offset_with_descend::(&syntax, position.offset), None) - }; - - let def = find_def(&sema, &syntax, position, opt_name)?; + let (def, is_literal_search) = + if let Some(name) = get_name_of_item_declaration(&syntax, position) { + (NameClass::classify(sema, &name)?.referenced_or_defined(sema.db), true) + } else { + (find_def(&sema, &syntax, position)?, false) + }; let mut usages = def.usages(sema).set_scope(search_scope).all(); - if let Some(ctor_filter) = ctor_filter { + if is_literal_search { // filter for constructor-literals - usages.references.values_mut().for_each(|it| { - it.retain(|reference| reference.name.as_name_ref().map_or(false, ctor_filter)); - }); - usages.references.retain(|_, it| !it.is_empty()); + let refs = usages.references.values_mut(); + match def { + Definition::ModuleDef(hir::ModuleDef::Adt(hir::Adt::Enum(enum_))) => { + refs.for_each(|it| { + it.retain(|reference| { + reference + .name + .as_name_ref() + .map_or(false, |name_ref| is_enum_lit_name_ref(sema, enum_, name_ref)) + }) + }); + usages.references.retain(|_, it| !it.is_empty()); + } + Definition::ModuleDef(hir::ModuleDef::Adt(_)) + | Definition::ModuleDef(hir::ModuleDef::Variant(_)) => { + refs.for_each(|it| { + it.retain(|reference| { + reference.name.as_name_ref().map_or(false, is_lit_name_ref) + }) + }); + usages.references.retain(|_, it| !it.is_empty()); + } + _ => {} + } } let nav = def.try_to_nav(sema.db)?; let decl_range = nav.focus_or_full_range(); @@ -89,9 +100,9 @@ fn find_def( sema: &Semantics, syntax: &SyntaxNode, position: FilePosition, - opt_name: Option, ) -> Option { - if let Some(name) = opt_name { + if let Some(name) = sema.find_node_at_offset_with_descend::(&syntax, position.offset) + { let class = NameClass::classify(sema, &name)?; Some(class.referenced_or_defined(sema.db)) } else if let Some(lifetime) = @@ -134,95 +145,85 @@ fn decl_access(def: &Definition, syntax: &SyntaxNode, range: TextRange) -> Optio None } -fn get_struct_def_name_for_struct_literal_search( - sema: &Semantics, - syntax: &SyntaxNode, - position: FilePosition, -) -> Option { - if let TokenAtOffset::Between(ref left, ref right) = syntax.token_at_offset(position.offset) { - if right.kind() != T!['{'] && right.kind() != T!['('] { - return None; +fn get_name_of_item_declaration(syntax: &SyntaxNode, position: FilePosition) -> Option { + let token = syntax.token_at_offset(position.offset).right_biased()?; + let kind = token.kind(); + if kind == T![;] { + ast::Struct::cast(token.parent()) + .filter(|struct_| struct_.field_list().is_none()) + .and_then(|struct_| struct_.name()) + } else if kind == T!['{'] { + match_ast! { + match (token.parent()) { + ast::RecordFieldList(rfl) => match_ast! { + match (rfl.syntax().parent()?) { + ast::Variant(it) => it.name(), + ast::Struct(it) => it.name(), + ast::Union(it) => it.name(), + _ => None, + } + }, + ast::VariantList(vl) => ast::Enum::cast(vl.syntax().parent()?)?.name(), + _ => None, + } } - if let Some(name) = - sema.find_node_at_offset_with_descend::(&syntax, left.text_range().start()) - { - return name.syntax().ancestors().find_map(ast::Struct::cast).and_then(|l| l.name()); - } - if sema - .find_node_at_offset_with_descend::( - &syntax, - left.text_range().start(), - ) - .is_some() - { - return left.ancestors().find_map(ast::Struct::cast).and_then(|l| l.name()); + } else if kind == T!['('] { + let tfl = ast::TupleFieldList::cast(token.parent())?; + match_ast! { + match (tfl.syntax().parent()?) { + ast::Variant(it) => it.name(), + ast::Struct(it) => it.name(), + _ => None, + } } + } else { + None } - None } -fn get_enum_def_name_for_struct_literal_search( +fn is_enum_lit_name_ref( sema: &Semantics, - syntax: &SyntaxNode, - position: FilePosition, -) -> Option { - if let TokenAtOffset::Between(ref left, ref right) = syntax.token_at_offset(position.offset) { - if right.kind() != T!['{'] && right.kind() != T!['('] { - return None; - } - if let Some(name) = - sema.find_node_at_offset_with_descend::(&syntax, left.text_range().start()) - { - return name.syntax().ancestors().find_map(ast::Enum::cast).and_then(|l| l.name()); - } - if sema - .find_node_at_offset_with_descend::( - &syntax, - left.text_range().start(), - ) - .is_some() - { - return left.ancestors().find_map(ast::Enum::cast).and_then(|l| l.name()); - } - } - None -} - -fn is_call_expr_name_ref(name_ref: &ast::NameRef) -> bool { + enum_: hir::Enum, + name_ref: &ast::NameRef, +) -> bool { + let path_is_variant_of_enum = |path: ast::Path| { + matches!( + sema.resolve_path(&path), + Some(PathResolution::Def(hir::ModuleDef::Variant(variant))) + if variant.parent_enum(sema.db) == enum_ + ) + }; name_ref .syntax() .ancestors() - .find_map(ast::CallExpr::cast) - .and_then(|c| match c.expr()? { - ast::Expr::PathExpr(p) => { - Some(p.path()?.segment()?.name_ref().as_ref() == Some(name_ref)) + .find_map(|ancestor| { + match_ast! { + match ancestor { + ast::PathExpr(path_expr) => path_expr.path().map(path_is_variant_of_enum), + ast::RecordExpr(record_expr) => record_expr.path().map(path_is_variant_of_enum), + _ => None, + } } - _ => None, }) .unwrap_or(false) } -fn is_record_lit_name_ref(name_ref: &ast::NameRef) -> bool { - name_ref - .syntax() - .ancestors() - .find_map(ast::RecordExpr::cast) - .and_then(|l| l.path()) - .and_then(|p| p.segment()) - .map(|p| p.name_ref().as_ref() == Some(name_ref)) - .unwrap_or(false) +fn path_ends_with(path: Option, name_ref: &ast::NameRef) -> bool { + path.and_then(|path| path.segment()) + .and_then(|segment| segment.name_ref()) + .map_or(false, |segment| segment == *name_ref) } -fn is_enum_lit_name_ref(name_ref: &ast::NameRef) -> bool { - name_ref - .syntax() - .ancestors() - .find_map(ast::PathExpr::cast) - .and_then(|p| p.path()) - .and_then(|p| p.qualifier()) - .and_then(|p| p.segment()) - .map(|p| p.name_ref().as_ref() == Some(name_ref)) - .unwrap_or(false) +fn is_lit_name_ref(name_ref: &ast::NameRef) -> bool { + name_ref.syntax().ancestors().find_map(|ancestor| { + match_ast! { + match ancestor { + ast::PathExpr(path_expr) => Some(path_ends_with(path_expr.path(), name_ref)), + ast::RecordExpr(record_expr) => Some(path_ends_with(record_expr.path(), name_ref)), + _ => None, + } + } + }).unwrap_or(false) } #[cfg(test)] @@ -312,23 +313,92 @@ fn main() { ); } + #[test] + fn test_struct_literal_for_union() { + check( + r#" +union Foo $0{ + x: u32 +} + +fn main() { + let f: Foo; + f = Foo { x: 1 }; +} +"#, + expect![[r#" + Foo Union FileId(0) 0..24 6..9 + + FileId(0) 62..65 + "#]], + ); + } + #[test] fn test_enum_after_space() { check( r#" enum Foo $0{ A, - B, + B(), + C{}, } fn main() { let f: Foo; f = Foo::A; + f = Foo::B(); + f = Foo::C{}; } "#, expect![[r#" - Foo Enum FileId(0) 0..26 5..8 + Foo Enum FileId(0) 0..37 5..8 - FileId(0) 63..66 + FileId(0) 74..77 + FileId(0) 90..93 + FileId(0) 108..111 + "#]], + ); + } + + #[test] + fn test_variant_record_after_space() { + check( + r#" +enum Foo { + A $0{ n: i32 }, + B, +} +fn main() { + let f: Foo; + f = Foo::B; + f = Foo::A { n: 92 }; +} +"#, + expect![[r#" + A Variant FileId(0) 15..27 15..16 + + FileId(0) 95..96 + "#]], + ); + } + #[test] + fn test_variant_tuple_before_paren() { + check( + r#" +enum Foo { + A$0(i32), + B, +} +fn main() { + let f: Foo; + f = Foo::B; + f = Foo::A(92); +} +"#, + expect![[r#" + A Variant FileId(0) 15..21 15..16 + + FileId(0) 89..90 "#]], ); }