From b4217b383bbbebafbdc621f31716d5bee3cd0c72 Mon Sep 17 00:00:00 2001
From: Marijn Haverbeke <marijnh@gmail.com>
Date: Mon, 21 Nov 2011 10:56:00 +0100
Subject: [PATCH] Add a pass that checks that blocks are only used in safe ways

Closes #1188
---
 src/comp/driver/rustc.rs                      |  2 +
 src/comp/middle/block_use.rs                  | 41 +++++++++++++++++++
 src/comp/rustc.rc                             |  1 +
 src/test/compile-fail/block-copy.rs           |  2 +-
 .../compile-fail/block-deinitializes-upvar.rs |  7 +---
 5 files changed, 47 insertions(+), 6 deletions(-)
 create mode 100644 src/comp/middle/block_use.rs

diff --git a/src/comp/driver/rustc.rs b/src/comp/driver/rustc.rs
index 32dd652b059..c976410554a 100644
--- a/src/comp/driver/rustc.rs
+++ b/src/comp/driver/rustc.rs
@@ -139,6 +139,8 @@ fn compile_input(sess: session::session, cfg: ast::crate_cfg, input: str,
              bind freevars::annotate_freevars(def_map, crate));
     let ty_cx = ty::mk_ctxt(sess, def_map, ext_map, ast_map, freevars);
     time(time_passes, "typechecking", bind typeck::check_crate(ty_cx, crate));
+    time(time_passes, "block-use checking",
+         bind middle::block_use::check_crate(ty_cx, crate));
     time(time_passes, "function usage",
          bind fn_usage::check_crate_fn_usage(ty_cx, crate));
     time(time_passes, "alt checking",
diff --git a/src/comp/middle/block_use.rs b/src/comp/middle/block_use.rs
new file mode 100644
index 00000000000..78a49692e01
--- /dev/null
+++ b/src/comp/middle/block_use.rs
@@ -0,0 +1,41 @@
+import syntax::visit;
+import syntax::ast::*;
+
+type ctx = {tcx: ty::ctxt, mutable allow_block: bool};
+
+fn check_crate(tcx: ty::ctxt, crate: @crate) {
+    let cx = {tcx: tcx, mutable allow_block: false};
+    let v = visit::mk_vt(@{visit_expr: visit_expr
+                           with *visit::default_visitor()});
+    visit::visit_crate(*crate, cx, v);
+}
+
+fn visit_expr(ex: @expr, cx: ctx, v: visit::vt<ctx>) {
+    if !cx.allow_block {
+        alt ty::struct(cx.tcx, ty::expr_ty(cx.tcx, ex)) {
+          ty::ty_fn(proto_block., _, _, _, _) {
+            cx.tcx.sess.span_err(ex.span, "expressions with block type \
+                can only appear in callee or (by-ref) argument position");
+          }
+          _ {}
+        }
+    }
+    let outer = cx.allow_block;
+    alt ex.node {
+      expr_call(f, args, _) {
+        cx.allow_block = true;
+        v.visit_expr(f, cx, v);
+        let i = 0u;
+        for arg_t in ty::ty_fn_args(cx.tcx, ty::expr_ty(cx.tcx, f)) {
+            cx.allow_block = arg_t.mode == by_ref;
+            v.visit_expr(args[i], cx, v);
+            i += 1u;
+        }
+      }
+      _ {
+        cx.allow_block = false;
+        visit::visit_expr(ex, cx, v);
+      }
+    }
+    cx.allow_block = outer;
+}
diff --git a/src/comp/rustc.rc b/src/comp/rustc.rc
index af6369b5bce..8f8a1414513 100644
--- a/src/comp/rustc.rc
+++ b/src/comp/rustc.rc
@@ -31,6 +31,7 @@ mod middle {
     mod mut;
     mod alias;
     mod last_use;
+    mod block_use;
     mod kind;
     mod freevars;
     mod shape;
diff --git a/src/test/compile-fail/block-copy.rs b/src/test/compile-fail/block-copy.rs
index b14bb3ddedd..313e86511d1 100644
--- a/src/test/compile-fail/block-copy.rs
+++ b/src/test/compile-fail/block-copy.rs
@@ -1,4 +1,4 @@
-// error-pattern: copying a noncopyable value
+// error-pattern: block type can only appear
 
 fn lol(f: block()) -> block() { ret f; }
 fn main() { let i = 8; let f = lol(block () { log_err i; }); f(); }
diff --git a/src/test/compile-fail/block-deinitializes-upvar.rs b/src/test/compile-fail/block-deinitializes-upvar.rs
index 72cd09b18ab..65367a01a7b 100644
--- a/src/test/compile-fail/block-deinitializes-upvar.rs
+++ b/src/test/compile-fail/block-deinitializes-upvar.rs
@@ -1,11 +1,8 @@
 // error-pattern:Tried to deinitialize a variable declared in a different
-fn force(f: block() -> int) -> int { ret f(); }
+fn force(f: block()) { f(); }
 fn main() {
     let x = @{x: 17, y: 2};
     let y = @{x: 5, y: 5};
 
-    let f = {|i| log_err i; x <- y; ret 7; };
-    assert (f(5) == 7);
-    log_err x;
-    log_err y;
+    force({|| x <- y;});
 }