From 3983881d4e00c2b12d1b5b0319b4c61d72926917 Mon Sep 17 00:00:00 2001
From: yukang <moorekang@gmail.com>
Date: Tue, 6 Jun 2023 23:51:09 +0800
Subject: [PATCH] take care module name for suggesting surround the struct
 literal in parentheses

---
 .../rustc_parse/src/parser/diagnostics.rs     | 10 +++-
 compiler/rustc_span/src/source_map.rs         | 12 +++++
 tests/ui/parser/issues/issue-111692.rs        | 32 +++++++++++++
 tests/ui/parser/issues/issue-111692.stderr    | 46 +++++++++++++++++++
 4 files changed, 99 insertions(+), 1 deletion(-)
 create mode 100644 tests/ui/parser/issues/issue-111692.rs
 create mode 100644 tests/ui/parser/issues/issue-111692.stderr

diff --git a/compiler/rustc_parse/src/parser/diagnostics.rs b/compiler/rustc_parse/src/parser/diagnostics.rs
index c1454039685..2abd485b1be 100644
--- a/compiler/rustc_parse/src/parser/diagnostics.rs
+++ b/compiler/rustc_parse/src/parser/diagnostics.rs
@@ -751,10 +751,18 @@ impl<'a> Parser<'a> {
                     tail.could_be_bare_literal = true;
                     if maybe_struct_name.is_ident() && can_be_struct_literal {
                         // Account for `if Example { a: one(), }.is_pos() {}`.
+                        // expand `before` so that we take care of module path such as:
+                        // `foo::Bar { ... } `
+                        // we expect to suggest `(foo::Bar { ... })` instead of `foo::(Bar { ... })`
+                        let sm = self.sess.source_map();
+                        let before = maybe_struct_name.span.shrink_to_lo();
+                        let extend_before = sm.span_extend_prev_while(before, |t| {
+                            t.is_alphanumeric() || t == ':' || t == '_'
+                        });
                         Err(self.sess.create_err(StructLiteralNeedingParens {
                             span: maybe_struct_name.span.to(expr.span),
                             sugg: StructLiteralNeedingParensSugg {
-                                before: maybe_struct_name.span.shrink_to_lo(),
+                                before: extend_before.unwrap().shrink_to_lo(),
                                 after: expr.span.shrink_to_hi(),
                             },
                         }))
diff --git a/compiler/rustc_span/src/source_map.rs b/compiler/rustc_span/src/source_map.rs
index 1824510a974..f354751112f 100644
--- a/compiler/rustc_span/src/source_map.rs
+++ b/compiler/rustc_span/src/source_map.rs
@@ -744,6 +744,18 @@ impl SourceMap {
         })
     }
 
+    /// Extends the given `Span` to previous character while the previous character matches the predicate
+    pub fn span_extend_prev_while(
+        &self,
+        span: Span,
+        f: impl Fn(char) -> bool,
+    ) -> Result<Span, SpanSnippetError> {
+        self.span_to_source(span, |s, start, _end| {
+            let n = s[..start].char_indices().rfind(|&(_, c)| !f(c)).map_or(start, |(i, _)| start - i - 1);
+            Ok(span.with_lo(span.lo() - BytePos(n as u32)))
+        })
+    }
+
     /// Extends the given `Span` to just before the next occurrence of `c`.
     pub fn span_extend_to_next_char(&self, sp: Span, c: char, accept_newlines: bool) -> Span {
         if let Ok(next_source) = self.span_to_next_source(sp) {
diff --git a/tests/ui/parser/issues/issue-111692.rs b/tests/ui/parser/issues/issue-111692.rs
new file mode 100644
index 00000000000..56096f706a8
--- /dev/null
+++ b/tests/ui/parser/issues/issue-111692.rs
@@ -0,0 +1,32 @@
+mod module {
+    #[derive(Eq, PartialEq)]
+    pub struct Type {
+        pub x: u8,
+        pub y: u8,
+    }
+
+    pub const C: u8 = 32u8;
+}
+
+fn test(x: module::Type) {
+    if x == module::Type { x: module::C, y: 1 } { //~ ERROR invalid struct literal
+    }
+}
+
+fn test2(x: module::Type) {
+    if x ==module::Type { x: module::C, y: 1 } { //~ ERROR invalid struct literal
+    }
+}
+
+
+fn test3(x: module::Type) {
+    if x == Type { x: module::C, y: 1 } { //~ ERROR invalid struct literal
+    }
+}
+
+fn test4(x: module::Type) {
+    if x == demo_module::Type { x: module::C, y: 1 } { //~ ERROR invalid struct literal
+    }
+}
+
+fn main() { }
diff --git a/tests/ui/parser/issues/issue-111692.stderr b/tests/ui/parser/issues/issue-111692.stderr
new file mode 100644
index 00000000000..7b09d47301d
--- /dev/null
+++ b/tests/ui/parser/issues/issue-111692.stderr
@@ -0,0 +1,46 @@
+error: invalid struct literal
+  --> $DIR/issue-111692.rs:12:21
+   |
+LL |     if x == module::Type { x: module::C, y: 1 } {
+   |                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^
+   |
+help: you might need to surround the struct literal in parentheses
+   |
+LL |     if x == (module::Type { x: module::C, y: 1 }) {
+   |             +                                   +
+
+error: invalid struct literal
+  --> $DIR/issue-111692.rs:17:20
+   |
+LL |     if x ==module::Type { x: module::C, y: 1 } {
+   |                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
+   |
+help: you might need to surround the struct literal in parentheses
+   |
+LL |     if x ==(module::Type { x: module::C, y: 1 }) {
+   |            +                                   +
+
+error: invalid struct literal
+  --> $DIR/issue-111692.rs:23:13
+   |
+LL |     if x == Type { x: module::C, y: 1 } {
+   |             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
+   |
+help: you might need to surround the struct literal in parentheses
+   |
+LL |     if x == (Type { x: module::C, y: 1 }) {
+   |             +                           +
+
+error: invalid struct literal
+  --> $DIR/issue-111692.rs:28:26
+   |
+LL |     if x == demo_module::Type { x: module::C, y: 1 } {
+   |                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
+   |
+help: you might need to surround the struct literal in parentheses
+   |
+LL |     if x == (demo_module::Type { x: module::C, y: 1 }) {
+   |             +                                        +
+
+error: aborting due to 4 previous errors
+