rust/src/librustc_mir/transform/add_validation.rs

391 lines
18 KiB
Rust
Raw Normal View History

// Copyright 2015 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// http://rust-lang.org/COPYRIGHT.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
//! This pass adds validation calls (AcquireValid, ReleaseValid) where appropriate.
//! It has to be run really early, before transformations like inlining, because
//! introducing these calls *adds* UB -- so, conceptually, this pass is actually part
//! of MIR building, and only after this pass we think of the program has having the
//! normal MIR semantics.
use rustc::ty::{self, TyCtxt, RegionKind};
use rustc::hir;
use rustc::mir::*;
use rustc::mir::transform::{MirPass, MirSource};
use rustc::middle::region;
pub struct AddValidation;
/// Determine the "context" of the lval: Mutability and region.
fn lval_context<'a, 'tcx, D>(
lval: &Lvalue<'tcx>,
local_decls: &D,
tcx: TyCtxt<'a, 'tcx, 'tcx>
) -> (Option<region::Scope>, hir::Mutability)
where D: HasLocalDecls<'tcx>
{
use rustc::mir::Lvalue::*;
match *lval {
Local { .. } => (None, hir::MutMutable),
Static(_) => (None, hir::MutImmutable),
Projection(ref proj) => {
match proj.elem {
ProjectionElem::Deref => {
// Computing the inside the recursion makes this quadratic.
// We don't expect deep paths though.
let ty = proj.base.ty(local_decls, tcx).to_ty(tcx);
// A Deref projection may restrict the context, this depends on the type
// being deref'd.
let context = match ty.sty {
ty::TyRef(re, tam) => {
let re = match re {
&RegionKind::ReScope(ce) => Some(ce),
&RegionKind::ReErased =>
bug!("AddValidation pass must be run before erasing lifetimes"),
_ => None
};
(re, tam.mutbl)
}
ty::TyRawPtr(_) =>
// There is no guarantee behind even a mutable raw pointer,
// no write locks are acquired there, so we also don't want to
// release any.
(None, hir::MutImmutable),
ty::TyAdt(adt, _) if adt.is_box() => (None, hir::MutMutable),
_ => bug!("Deref on a non-pointer type {:?}", ty),
};
// "Intersect" this restriction with proj.base.
if let (Some(_), hir::MutImmutable) = context {
// This is already as restricted as it gets, no need to even recurse
context
} else {
let base_context = lval_context(&proj.base, local_decls, tcx);
// The region of the outermost Deref is always most restrictive.
let re = context.0.or(base_context.0);
let mutbl = context.1.and(base_context.1);
(re, mutbl)
}
}
_ => lval_context(&proj.base, local_decls, tcx),
}
}
}
}
/// Check if this function contains an unsafe block or is an unsafe function.
fn fn_contains_unsafe<'a, 'tcx>(tcx: TyCtxt<'a, 'tcx, 'tcx>, src: MirSource) -> bool {
use rustc::hir::intravisit::{self, Visitor, FnKind};
use rustc::hir::map::blocks::FnLikeNode;
2017-07-31 18:33:45 -07:00
use rustc::hir::map::Node;
/// Decide if this is an unsafe block
2017-07-31 19:51:10 -07:00
fn block_is_unsafe(block: &hir::Block) -> bool {
use rustc::hir::BlockCheckMode::*;
match block.rules {
UnsafeBlock(_) | PushUnsafeBlock(_) => true,
// For PopUnsafeBlock, we don't actually know -- but we will always also check all
// parent blocks, so we can safely declare the PopUnsafeBlock to not be unsafe.
DefaultBlock | PopUnsafeBlock(_) => false,
}
}
/// Decide if this FnLike is a closure
fn fn_is_closure<'a>(fn_like: FnLikeNode<'a>) -> bool {
match fn_like.kind() {
FnKind::Closure(_) => true,
FnKind::Method(..) | FnKind::ItemFn(..) => false,
}
}
let fn_like = match src {
MirSource::Fn(node_id) => {
match FnLikeNode::from_node(tcx.hir.get(node_id)) {
Some(fn_like) => fn_like,
None => return false, // e.g. struct ctor shims -- such auto-generated code cannot
// contain unsafe.
}
},
_ => return false, // only functions can have unsafe
};
// Test if the function is marked unsafe.
if fn_like.unsafety() == hir::Unsafety::Unsafe {
return true;
}
// For closures, we need to walk up the parents and see if we are inside an unsafe fn or
// unsafe block.
if fn_is_closure(fn_like) {
let mut cur = fn_like.id();
loop {
// Go further upwards.
cur = tcx.hir.get_parent_node(cur);
let node = tcx.hir.get(cur);
// Check if this is an unsafe function
if let Some(fn_like) = FnLikeNode::from_node(node) {
if !fn_is_closure(fn_like) {
if fn_like.unsafety() == hir::Unsafety::Unsafe {
return true;
}
2017-07-31 19:51:10 -07:00
}
}
// Check if this is an unsafe block, or an item
match node {
Node::NodeExpr(&hir::Expr { node: hir::ExprBlock(ref block), ..}) => {
if block_is_unsafe(&*block) {
// Found an unsafe block, we can bail out here.
return true;
2017-07-31 19:51:10 -07:00
}
}
Node::NodeItem(..) => {
// No walking up beyond items. This makes sure the loop always terminates.
break;
2017-07-31 19:51:10 -07:00
}
_ => {},
2017-07-31 19:51:10 -07:00
}
}
}
// Visit the entire body of the function and check for unsafe blocks in there
struct FindUnsafe {
found_unsafe: bool,
}
let mut finder = FindUnsafe { found_unsafe: false };
// Run the visitor on the NodeId we got. Seems like there is no uniform way to do that.
finder.visit_body(tcx.hir.body(fn_like.body()));
impl<'tcx> Visitor<'tcx> for FindUnsafe {
fn nested_visit_map<'this>(&'this mut self) -> intravisit::NestedVisitorMap<'this, 'tcx> {
intravisit::NestedVisitorMap::None
}
fn visit_block(&mut self, b: &'tcx hir::Block) {
if self.found_unsafe { return; } // short-circuit
2017-07-31 19:51:10 -07:00
if block_is_unsafe(b) {
// We found an unsafe block. We can stop searching.
self.found_unsafe = true;
} else {
// No unsafe block here, go on searching.
intravisit::walk_block(self, b);
}
}
}
finder.found_unsafe
}
impl MirPass for AddValidation {
fn run_pass<'a, 'tcx>(&self,
tcx: TyCtxt<'a, 'tcx, 'tcx>,
src: MirSource,
mir: &mut Mir<'tcx>)
{
let emit_validate = tcx.sess.opts.debugging_opts.mir_emit_validate;
if emit_validate == 0 {
2017-07-21 23:18:34 -07:00
return;
}
let restricted_validation = emit_validate == 1 && fn_contains_unsafe(tcx, src);
2017-07-22 01:04:16 -07:00
let local_decls = mir.local_decls.clone(); // FIXME: Find a way to get rid of this clone.
// Convert an lvalue to a validation operand.
let lval_to_operand = |lval: Lvalue<'tcx>| -> ValidationOperand<'tcx, Lvalue<'tcx>> {
let (re, mutbl) = lval_context(&lval, &local_decls, tcx);
let ty = lval.ty(&local_decls, tcx).to_ty(tcx);
ValidationOperand { lval, ty, re, mutbl }
};
2017-07-31 16:15:37 -07:00
// Emit an Acquire at the beginning of the given block. If we are in restricted emission
// mode (mir_emit_validate=1), also emit a Release immediately after the Acquire.
let emit_acquire = |block: &mut BasicBlockData<'tcx>, source_info, operands: Vec<_>| {
if operands.len() == 0 {
return; // Nothing to do
}
// Emit the release first, to avoid cloning if we do not emit it
if restricted_validation {
let release_stmt = Statement {
source_info,
kind: StatementKind::Validate(ValidationOp::Release, operands.clone()),
};
block.statements.insert(0, release_stmt);
}
// Now, the acquire
let acquire_stmt = Statement {
source_info,
kind: StatementKind::Validate(ValidationOp::Acquire, operands),
};
block.statements.insert(0, acquire_stmt);
};
// PART 1
// Add an AcquireValid at the beginning of the start block.
{
let source_info = SourceInfo {
scope: ARGUMENT_VISIBILITY_SCOPE,
span: mir.span, // FIXME: Consider using just the span covering the function
// argument declaration.
};
// Gather all arguments, skip return value.
let operands = mir.local_decls.iter_enumerated().skip(1).take(mir.arg_count)
.map(|(local, _)| lval_to_operand(Lvalue::Local(local))).collect();
emit_acquire(&mut mir.basic_blocks_mut()[START_BLOCK], source_info, operands);
}
// PART 2
// Add ReleaseValid/AcquireValid around function call terminators. We don't use a visitor
// because we need to access the block that a Call jumps to.
let mut returns : Vec<(SourceInfo, Lvalue<'tcx>, BasicBlock)> = Vec::new();
for block_data in mir.basic_blocks_mut() {
match block_data.terminator {
Some(Terminator { kind: TerminatorKind::Call { ref args, ref destination, .. },
source_info }) => {
// Before the call: Release all arguments *and* the return value.
// The callee may write into the return value! Note that this relies
// on "release of uninitialized" to be a NOP.
if !restricted_validation {
let release_stmt = Statement {
source_info,
kind: StatementKind::Validate(ValidationOp::Release,
destination.iter().map(|dest| lval_to_operand(dest.0.clone()))
.chain(
args.iter().filter_map(|op| {
match op {
&Operand::Consume(ref lval) =>
Some(lval_to_operand(lval.clone())),
&Operand::Constant(..) => { None },
}
})
).collect())
};
block_data.statements.push(release_stmt);
}
// Remember the return destination for later
if let &Some(ref destination) = destination {
returns.push((source_info, destination.0.clone(), destination.1));
}
}
Some(Terminator { kind: TerminatorKind::Drop { location: ref lval, .. },
source_info }) |
Some(Terminator { kind: TerminatorKind::DropAndReplace { location: ref lval, .. },
source_info }) => {
2017-07-13 22:10:14 -07:00
// Before the call: Release all arguments
if !restricted_validation {
let release_stmt = Statement {
source_info,
kind: StatementKind::Validate(ValidationOp::Release,
vec![lval_to_operand(lval.clone())]),
};
block_data.statements.push(release_stmt);
}
2017-07-13 22:10:14 -07:00
// drop doesn't return anything, so we need no acquire.
}
_ => {
// Not a block ending in a Call -> ignore.
}
}
}
// Now we go over the returns we collected to acquire the return values.
for (source_info, dest_lval, dest_block) in returns {
emit_acquire(
&mut mir.basic_blocks_mut()[dest_block],
source_info,
vec![lval_to_operand(dest_lval)]
);
}
if restricted_validation {
// No part 3 for us.
return;
}
// PART 3
// Add ReleaseValid/AcquireValid around Ref and Cast. Again an iterator does not seem very
// suited as we need to add new statements before and after each Ref.
for block_data in mir.basic_blocks_mut() {
// We want to insert statements around Ref commands as we iterate. To this end, we
// iterate backwards using indices.
for i in (0..block_data.statements.len()).rev() {
match block_data.statements[i].kind {
// When the borrow of this ref expires, we need to recover validation.
StatementKind::Assign(_, Rvalue::Ref(_, _, _)) => {
// Due to a lack of NLL; we can't capture anything directly here.
// Instead, we have to re-match and clone there.
let (dest_lval, re, src_lval) = match block_data.statements[i].kind {
StatementKind::Assign(ref dest_lval,
Rvalue::Ref(re, _, ref src_lval)) => {
(dest_lval.clone(), re, src_lval.clone())
},
_ => bug!("We already matched this."),
};
// So this is a ref, and we got all the data we wanted.
// Do an acquire of the result -- but only what it points to, so add a Deref
// projection.
let dest_lval = Projection { base: dest_lval, elem: ProjectionElem::Deref };
let dest_lval = Lvalue::Projection(Box::new(dest_lval));
let acquire_stmt = Statement {
source_info: block_data.statements[i].source_info,
kind: StatementKind::Validate(ValidationOp::Acquire,
vec![lval_to_operand(dest_lval)]),
};
block_data.statements.insert(i+1, acquire_stmt);
// The source is released until the region of the borrow ends.
let op = match re {
&RegionKind::ReScope(ce) => ValidationOp::Suspend(ce),
&RegionKind::ReErased =>
bug!("AddValidation pass must be run before erasing lifetimes"),
_ => ValidationOp::Release,
};
let release_stmt = Statement {
source_info: block_data.statements[i].source_info,
kind: StatementKind::Validate(op, vec![lval_to_operand(src_lval)]),
};
block_data.statements.insert(i, release_stmt);
}
// Casts can change what validation does (e.g. unsizing)
StatementKind::Assign(_, Rvalue::Cast(kind, Operand::Consume(_), _))
if kind != CastKind::Misc =>
{
// Due to a lack of NLL; we can't capture anything directly here.
// Instead, we have to re-match and clone there.
let (dest_lval, src_lval) = match block_data.statements[i].kind {
StatementKind::Assign(ref dest_lval,
Rvalue::Cast(_, Operand::Consume(ref src_lval), _)) =>
{
(dest_lval.clone(), src_lval.clone())
},
_ => bug!("We already matched this."),
};
// Acquire of the result
let acquire_stmt = Statement {
source_info: block_data.statements[i].source_info,
kind: StatementKind::Validate(ValidationOp::Acquire,
vec![lval_to_operand(dest_lval)]),
};
block_data.statements.insert(i+1, acquire_stmt);
// Release of the input
let release_stmt = Statement {
source_info: block_data.statements[i].source_info,
kind: StatementKind::Validate(ValidationOp::Release,
vec![lval_to_operand(src_lval)]),
};
block_data.statements.insert(i, release_stmt);
}
_ => {},
}
}
}
}
}