From fbc3f11fc18a626ba9b9f6bc52de4fc5ef75154b Mon Sep 17 00:00:00 2001
From: Ulrik Sverdrup <bluss@users.noreply.github.com>
Date: Fri, 9 Dec 2016 16:28:54 +0100
Subject: [PATCH] mir: Reinstate while loop in deaggregator pass

A previous commit must have removed the `while let` loop here by
mistake; for each basic block, it should find and deaggregate multiple
statements in their index order, and the `curr` index tracks the
progress through the block.

This fixes both the case of deaggregating statements in separate
basic blocks (preserving `curr` could prevent that) as well
as multiple times in the same block (missing loop prevented that).
---
 src/librustc_mir/transform/deaggregator.rs    | 119 +++++++++---------
 src/test/mir-opt/deaggregator_test_enum_2.rs  |  57 +++++++++
 .../mir-opt/deaggregator_test_multiple.rs     |  48 +++++++
 3 files changed, 164 insertions(+), 60 deletions(-)
 create mode 100644 src/test/mir-opt/deaggregator_test_enum_2.rs
 create mode 100644 src/test/mir-opt/deaggregator_test_multiple.rs

diff --git a/src/librustc_mir/transform/deaggregator.rs b/src/librustc_mir/transform/deaggregator.rs
index fcdeae6d6c0..e13c8e02137 100644
--- a/src/librustc_mir/transform/deaggregator.rs
+++ b/src/librustc_mir/transform/deaggregator.rs
@@ -36,71 +36,70 @@ impl<'tcx> MirPass<'tcx> for Deaggregator {
         // In fact, we might not want to trigger in other cases.
         // Ex: when we could use SROA.  See issue #35259
 
-        let mut curr: usize = 0;
         for bb in mir.basic_blocks_mut() {
-            let idx = match get_aggregate_statement_index(curr, &bb.statements) {
-                Some(idx) => idx,
-                None => continue,
-            };
-            // do the replacement
-            debug!("removing statement {:?}", idx);
-            let src_info = bb.statements[idx].source_info;
-            let suffix_stmts = bb.statements.split_off(idx+1);
-            let orig_stmt = bb.statements.pop().unwrap();
-            let (lhs, rhs) = match orig_stmt.kind {
-                StatementKind::Assign(ref lhs, ref rhs) => (lhs, rhs),
-                _ => span_bug!(src_info.span, "expected assign, not {:?}", orig_stmt),
-            };
-            let (agg_kind, operands) = match rhs {
-                &Rvalue::Aggregate(ref agg_kind, ref operands) => (agg_kind, operands),
-                _ => span_bug!(src_info.span, "expected aggregate, not {:?}", rhs),
-            };
-            let (adt_def, variant, substs) = match agg_kind {
-                &AggregateKind::Adt(adt_def, variant, substs, None) => (adt_def, variant, substs),
-                _ => span_bug!(src_info.span, "expected struct, not {:?}", rhs),
-            };
-            let n = bb.statements.len();
-            bb.statements.reserve(n + operands.len() + suffix_stmts.len());
-            for (i, op) in operands.iter().enumerate() {
-                let ref variant_def = adt_def.variants[variant];
-                let ty = variant_def.fields[i].ty(tcx, substs);
-                let rhs = Rvalue::Use(op.clone());
+            let mut curr: usize = 0;
+            while let Some(idx) = get_aggregate_statement_index(curr, &bb.statements) {
+                // do the replacement
+                debug!("removing statement {:?}", idx);
+                let src_info = bb.statements[idx].source_info;
+                let suffix_stmts = bb.statements.split_off(idx+1);
+                let orig_stmt = bb.statements.pop().unwrap();
+                let (lhs, rhs) = match orig_stmt.kind {
+                    StatementKind::Assign(ref lhs, ref rhs) => (lhs, rhs),
+                    _ => span_bug!(src_info.span, "expected assign, not {:?}", orig_stmt),
+                };
+                let (agg_kind, operands) = match rhs {
+                    &Rvalue::Aggregate(ref agg_kind, ref operands) => (agg_kind, operands),
+                    _ => span_bug!(src_info.span, "expected aggregate, not {:?}", rhs),
+                };
+                let (adt_def, variant, substs) = match agg_kind {
+                    &AggregateKind::Adt(adt_def, variant, substs, None)
+                        => (adt_def, variant, substs),
+                    _ => span_bug!(src_info.span, "expected struct, not {:?}", rhs),
+                };
+                let n = bb.statements.len();
+                bb.statements.reserve(n + operands.len() + suffix_stmts.len());
+                for (i, op) in operands.iter().enumerate() {
+                    let ref variant_def = adt_def.variants[variant];
+                    let ty = variant_def.fields[i].ty(tcx, substs);
+                    let rhs = Rvalue::Use(op.clone());
 
-                let lhs_cast = if adt_def.variants.len() > 1 {
-                    Lvalue::Projection(Box::new(LvalueProjection {
-                        base: lhs.clone(),
-                        elem: ProjectionElem::Downcast(adt_def, variant),
-                    }))
-                } else {
-                    lhs.clone()
+                    let lhs_cast = if adt_def.variants.len() > 1 {
+                        Lvalue::Projection(Box::new(LvalueProjection {
+                            base: lhs.clone(),
+                            elem: ProjectionElem::Downcast(adt_def, variant),
+                        }))
+                    } else {
+                        lhs.clone()
+                    };
+
+                    let lhs_proj = Lvalue::Projection(Box::new(LvalueProjection {
+                        base: lhs_cast,
+                        elem: ProjectionElem::Field(Field::new(i), ty),
+                    }));
+                    let new_statement = Statement {
+                        source_info: src_info,
+                        kind: StatementKind::Assign(lhs_proj, rhs),
+                    };
+                    debug!("inserting: {:?} @ {:?}", new_statement, idx + i);
+                    bb.statements.push(new_statement);
+                }
+
+                // if the aggregate was an enum, we need to set the discriminant
+                if adt_def.variants.len() > 1 {
+                    let set_discriminant = Statement {
+                        kind: StatementKind::SetDiscriminant {
+                            lvalue: lhs.clone(),
+                            variant_index: variant,
+                        },
+                        source_info: src_info,
+                    };
+                    bb.statements.push(set_discriminant);
                 };
 
-                let lhs_proj = Lvalue::Projection(Box::new(LvalueProjection {
-                    base: lhs_cast,
-                    elem: ProjectionElem::Field(Field::new(i), ty),
-                }));
-                let new_statement = Statement {
-                    source_info: src_info,
-                    kind: StatementKind::Assign(lhs_proj, rhs),
-                };
-                debug!("inserting: {:?} @ {:?}", new_statement, idx + i);
-                bb.statements.push(new_statement);
+                curr = bb.statements.len();
+                bb.statements.extend(suffix_stmts);
             }
-
-            // if the aggregate was an enum, we need to set the discriminant
-            if adt_def.variants.len() > 1 {
-                let set_discriminant = Statement {
-                    kind: StatementKind::SetDiscriminant {
-                        lvalue: lhs.clone(),
-                        variant_index: variant,
-                    },
-                    source_info: src_info,
-                };
-                bb.statements.push(set_discriminant);
-            };
-
-            curr = bb.statements.len();
-            bb.statements.extend(suffix_stmts);
         }
     }
 }
diff --git a/src/test/mir-opt/deaggregator_test_enum_2.rs b/src/test/mir-opt/deaggregator_test_enum_2.rs
new file mode 100644
index 00000000000..02d496b2901
--- /dev/null
+++ b/src/test/mir-opt/deaggregator_test_enum_2.rs
@@ -0,0 +1,57 @@
+// Copyright 2016 The Rust Project Developers. See the COPYRIGHT
+// file at the top-level directory of this distribution and at
+// http://rust-lang.org/COPYRIGHT.
+//
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Test that deaggregate fires in more than one basic block
+
+enum Foo {
+    A(i32),
+    B(i32),
+}
+
+fn test1(x: bool, y: i32) -> Foo {
+    if x {
+        Foo::A(y)
+    } else {
+        Foo::B(y)
+    }
+}
+
+fn main() {}
+
+// END RUST SOURCE
+// START rustc.node12.Deaggregator.before.mir
+//  bb1: {
+//      _6 = _4;
+//      _0 = Foo::A(_6,);
+//      goto -> bb3;
+//  }
+//
+//  bb2: {
+//      _7 = _4;
+//      _0 = Foo::B(_7,);
+//      goto -> bb3;
+//  }
+// END rustc.node12.Deaggregator.before.mir
+// START rustc.node12.Deaggregator.after.mir
+//  bb1: {
+//      _6 = _4;
+//      ((_0 as A).0: i32) = _6;
+//      discriminant(_0) = 0;
+//      goto -> bb3;
+//  }
+//
+//  bb2: {
+//      _7 = _4;
+//      ((_0 as B).0: i32) = _7;
+//      discriminant(_0) = 1;
+//      goto -> bb3;
+//  }
+// END rustc.node12.Deaggregator.after.mir
+//
diff --git a/src/test/mir-opt/deaggregator_test_multiple.rs b/src/test/mir-opt/deaggregator_test_multiple.rs
new file mode 100644
index 00000000000..a180a69be55
--- /dev/null
+++ b/src/test/mir-opt/deaggregator_test_multiple.rs
@@ -0,0 +1,48 @@
+// Copyright 2016 The Rust Project Developers. See the COPYRIGHT
+// file at the top-level directory of this distribution and at
+// http://rust-lang.org/COPYRIGHT.
+//
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Test that deaggregate fires more than once per block
+
+enum Foo {
+    A(i32),
+    B,
+}
+
+fn test(x: i32) -> [Foo; 2] {
+    [Foo::A(x), Foo::A(x)]
+}
+
+fn main() { }
+
+// END RUST SOURCE
+// START rustc.node10.Deaggregator.before.mir
+// bb0: {
+//     _2 = _1;
+//     _4 = _2;
+//     _3 = Foo::A(_4,);
+//     _6 = _2;
+//     _5 = Foo::A(_6,);
+//     _0 = [_3, _5];
+//     return;
+// }
+// END rustc.node10.Deaggregator.before.mir
+// START rustc.node10.Deaggregator.after.mir
+// bb0: {
+//     _2 = _1;
+//     _4 = _2;
+//     ((_3 as A).0: i32) = _4;
+//     discriminant(_3) = 0;
+//     _6 = _2;
+//     ((_5 as A).0: i32) = _6;
+//     discriminant(_5) = 0;
+//     _0 = [_3, _5];
+//     return;
+// }
+// END rustc.node10.Deaggregator.after.mir