From 3e92f90952a2f9afc64079dc0f2dd02dd6484388 Mon Sep 17 00:00:00 2001 From: Marijn Haverbeke Date: Tue, 13 Sep 2011 12:14:30 +0200 Subject: [PATCH] Apply implicit copying for unsafe references to alt patterns --- src/comp/middle/alias.rs | 140 ++++++++++++++++-------- src/comp/middle/mut.rs | 5 - src/comp/middle/trans_alt.rs | 23 +++- src/comp/middle/ty.rs | 11 ++ src/test/compile-fail/unsafe-alias-2.rs | 4 +- src/test/compile-fail/unsafe-alt.rs | 4 +- src/test/run-pass/alt-implicit-copy.rs | 6 + 7 files changed, 135 insertions(+), 58 deletions(-) create mode 100644 src/test/run-pass/alt-implicit-copy.rs diff --git a/src/comp/middle/alias.rs b/src/comp/middle/alias.rs index 47abdbbfe46..23da42291aa 100644 --- a/src/comp/middle/alias.rs +++ b/src/comp/middle/alias.rs @@ -1,7 +1,7 @@ import syntax::{ast, ast_util}; import ast::{ident, fn_ident, node_id, def_id}; -import mut::{expr_root, mut_field, inner_mut}; +import mut::{expr_root, mut_field, deref, field, index, unbox}; import syntax::codemap::span; import syntax::visit; import visit::vt; @@ -21,7 +21,7 @@ type restrict = span: span, local_id: uint, bindings: [node_id], - unsafe_ty: option::t, + unsafe_tys: [ty::t], depends_on: [uint], mutable ok: valid, mutable given_up: bool}; @@ -192,21 +192,17 @@ fn check_call(cx: ctx, f: @ast::expr, args: [@ast::expr], sc: scope) -> } } let root_var = path_def_id(cx, root.ex); - let unsafe_t = - alt inner_mut(root.ds) { some(t) { some(t) } _ { none } }; - restricts += - [ - // FIXME kludge - @{root_var: root_var, - node_id: arg_t.mode == ast::by_mut_ref ? 0 : arg.id, - ty: arg_t.ty, - span: arg.span, - local_id: cx.next_local, - bindings: [arg.id], - unsafe_ty: unsafe_t, - depends_on: deps(sc, root_var), - mutable ok: valid, - mutable given_up: arg_t.mode == ast::by_move}]; + restricts += [@{root_var: root_var, + // FIXME kludge + node_id: arg_t.mode == ast::by_mut_ref ? 0 : arg.id, + ty: arg_t.ty, + span: arg.span, + local_id: cx.next_local, + bindings: [arg.id], + unsafe_tys: inner_mut(root.ds), + depends_on: deps(sc, root_var), + mutable ok: valid, + mutable given_up: arg_t.mode == ast::by_move}]; i += 1u; } let f_may_close = @@ -217,7 +213,7 @@ fn check_call(cx: ctx, f: @ast::expr, args: [@ast::expr], sc: scope) -> if f_may_close { let i = 0u; for r in restricts { - if !option::is_none(r.unsafe_ty) && cant_copy(cx, r) { + if vec::len(r.unsafe_tys) > 0u && cant_copy(cx, r) { cx.tcx.sess.span_err(f.span, #fmt["function may alias with argument \ %u, which is not immutably rooted", @@ -228,8 +224,7 @@ fn check_call(cx: ctx, f: @ast::expr, args: [@ast::expr], sc: scope) -> } let j = 0u; for r in restricts { - alt r.unsafe_ty { - some(ty) { + for ty in r.unsafe_tys { let i = 0u; for arg_t: ty::arg in arg_ts { let mut_alias = arg_t.mode == ast::by_mut_ref; @@ -244,8 +239,6 @@ fn check_call(cx: ctx, f: @ast::expr, args: [@ast::expr], sc: scope) -> } i += 1u; } - } - _ { } } j += 1u; } @@ -279,24 +272,42 @@ fn check_alt(cx: ctx, input: @ast::expr, arms: [ast::arm], sc: scope, v.visit_expr(input, sc, v); let root = expr_root(cx.tcx, input, true); for a: ast::arm in arms { - let dnums = ast_util::pat_binding_ids(a.pats[0]); - let new_sc = sc; - if vec::len(dnums) > 0u { - let root_var = path_def_id(cx, root.ex); - // FIXME need to use separate restrict for each binding - new_sc = @(*sc + [@{root_var: root_var, - node_id: 0, - ty: ty::mk_int(cx.tcx), - span: a.pats[0].span, - local_id: cx.next_local, - bindings: dnums, - unsafe_ty: inner_mut(root.ds), - depends_on: deps(sc, root_var), - mutable ok: valid, - mutable given_up: false}]); + // FIXME handle other | patterns + let new_sc = *sc; + let root_var = path_def_id(cx, root.ex); + let pat_id_map = ast_util::pat_id_map(a.pats[0]); + type info = {id: node_id, mutable unsafe: [ty::t], span: span}; + let binding_info: [info] = []; + for pat in a.pats { + for proot in *pattern_roots(cx.tcx, root.ds, pat) { + let canon_id = pat_id_map.get(proot.name); + // FIXME I wanted to use a block, but that hit a + // typestate bug. + fn match(x: info, canon: node_id) -> bool { x.id == canon } + alt vec::find(bind match(_, canon_id), binding_info) { + some(s) { s.unsafe += inner_mut(proot.ds); } + none. { + binding_info += [{id: canon_id, + mutable unsafe: inner_mut(proot.ds), + span: proot.span}]; + } + } + } + } + for info in binding_info { + new_sc += [@{root_var: root_var, + node_id: info.id, + ty: ty::node_id_to_type(cx.tcx, info.id), + span: info.span, + local_id: cx.next_local, + bindings: [info.id], + unsafe_tys: info.unsafe, + depends_on: deps(sc, root_var), + mutable ok: valid, + mutable given_up: false}]; } register_locals(cx, a.pats[0]); - visit::visit_arm(a, new_sc, v); + visit::visit_arm(a, @new_sc, v); } } @@ -323,7 +334,7 @@ fn check_for(cx: ctx, local: @ast::local, seq: @ast::expr, blk: ast::blk, let elt_t; alt ty::struct(cx.tcx, seq_t) { ty::ty_vec(mt) { - if mt.mut != ast::imm { unsafe = some(seq_t); } + if mt.mut != ast::imm { unsafe = [seq_t]; } elt_t = mt.ty; } ty::ty_str. { elt_t = ty::mk_mach(cx.tcx, ast::ty_u8); } @@ -337,7 +348,7 @@ fn check_for(cx: ctx, local: @ast::local, seq: @ast::expr, blk: ast::blk, span: local.node.pat.span, local_id: cx.next_local, bindings: ast_util::pat_binding_ids(local.node.pat), - unsafe_ty: unsafe, + unsafe_tys: unsafe, depends_on: deps(sc, root_var), mutable ok: valid, mutable given_up: false}; @@ -354,16 +365,12 @@ fn check_var(cx: ctx, ex: @ast::expr, p: ast::path, id: ast::node_id, alt cx.local_map.find(my_defnum) { some(local(id)) { id } _ { 0u } }; let var_t = ty::expr_ty(cx.tcx, ex); for r: restrict in *sc { - // excludes variables introduced since the alias was made if my_local_id < r.local_id { - alt r.unsafe_ty { - some(ty) { + for ty in r.unsafe_tys { if ty_can_unsafely_include(cx, ty, var_t, assign) { r.ok = val_taken(ex.span, p); } - } - _ { } } } else if vec::member(my_defnum, r.bindings) { test_scope(cx, sc, r, p); @@ -546,6 +553,49 @@ fn copy_is_expensive(tcx: ty::ctxt, ty: ty::t) -> bool { ret score_ty(tcx, ty) > 8u; } +type pattern_root = {id: node_id, name: ident, ds: @[deref], span: span}; + +fn pattern_roots(tcx: ty::ctxt, base: @[deref], pat: @ast::pat) + -> @[pattern_root] { + fn walk(tcx: ty::ctxt, base: [deref], pat: @ast::pat, + &set: [pattern_root]) { + alt pat.node { + ast::pat_wild. | ast::pat_lit(_) {} + ast::pat_bind(nm) { + set += [{id: pat.id, name: nm, ds: @base, span: pat.span}]; + } + ast::pat_tag(_, ps) | ast::pat_tup(ps) { + let base = base + [@{mut: false, kind: field, + outer_t: ty::node_id_to_type(tcx, pat.id)}]; + for p in ps { walk(tcx, base, p, set); } + } + ast::pat_rec(fs, _) { + let ty = ty::node_id_to_type(tcx, pat.id); + for f in fs { + let mut = ty::get_field(tcx, ty, f.ident).mt.mut != ast::imm; + let base = base + [@{mut: mut, kind: field, outer_t: ty}]; + walk(tcx, base, f.pat, set); + } + } + ast::pat_box(p) { + let ty = ty::node_id_to_type(tcx, pat.id); + let mut = alt ty::struct(tcx, ty) { + ty::ty_box(mt) { mt.mut != ast::imm } + }; + walk(tcx, base + [@{mut: mut, kind: unbox, outer_t: ty}], p, set); + } + } + } + let set = []; + walk(tcx, *base, pat, set); + ret @set; +} + +fn inner_mut(ds: @[deref]) -> [ty::t] { + for d: deref in *ds { if d.mut { ret [d.outer_t]; } } + ret []; +} + // Local Variables: // mode: rust // fill-column: 78; diff --git a/src/comp/middle/mut.rs b/src/comp/middle/mut.rs index 2aa3ddc66c8..253b402fb04 100644 --- a/src/comp/middle/mut.rs +++ b/src/comp/middle/mut.rs @@ -110,11 +110,6 @@ fn mut_field(ds: @[deref]) -> bool { ret false; } -fn inner_mut(ds: @[deref]) -> option::t { - for d: deref in *ds { if d.mut { ret some(d.outer_t); } } - ret none; -} - // Actual mut-checking pass type mut_map = std::map::hashmap; diff --git a/src/comp/middle/trans_alt.rs b/src/comp/middle/trans_alt.rs index cd15116964f..0b84f85054a 100644 --- a/src/comp/middle/trans_alt.rs +++ b/src/comp/middle/trans_alt.rs @@ -471,22 +471,37 @@ fn make_phi_bindings(bcx: @block_ctxt, map: [exit_node], ids: ast_util::pat_id_map) -> bool { let our_block = bcx.llbb as uint; let success = true; - for each item: @{key: ast::ident, val: ast::node_id} in ids.items() { + for each @{key: name, val: node_id} in ids.items() { let llbbs = []; let vals = []; for ex: exit_node in map { if ex.to as uint == our_block { - alt assoc(item.key, ex.bound) { + alt assoc(name, ex.bound) { some(val) { llbbs += [ex.from]; vals += [val]; } none. { } } } } if vec::len(vals) > 0u { - let phi = Phi(bcx, val_ty(vals[0]), vals, llbbs); - bcx.fcx.lllocals.insert(item.val, phi); + let local = Phi(bcx, val_ty(vals[0]), vals, llbbs); + bcx.fcx.lllocals.insert(node_id, local); } else { success = false; } } + if success { + // Copy references that the alias analysis considered unsafe + for each @{val: node_id, _} in ids.items() { + if bcx_ccx(bcx).copy_map.contains_key(node_id) { + let local = bcx.fcx.lllocals.get(node_id); + let e_ty = ty::node_id_to_type(bcx_tcx(bcx), node_id); + let {bcx: abcx, val: alloc} = trans::alloc_ty(bcx, e_ty); + bcx = trans::copy_val(abcx, trans::INIT, alloc, + load_if_immediate(abcx, local, e_ty), + e_ty); + add_clean(bcx, alloc, e_ty); + bcx.fcx.lllocals.insert(node_id, alloc); + } + } + } ret success; } diff --git a/src/comp/middle/ty.rs b/src/comp/middle/ty.rs index 80095bdd452..082644dbd72 100644 --- a/src/comp/middle/ty.rs +++ b/src/comp/middle/ty.rs @@ -45,6 +45,7 @@ export expr_ty_params_and_ty; export fold_ty; export field; export field_idx; +export get_field; export fm_general; export get_element_type; export hash_ty; @@ -1680,6 +1681,16 @@ fn field_idx(sess: session::session, sp: span, id: ast::ident, sess.span_fatal(sp, "unknown field '" + id + "' of record"); } +fn get_field(tcx: ctxt, rec_ty: t, id: ast::ident) -> field { + alt struct(tcx, rec_ty) { + ty_rec(fields) { + alt vec::find({|f| str::eq(f.ident, id) }, fields) { + some(f) { ret f; } + } + } + } +} + fn method_idx(sess: session::session, sp: span, id: ast::ident, meths: [method]) -> uint { let i: uint = 0u; diff --git a/src/test/compile-fail/unsafe-alias-2.rs b/src/test/compile-fail/unsafe-alias-2.rs index 9f28c6190ef..23daa8bba16 100644 --- a/src/test/compile-fail/unsafe-alias-2.rs +++ b/src/test/compile-fail/unsafe-alias-2.rs @@ -1,8 +1,8 @@ // error-pattern:invalidate reference x -fn whoknows(x: @mutable int) { *x = 10; } +fn whoknows(x: @mutable {mutable x: int}) { x.x = 10; } fn main() { - let box = @mutable 1; + let box = @mutable {mutable x: 1}; alt *box { x { whoknows(box); log_err x; } } } diff --git a/src/test/compile-fail/unsafe-alt.rs b/src/test/compile-fail/unsafe-alt.rs index b8314771561..864ba50298b 100644 --- a/src/test/compile-fail/unsafe-alt.rs +++ b/src/test/compile-fail/unsafe-alt.rs @@ -1,8 +1,8 @@ // error-pattern:invalidate reference i -tag foo { left(int); right(bool); } +tag foo { left({mutable x: int}); right(bool); } fn main() { - let x = left(10); + let x = left({mutable x: 10}); alt x { left(i) { x = right(false); log i; } _ { } } } diff --git a/src/test/run-pass/alt-implicit-copy.rs b/src/test/run-pass/alt-implicit-copy.rs new file mode 100644 index 00000000000..7bb148b1b91 --- /dev/null +++ b/src/test/run-pass/alt-implicit-copy.rs @@ -0,0 +1,6 @@ +fn main() { + let x = @{mutable a: @10, b: @20}; + alt x { + @{a, b} { assert *a == 10; (*x).a = @30; assert *a == 10; } + } +}