From e55a44a831477e2fc8e11340c3d91db883b97c8e Mon Sep 17 00:00:00 2001
From: Lukas Wirth <lukastw97@gmail.com>
Date: Sat, 14 Nov 2020 17:49:36 +0100
Subject: [PATCH] Use shorthand record syntax when renaming struct initializer
 field

---
 crates/ide/src/references/rename.rs | 49 ++++++++++++++++++++++++++---
 crates/ide_db/src/search.rs         | 16 ++++++----
 crates/syntax/src/ast/expr_ext.rs   | 12 +++++++
 crates/syntax/src/ast/node_ext.rs   | 10 +-----
 4 files changed, 68 insertions(+), 19 deletions(-)

diff --git a/crates/ide/src/references/rename.rs b/crates/ide/src/references/rename.rs
index 26ac2371a52..b3ade20ef0a 100644
--- a/crates/ide/src/references/rename.rs
+++ b/crates/ide/src/references/rename.rs
@@ -106,7 +106,11 @@ fn find_module_at_offset(
     Some(module)
 }
 
-fn source_edit_from_reference(reference: Reference, new_name: &str) -> SourceFileEdit {
+fn source_edit_from_reference(
+    sema: &Semantics<RootDatabase>,
+    reference: Reference,
+    new_name: &str,
+) -> SourceFileEdit {
     let mut replacement_text = String::new();
     let file_id = reference.file_range.file_id;
     let range = match reference.kind {
@@ -122,6 +126,22 @@ fn source_edit_from_reference(reference: Reference, new_name: &str) -> SourceFil
             replacement_text.push_str(new_name);
             TextRange::new(reference.file_range.range.end(), reference.file_range.range.end())
         }
+        ReferenceKind::RecordExprField => {
+            replacement_text.push_str(new_name);
+            let mut range = reference.file_range.range;
+            if let Some(field_expr) = syntax::algo::find_node_at_range::<ast::RecordExprField>(
+                sema.parse(file_id).syntax(),
+                reference.file_range.range,
+            ) {
+                // use shorthand initializer if we were to write foo: foo
+                if let Some(name) = field_expr.expr().and_then(|e| e.name_ref()) {
+                    if &name.to_string() == new_name {
+                        range = field_expr.syntax().text_range();
+                    }
+                }
+            }
+            range
+        }
         _ => {
             replacement_text.push_str(new_name);
             reference.file_range.range
@@ -170,7 +190,7 @@ fn rename_mod(
     let ref_edits = refs
         .references
         .into_iter()
-        .map(|reference| source_edit_from_reference(reference, new_name));
+        .map(|reference| source_edit_from_reference(sema, reference, new_name));
     source_file_edits.extend(ref_edits);
 
     Ok(RangeInfo::new(range, SourceChange::from_edits(source_file_edits, file_system_edits)))
@@ -211,7 +231,7 @@ fn rename_to_self(
 
     let mut edits = usages
         .into_iter()
-        .map(|reference| source_edit_from_reference(reference, "self"))
+        .map(|reference| source_edit_from_reference(sema, reference, "self"))
         .collect::<Vec<_>>();
 
     edits.push(SourceFileEdit {
@@ -300,7 +320,7 @@ fn rename_reference(
 
     let edit = refs
         .into_iter()
-        .map(|reference| source_edit_from_reference(reference, new_name))
+        .map(|reference| source_edit_from_reference(sema, reference, new_name))
         .collect::<Vec<_>>();
 
     if edit.is_empty() {
@@ -1094,6 +1114,27 @@ impl Foo {
         foo.i
     }
 }
+"#,
+        );
+    }
+
+    #[test]
+    fn test_initializer_use_field_init_shorthand() {
+        check(
+            "bar",
+            r#"
+struct Foo { i<|>: i32 }
+
+fn foo(bar: i32) -> Foo {
+    Foo { i: bar }
+}
+"#,
+            r#"
+struct Foo { bar: i32 }
+
+fn foo(bar: i32) -> Foo {
+    Foo { bar }
+}
 "#,
         );
     }
diff --git a/crates/ide_db/src/search.rs b/crates/ide_db/src/search.rs
index a243352406e..4248606c8a3 100644
--- a/crates/ide_db/src/search.rs
+++ b/crates/ide_db/src/search.rs
@@ -30,6 +30,7 @@ pub enum ReferenceKind {
     FieldShorthandForField,
     FieldShorthandForLocal,
     StructLiteral,
+    RecordExprField,
     Other,
 }
 
@@ -278,12 +279,15 @@ impl<'a> FindUsages<'a> {
     ) -> bool {
         match NameRefClass::classify(self.sema, &name_ref) {
             Some(NameRefClass::Definition(def)) if &def == self.def => {
-                let kind = if is_record_lit_name_ref(&name_ref) || is_call_expr_name_ref(&name_ref)
-                {
-                    ReferenceKind::StructLiteral
-                } else {
-                    ReferenceKind::Other
-                };
+                let kind =
+                    if name_ref.syntax().parent().and_then(ast::RecordExprField::cast).is_some() {
+                        ReferenceKind::RecordExprField
+                    } else if is_record_lit_name_ref(&name_ref) || is_call_expr_name_ref(&name_ref)
+                    {
+                        ReferenceKind::StructLiteral
+                    } else {
+                        ReferenceKind::Other
+                    };
 
                 let reference = Reference {
                     file_range: self.sema.original_range(name_ref.syntax()),
diff --git a/crates/syntax/src/ast/expr_ext.rs b/crates/syntax/src/ast/expr_ext.rs
index 9253c97d084..e4a9b945c98 100644
--- a/crates/syntax/src/ast/expr_ext.rs
+++ b/crates/syntax/src/ast/expr_ext.rs
@@ -22,6 +22,18 @@ impl ast::Expr {
             _ => false,
         }
     }
+
+    pub fn name_ref(&self) -> Option<ast::NameRef> {
+        if let ast::Expr::PathExpr(expr) = self {
+            let path = expr.path()?;
+            let segment = path.segment()?;
+            let name_ref = segment.name_ref()?;
+            if path.qualifier().is_none() {
+                return Some(name_ref);
+            }
+        }
+        None
+    }
 }
 
 #[derive(Debug, Clone, PartialEq, Eq)]
diff --git a/crates/syntax/src/ast/node_ext.rs b/crates/syntax/src/ast/node_ext.rs
index ce35ac01afd..b70b840b81a 100644
--- a/crates/syntax/src/ast/node_ext.rs
+++ b/crates/syntax/src/ast/node_ext.rs
@@ -203,15 +203,7 @@ impl ast::RecordExprField {
         if let Some(name_ref) = self.name_ref() {
             return Some(name_ref);
         }
-        if let Some(ast::Expr::PathExpr(expr)) = self.expr() {
-            let path = expr.path()?;
-            let segment = path.segment()?;
-            let name_ref = segment.name_ref()?;
-            if path.qualifier().is_none() {
-                return Some(name_ref);
-            }
-        }
-        None
+        self.expr()?.name_ref()
     }
 }