From fbe04244f8256b6dad38d8bcaac96c8a3a754a04 Mon Sep 17 00:00:00 2001
From: ding-young <lsyhime@snu.ac.kr>
Date: Wed, 7 Aug 2024 18:00:02 +0900
Subject: [PATCH] update rewrite_chain to return RewriteResult

---
 src/chains.rs | 161 +++++++++++++++++++++++++++++++++-----------------
 src/expr.rs   |   2 +-
 2 files changed, 107 insertions(+), 56 deletions(-)

diff --git a/src/chains.rs b/src/chains.rs
index 65326527000..a3f70055425 100644
--- a/src/chains.rs
+++ b/src/chains.rs
@@ -79,6 +79,9 @@ use thin_vec::ThinVec;
 /// Provides the original input contents from the span
 /// of a chain element with trailing spaces trimmed.
 fn format_overflow_style(span: Span, context: &RewriteContext<'_>) -> Option<String> {
+    // TODO(ding-young): Currently returning None when the given span is out of the range
+    // covered by the snippet provider. If this is a common cause for internal
+    // rewrite failure, add a new enum variant and return RewriteError instead of None
     context.snippet_provider.span_to_snippet(span).map(|s| {
         s.lines()
             .map(|l| l.trim_end())
@@ -92,12 +95,16 @@ fn format_chain_item(
     context: &RewriteContext<'_>,
     rewrite_shape: Shape,
     allow_overflow: bool,
-) -> Option<String> {
+) -> RewriteResult {
     if allow_overflow {
-        item.rewrite(context, rewrite_shape)
-            .or_else(|| format_overflow_style(item.span, context))
+        // TODO(ding-young): Consider calling format_overflow_style()
+        // only when item.rewrite_result() returns RewriteError::ExceedsMaxWidth.
+        // It may be inappropriate to call format_overflow_style on other RewriteError
+        // since the current approach retries formatting if allow_overflow is true
+        item.rewrite_result(context, rewrite_shape)
+            .or_else(|_| format_overflow_style(item.span, context).unknown_error())
     } else {
-        item.rewrite(context, rewrite_shape)
+        item.rewrite_result(context, rewrite_shape)
     }
 }
 
@@ -134,17 +141,17 @@ pub(crate) fn rewrite_chain(
     expr: &ast::Expr,
     context: &RewriteContext<'_>,
     shape: Shape,
-) -> Option<String> {
+) -> RewriteResult {
     let chain = Chain::from_ast(expr, context);
     debug!("rewrite_chain {:?} {:?}", chain, shape);
 
     // If this is just an expression with some `?`s, then format it trivially and
     // return early.
     if chain.children.is_empty() {
-        return chain.parent.rewrite(context, shape);
+        return chain.parent.rewrite_result(context, shape);
     }
 
-    chain.rewrite(context, shape)
+    chain.rewrite_result(context, shape)
 }
 
 #[derive(Debug)]
@@ -524,6 +531,10 @@ impl Chain {
 
 impl Rewrite for Chain {
     fn rewrite(&self, context: &RewriteContext<'_>, shape: Shape) -> Option<String> {
+        self.rewrite_result(context, shape).ok()
+    }
+
+    fn rewrite_result(&self, context: &RewriteContext<'_>, shape: Shape) -> RewriteResult {
         debug!("rewrite chain {:?} {:?}", self, shape);
 
         let mut formatter = match context.config.indent_style() {
@@ -537,17 +548,25 @@ impl Rewrite for Chain {
 
         formatter.format_root(&self.parent, context, shape)?;
         if let Some(result) = formatter.pure_root() {
-            return wrap_str(result, context.config.max_width(), shape);
+            return wrap_str(result, context.config.max_width(), shape)
+                .max_width_error(shape.width, self.parent.span);
         }
 
+        let first = self.children.first().unwrap_or(&self.parent);
+        let last = self.children.last().unwrap_or(&self.parent);
+        let children_span = mk_sp(first.span.lo(), last.span.hi());
+        let full_span = self.parent.span.with_hi(children_span.hi());
+
         // Decide how to layout the rest of the chain.
-        let child_shape = formatter.child_shape(context, shape)?;
+        let child_shape = formatter
+            .child_shape(context, shape)
+            .max_width_error(shape.width, children_span)?;
 
         formatter.format_children(context, child_shape)?;
         formatter.format_last_child(context, shape, child_shape)?;
 
         let result = formatter.join_rewrites(context, child_shape)?;
-        wrap_str(result, context.config.max_width(), shape)
+        wrap_str(result, context.config.max_width(), shape).max_width_error(shape.width, full_span)
     }
 }
 
@@ -569,16 +588,20 @@ trait ChainFormatter {
         parent: &ChainItem,
         context: &RewriteContext<'_>,
         shape: Shape,
-    ) -> Option<()>;
+    ) -> Result<(), RewriteError>;
     fn child_shape(&self, context: &RewriteContext<'_>, shape: Shape) -> Option<Shape>;
-    fn format_children(&mut self, context: &RewriteContext<'_>, child_shape: Shape) -> Option<()>;
+    fn format_children(
+        &mut self,
+        context: &RewriteContext<'_>,
+        child_shape: Shape,
+    ) -> Result<(), RewriteError>;
     fn format_last_child(
         &mut self,
         context: &RewriteContext<'_>,
         shape: Shape,
         child_shape: Shape,
-    ) -> Option<()>;
-    fn join_rewrites(&self, context: &RewriteContext<'_>, child_shape: Shape) -> Option<String>;
+    ) -> Result<(), RewriteError>;
+    fn join_rewrites(&self, context: &RewriteContext<'_>, child_shape: Shape) -> RewriteResult;
     // Returns `Some` if the chain is only a root, None otherwise.
     fn pure_root(&mut self) -> Option<String>;
 }
@@ -621,12 +644,16 @@ impl<'a> ChainFormatterShared<'a> {
         }
     }
 
-    fn format_children(&mut self, context: &RewriteContext<'_>, child_shape: Shape) -> Option<()> {
+    fn format_children(
+        &mut self,
+        context: &RewriteContext<'_>,
+        child_shape: Shape,
+    ) -> Result<(), RewriteError> {
         for item in &self.children[..self.children.len() - 1] {
             let rewrite = format_chain_item(item, context, child_shape, self.allow_overflow)?;
             self.rewrites.push(rewrite);
         }
-        Some(())
+        Ok(())
     }
 
     // Rewrite the last child. The last child of a chain requires special treatment. We need to
@@ -667,8 +694,8 @@ impl<'a> ChainFormatterShared<'a> {
         context: &RewriteContext<'_>,
         shape: Shape,
         child_shape: Shape,
-    ) -> Option<()> {
-        let last = self.children.last()?;
+    ) -> Result<(), RewriteError> {
+        let last = self.children.last().unknown_error()?;
         let extendable = may_extend && last_line_extendable(&self.rewrites[0]);
         let prev_last_line_width = last_line_width(&self.rewrites[0]);
 
@@ -692,11 +719,17 @@ impl<'a> ChainFormatterShared<'a> {
             && self.rewrites.iter().all(|s| !s.contains('\n'))
             && one_line_budget > 0;
         let last_shape = if all_in_one_line {
-            shape.sub_width(last.tries)?
+            shape
+                .sub_width(last.tries)
+                .max_width_error(shape.width, last.span)?
         } else if extendable {
-            child_shape.sub_width(last.tries)?
+            child_shape
+                .sub_width(last.tries)
+                .max_width_error(child_shape.width, last.span)?
         } else {
-            child_shape.sub_width(shape.rhs_overhead(context.config) + last.tries)?
+            child_shape
+                .sub_width(shape.rhs_overhead(context.config) + last.tries)
+                .max_width_error(child_shape.width, last.span)?
         };
 
         let mut last_subexpr_str = None;
@@ -712,7 +745,7 @@ impl<'a> ChainFormatterShared<'a> {
             };
 
             if let Some(one_line_shape) = one_line_shape {
-                if let Some(rw) = last.rewrite(context, one_line_shape) {
+                if let Ok(rw) = last.rewrite_result(context, one_line_shape) {
                     // We allow overflowing here only if both of the following conditions match:
                     // 1. The entire chain fits in a single line except the last child.
                     // 2. `last_child_str.lines().count() >= 5`.
@@ -727,17 +760,18 @@ impl<'a> ChainFormatterShared<'a> {
                         // last child on its own line, and compare two rewrites to choose which is
                         // better.
                         let last_shape = child_shape
-                            .sub_width(shape.rhs_overhead(context.config) + last.tries)?;
-                        match last.rewrite(context, last_shape) {
-                            Some(ref new_rw) if !could_fit_single_line => {
+                            .sub_width(shape.rhs_overhead(context.config) + last.tries)
+                            .max_width_error(child_shape.width, last.span)?;
+                        match last.rewrite_result(context, last_shape) {
+                            Ok(ref new_rw) if !could_fit_single_line => {
                                 last_subexpr_str = Some(new_rw.clone());
                             }
-                            Some(ref new_rw) if new_rw.lines().count() >= line_count => {
+                            Ok(ref new_rw) if new_rw.lines().count() >= line_count => {
                                 last_subexpr_str = Some(rw);
                                 self.fits_single_line = could_fit_single_line && all_in_one_line;
                             }
-                            new_rw @ Some(..) => {
-                                last_subexpr_str = new_rw;
+                            Ok(new_rw) => {
+                                last_subexpr_str = Some(new_rw);
                             }
                             _ => {
                                 last_subexpr_str = Some(rw);
@@ -752,22 +786,28 @@ impl<'a> ChainFormatterShared<'a> {
         let last_shape = if context.use_block_indent() {
             last_shape
         } else {
-            child_shape.sub_width(shape.rhs_overhead(context.config) + last.tries)?
+            child_shape
+                .sub_width(shape.rhs_overhead(context.config) + last.tries)
+                .max_width_error(child_shape.width, last.span)?
         };
 
-        last_subexpr_str = last_subexpr_str.or_else(|| last.rewrite(context, last_shape));
-        self.rewrites.push(last_subexpr_str?);
-        Some(())
+        let last_subexpr_str =
+            last_subexpr_str.unwrap_or(last.rewrite_result(context, last_shape)?);
+        self.rewrites.push(last_subexpr_str);
+        Ok(())
     }
 
-    fn join_rewrites(&self, context: &RewriteContext<'_>, child_shape: Shape) -> Option<String> {
+    fn join_rewrites(&self, context: &RewriteContext<'_>, child_shape: Shape) -> RewriteResult {
         let connector = if self.fits_single_line {
             // Yay, we can put everything on one line.
             Cow::from("")
         } else {
             // Use new lines.
             if context.force_one_line_chain.get() {
-                return None;
+                return Err(RewriteError::ExceedsMaxWidth {
+                    configured_width: child_shape.width,
+                    span: self.children.last().unknown_error()?.span,
+                });
             }
             child_shape.to_string_with_newline(context.config)
         };
@@ -786,7 +826,7 @@ impl<'a> ChainFormatterShared<'a> {
             result.push_str(rewrite);
         }
 
-        Some(result)
+        Ok(result)
     }
 }
 
@@ -811,8 +851,8 @@ impl<'a> ChainFormatter for ChainFormatterBlock<'a> {
         parent: &ChainItem,
         context: &RewriteContext<'_>,
         shape: Shape,
-    ) -> Option<()> {
-        let mut root_rewrite: String = parent.rewrite(context, shape)?;
+    ) -> Result<(), RewriteError> {
+        let mut root_rewrite: String = parent.rewrite_result(context, shape)?;
 
         let mut root_ends_with_block = parent.kind.is_block_like(context, &root_rewrite);
         let tab_width = context.config.tab_spaces().saturating_sub(shape.offset);
@@ -822,10 +862,12 @@ impl<'a> ChainFormatter for ChainFormatterBlock<'a> {
             if let ChainItemKind::Comment(..) = item.kind {
                 break;
             }
-            let shape = shape.offset_left(root_rewrite.len())?;
-            match &item.rewrite(context, shape) {
-                Some(rewrite) => root_rewrite.push_str(rewrite),
-                None => break,
+            let shape = shape
+                .offset_left(root_rewrite.len())
+                .max_width_error(shape.width, item.span)?;
+            match &item.rewrite_result(context, shape) {
+                Ok(rewrite) => root_rewrite.push_str(rewrite),
+                Err(_) => break,
             }
 
             root_ends_with_block = last_line_extendable(&root_rewrite);
@@ -837,7 +879,7 @@ impl<'a> ChainFormatter for ChainFormatterBlock<'a> {
         }
         self.shared.rewrites.push(root_rewrite);
         self.root_ends_with_block = root_ends_with_block;
-        Some(())
+        Ok(())
     }
 
     fn child_shape(&self, context: &RewriteContext<'_>, shape: Shape) -> Option<Shape> {
@@ -845,7 +887,11 @@ impl<'a> ChainFormatter for ChainFormatterBlock<'a> {
         Some(get_block_child_shape(block_end, context, shape))
     }
 
-    fn format_children(&mut self, context: &RewriteContext<'_>, child_shape: Shape) -> Option<()> {
+    fn format_children(
+        &mut self,
+        context: &RewriteContext<'_>,
+        child_shape: Shape,
+    ) -> Result<(), RewriteError> {
         self.shared.format_children(context, child_shape)
     }
 
@@ -854,12 +900,12 @@ impl<'a> ChainFormatter for ChainFormatterBlock<'a> {
         context: &RewriteContext<'_>,
         shape: Shape,
         child_shape: Shape,
-    ) -> Option<()> {
+    ) -> Result<(), RewriteError> {
         self.shared
             .format_last_child(true, context, shape, child_shape)
     }
 
-    fn join_rewrites(&self, context: &RewriteContext<'_>, child_shape: Shape) -> Option<String> {
+    fn join_rewrites(&self, context: &RewriteContext<'_>, child_shape: Shape) -> RewriteResult {
         self.shared.join_rewrites(context, child_shape)
     }
 
@@ -890,9 +936,9 @@ impl<'a> ChainFormatter for ChainFormatterVisual<'a> {
         parent: &ChainItem,
         context: &RewriteContext<'_>,
         shape: Shape,
-    ) -> Option<()> {
+    ) -> Result<(), RewriteError> {
         let parent_shape = shape.visual_indent(0);
-        let mut root_rewrite = parent.rewrite(context, parent_shape)?;
+        let mut root_rewrite = parent.rewrite_result(context, parent_shape)?;
         let multiline = root_rewrite.contains('\n');
         self.offset = if multiline {
             last_line_width(&root_rewrite).saturating_sub(shape.used_width())
@@ -904,18 +950,19 @@ impl<'a> ChainFormatter for ChainFormatterVisual<'a> {
             let item = &self.shared.children[0];
             if let ChainItemKind::Comment(..) = item.kind {
                 self.shared.rewrites.push(root_rewrite);
-                return Some(());
+                return Ok(());
             }
             let child_shape = parent_shape
                 .visual_indent(self.offset)
-                .sub_width(self.offset)?;
-            let rewrite = item.rewrite(context, child_shape)?;
+                .sub_width(self.offset)
+                .max_width_error(parent_shape.width, item.span)?;
+            let rewrite = item.rewrite_result(context, child_shape)?;
             if filtered_str_fits(&rewrite, context.config.max_width(), shape) {
                 root_rewrite.push_str(&rewrite);
             } else {
                 // We couldn't fit in at the visual indent, try the last
                 // indent.
-                let rewrite = item.rewrite(context, parent_shape)?;
+                let rewrite = item.rewrite_result(context, parent_shape)?;
                 root_rewrite.push_str(&rewrite);
                 self.offset = 0;
             }
@@ -924,7 +971,7 @@ impl<'a> ChainFormatter for ChainFormatterVisual<'a> {
         }
 
         self.shared.rewrites.push(root_rewrite);
-        Some(())
+        Ok(())
     }
 
     fn child_shape(&self, context: &RewriteContext<'_>, shape: Shape) -> Option<Shape> {
@@ -937,7 +984,11 @@ impl<'a> ChainFormatter for ChainFormatterVisual<'a> {
         )
     }
 
-    fn format_children(&mut self, context: &RewriteContext<'_>, child_shape: Shape) -> Option<()> {
+    fn format_children(
+        &mut self,
+        context: &RewriteContext<'_>,
+        child_shape: Shape,
+    ) -> Result<(), RewriteError> {
         self.shared.format_children(context, child_shape)
     }
 
@@ -946,12 +997,12 @@ impl<'a> ChainFormatter for ChainFormatterVisual<'a> {
         context: &RewriteContext<'_>,
         shape: Shape,
         child_shape: Shape,
-    ) -> Option<()> {
+    ) -> Result<(), RewriteError> {
         self.shared
             .format_last_child(false, context, shape, child_shape)
     }
 
-    fn join_rewrites(&self, context: &RewriteContext<'_>, child_shape: Shape) -> Option<String> {
+    fn join_rewrites(&self, context: &RewriteContext<'_>, child_shape: Shape) -> RewriteResult {
         self.shared.join_rewrites(context, child_shape)
     }
 
diff --git a/src/expr.rs b/src/expr.rs
index 261c7b33e0a..2d851f18889 100644
--- a/src/expr.rs
+++ b/src/expr.rs
@@ -241,7 +241,7 @@ pub(crate) fn format_expr(
         ast::ExprKind::Try(..)
         | ast::ExprKind::Field(..)
         | ast::ExprKind::MethodCall(..)
-        | ast::ExprKind::Await(_, _) => rewrite_chain(expr, context, shape),
+        | ast::ExprKind::Await(_, _) => rewrite_chain(expr, context, shape).ok(),
         ast::ExprKind::MacCall(ref mac) => {
             rewrite_macro(mac, None, context, shape, MacroPosition::Expression).or_else(|| {
                 wrap_str(