Use an interpreter in jump threading.
This commit is contained in:
parent
25f8d01fd8
commit
be9668d398
@ -36,16 +36,21 @@
|
||||
//! cost by `MAX_COST`.
|
||||
|
||||
use rustc_arena::DroplessArena;
|
||||
use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
|
||||
use rustc_data_structures::fx::FxHashSet;
|
||||
use rustc_index::bit_set::BitSet;
|
||||
use rustc_index::IndexVec;
|
||||
use rustc_middle::mir::interpret::Scalar;
|
||||
use rustc_middle::mir::visit::Visitor;
|
||||
use rustc_middle::mir::*;
|
||||
use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt};
|
||||
use rustc_middle::ty::layout::LayoutOf;
|
||||
use rustc_middle::ty::{self, ScalarInt, TyCtxt};
|
||||
use rustc_mir_dataflow::value_analysis::{Map, PlaceIndex, State, TrackElem};
|
||||
use rustc_span::DUMMY_SP;
|
||||
use rustc_target::abi::{TagEncoding, Variants};
|
||||
|
||||
use crate::cost_checker::CostChecker;
|
||||
use crate::dataflow_const_prop::DummyMachine;
|
||||
|
||||
pub struct JumpThreading;
|
||||
|
||||
@ -71,6 +76,7 @@ fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
|
||||
let mut finder = TOFinder {
|
||||
tcx,
|
||||
param_env,
|
||||
ecx: InterpCx::new(tcx, DUMMY_SP, param_env, DummyMachine),
|
||||
body,
|
||||
arena: &arena,
|
||||
map: &map,
|
||||
@ -88,7 +94,7 @@ fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
|
||||
debug!(?discr, ?bb);
|
||||
|
||||
let discr_ty = discr.ty(body, tcx).ty;
|
||||
let Ok(discr_layout) = tcx.layout_of(param_env.and(discr_ty)) else { continue };
|
||||
let Ok(discr_layout) = finder.ecx.layout_of(discr_ty) else { continue };
|
||||
|
||||
let Some(discr) = finder.map.find(discr.as_ref()) else { continue };
|
||||
debug!(?discr);
|
||||
@ -142,6 +148,7 @@ struct ThreadingOpportunity {
|
||||
struct TOFinder<'tcx, 'a> {
|
||||
tcx: TyCtxt<'tcx>,
|
||||
param_env: ty::ParamEnv<'tcx>,
|
||||
ecx: InterpCx<'tcx, 'tcx, DummyMachine>,
|
||||
body: &'a Body<'tcx>,
|
||||
map: &'a Map,
|
||||
loop_headers: &'a BitSet<BasicBlock>,
|
||||
@ -329,11 +336,11 @@ fn mutated_statement(
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip(self))]
|
||||
fn process_operand(
|
||||
fn process_immediate(
|
||||
&mut self,
|
||||
bb: BasicBlock,
|
||||
lhs: PlaceIndex,
|
||||
rhs: &Operand<'tcx>,
|
||||
rhs: ImmTy<'tcx>,
|
||||
state: &mut State<ConditionSet<'a>>,
|
||||
) -> Option<!> {
|
||||
let register_opportunity = |c: Condition| {
|
||||
@ -341,13 +348,60 @@ fn process_operand(
|
||||
self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
|
||||
};
|
||||
|
||||
let conditions = state.try_get_idx(lhs, self.map)?;
|
||||
if let Immediate::Scalar(Scalar::Int(int)) = *rhs {
|
||||
conditions.iter_matches(int).for_each(register_opportunity);
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip(self))]
|
||||
fn process_operand(
|
||||
&mut self,
|
||||
bb: BasicBlock,
|
||||
lhs: PlaceIndex,
|
||||
rhs: &Operand<'tcx>,
|
||||
state: &mut State<ConditionSet<'a>>,
|
||||
) -> Option<!> {
|
||||
match rhs {
|
||||
// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
|
||||
Operand::Constant(constant) => {
|
||||
let conditions = state.try_get_idx(lhs, self.map)?;
|
||||
let constant =
|
||||
constant.const_.normalize(self.tcx, self.param_env).try_to_scalar_int()?;
|
||||
conditions.iter_matches(constant).for_each(register_opportunity);
|
||||
let constant = self.ecx.eval_mir_constant(&constant.const_, None, None).ok()?;
|
||||
self.map.for_each_projection_value(
|
||||
lhs,
|
||||
constant,
|
||||
&mut |elem, op| match elem {
|
||||
TrackElem::Field(idx) => self.ecx.project_field(op, idx.as_usize()).ok(),
|
||||
TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).ok(),
|
||||
TrackElem::Discriminant => {
|
||||
let variant = self.ecx.read_discriminant(op).ok()?;
|
||||
let discr_value =
|
||||
self.ecx.discriminant_for_variant(op.layout.ty, variant).ok()?;
|
||||
Some(discr_value.into())
|
||||
}
|
||||
TrackElem::DerefLen => {
|
||||
let op: OpTy<'_> = self.ecx.deref_pointer(op).ok()?.into();
|
||||
let len_usize = op.len(&self.ecx).ok()?;
|
||||
let layout = self.ecx.layout_of(self.tcx.types.usize).unwrap();
|
||||
Some(ImmTy::from_uint(len_usize, layout).into())
|
||||
}
|
||||
},
|
||||
&mut |place, op| {
|
||||
if let Some(conditions) = state.try_get_idx(place, self.map)
|
||||
&& let Ok(imm) = self.ecx.read_immediate_raw(op)
|
||||
&& let Some(imm) = imm.right()
|
||||
&& let Immediate::Scalar(Scalar::Int(int)) = *imm
|
||||
{
|
||||
conditions.iter_matches(int).for_each(|c: Condition| {
|
||||
self.opportunities.push(ThreadingOpportunity {
|
||||
chain: vec![bb],
|
||||
target: c.target,
|
||||
})
|
||||
})
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
// Transfer the conditions on the copied rhs.
|
||||
Operand::Move(rhs) | Operand::Copy(rhs) => {
|
||||
@ -374,18 +428,6 @@ fn process_statement(
|
||||
// Below, `lhs` is the return value of `mutated_statement`,
|
||||
// the place to which `conditions` apply.
|
||||
|
||||
let discriminant_for_variant = |enum_ty: Ty<'tcx>, variant_index| {
|
||||
let discr = enum_ty.discriminant_for_variant(self.tcx, variant_index)?;
|
||||
let discr_layout = self.tcx.layout_of(self.param_env.and(discr.ty)).ok()?;
|
||||
let scalar = ScalarInt::try_from_uint(discr.val, discr_layout.size)?;
|
||||
Some(Operand::const_from_scalar(
|
||||
self.tcx,
|
||||
discr.ty,
|
||||
scalar.into(),
|
||||
rustc_span::DUMMY_SP,
|
||||
))
|
||||
};
|
||||
|
||||
match &stmt.kind {
|
||||
// If we expect `discriminant(place) ?= A`,
|
||||
// we have an opportunity if `variant_index ?= A`.
|
||||
@ -395,7 +437,7 @@ fn process_statement(
|
||||
// `SetDiscriminant` may be a no-op if the assigned variant is the untagged variant
|
||||
// of a niche encoding. If we cannot ensure that we write to the discriminant, do
|
||||
// nothing.
|
||||
let enum_layout = self.tcx.layout_of(self.param_env.and(enum_ty)).ok()?;
|
||||
let enum_layout = self.ecx.layout_of(enum_ty).ok()?;
|
||||
let writes_discriminant = match enum_layout.variants {
|
||||
Variants::Single { index } => {
|
||||
assert_eq!(index, *variant_index);
|
||||
@ -408,8 +450,8 @@ fn process_statement(
|
||||
} => *variant_index != untagged_variant,
|
||||
};
|
||||
if writes_discriminant {
|
||||
let discr = discriminant_for_variant(enum_ty, *variant_index)?;
|
||||
self.process_operand(bb, discr_target, &discr, state)?;
|
||||
let discr = self.ecx.discriminant_for_variant(enum_ty, *variant_index).ok()?;
|
||||
self.process_immediate(bb, discr_target, discr, state)?;
|
||||
}
|
||||
}
|
||||
// If we expect `lhs ?= true`, we have an opportunity if we assume `lhs == true`.
|
||||
@ -440,10 +482,16 @@ fn process_statement(
|
||||
AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
|
||||
if let Some(discr_target) =
|
||||
self.map.apply(lhs, TrackElem::Discriminant)
|
||||
&& let Some(discr_value) =
|
||||
discriminant_for_variant(agg_ty, *variant_index)
|
||||
&& let Ok(discr_value) = self
|
||||
.ecx
|
||||
.discriminant_for_variant(agg_ty, *variant_index)
|
||||
{
|
||||
self.process_operand(bb, discr_target, &discr_value, state);
|
||||
self.process_immediate(
|
||||
bb,
|
||||
discr_target,
|
||||
discr_value,
|
||||
state,
|
||||
);
|
||||
}
|
||||
self.map.apply(lhs, TrackElem::Variant(*variant_index))?
|
||||
}
|
||||
@ -577,7 +625,7 @@ fn process_switch_int(
|
||||
|
||||
let discr = discr.place()?;
|
||||
let discr_ty = discr.ty(self.body, self.tcx).ty;
|
||||
let discr_layout = self.tcx.layout_of(self.param_env.and(discr_ty)).ok()?;
|
||||
let discr_layout = self.ecx.layout_of(discr_ty).ok()?;
|
||||
let conditions = state.try_get(discr.as_ref(), self.map)?;
|
||||
|
||||
if let Some((value, _)) = targets.iter().find(|&(_, target)| target == target_bb) {
|
||||
|
@ -0,0 +1,52 @@
|
||||
- // MIR for `aggregate` before JumpThreading
|
||||
+ // MIR for `aggregate` after JumpThreading
|
||||
|
||||
fn aggregate(_1: u8) -> u8 {
|
||||
debug x => _1;
|
||||
let mut _0: u8;
|
||||
let _2: u8;
|
||||
let _3: u8;
|
||||
let mut _4: (u8, u8);
|
||||
let mut _5: bool;
|
||||
let mut _6: u8;
|
||||
scope 1 {
|
||||
debug a => _2;
|
||||
debug b => _3;
|
||||
}
|
||||
|
||||
bb0: {
|
||||
StorageLive(_4);
|
||||
_4 = const _;
|
||||
StorageLive(_2);
|
||||
_2 = (_4.0: u8);
|
||||
StorageLive(_3);
|
||||
_3 = (_4.1: u8);
|
||||
StorageDead(_4);
|
||||
StorageLive(_5);
|
||||
StorageLive(_6);
|
||||
_6 = _2;
|
||||
_5 = Eq(move _6, const 7_u8);
|
||||
- switchInt(move _5) -> [0: bb2, otherwise: bb1];
|
||||
+ goto -> bb2;
|
||||
}
|
||||
|
||||
bb1: {
|
||||
StorageDead(_6);
|
||||
_0 = _3;
|
||||
goto -> bb3;
|
||||
}
|
||||
|
||||
bb2: {
|
||||
StorageDead(_6);
|
||||
_0 = _2;
|
||||
goto -> bb3;
|
||||
}
|
||||
|
||||
bb3: {
|
||||
StorageDead(_5);
|
||||
StorageDead(_3);
|
||||
StorageDead(_2);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,52 @@
|
||||
- // MIR for `aggregate` before JumpThreading
|
||||
+ // MIR for `aggregate` after JumpThreading
|
||||
|
||||
fn aggregate(_1: u8) -> u8 {
|
||||
debug x => _1;
|
||||
let mut _0: u8;
|
||||
let _2: u8;
|
||||
let _3: u8;
|
||||
let mut _4: (u8, u8);
|
||||
let mut _5: bool;
|
||||
let mut _6: u8;
|
||||
scope 1 {
|
||||
debug a => _2;
|
||||
debug b => _3;
|
||||
}
|
||||
|
||||
bb0: {
|
||||
StorageLive(_4);
|
||||
_4 = const _;
|
||||
StorageLive(_2);
|
||||
_2 = (_4.0: u8);
|
||||
StorageLive(_3);
|
||||
_3 = (_4.1: u8);
|
||||
StorageDead(_4);
|
||||
StorageLive(_5);
|
||||
StorageLive(_6);
|
||||
_6 = _2;
|
||||
_5 = Eq(move _6, const 7_u8);
|
||||
- switchInt(move _5) -> [0: bb2, otherwise: bb1];
|
||||
+ goto -> bb2;
|
||||
}
|
||||
|
||||
bb1: {
|
||||
StorageDead(_6);
|
||||
_0 = _3;
|
||||
goto -> bb3;
|
||||
}
|
||||
|
||||
bb2: {
|
||||
StorageDead(_6);
|
||||
_0 = _2;
|
||||
goto -> bb3;
|
||||
}
|
||||
|
||||
bb3: {
|
||||
StorageDead(_5);
|
||||
StorageDead(_3);
|
||||
StorageDead(_2);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
@ -453,7 +453,23 @@ fn disappearing_bb(x: u8) -> u8 {
|
||||
)
|
||||
}
|
||||
|
||||
/// Verify that we can thread jumps when we assign from an aggregate constant.
|
||||
fn aggregate(x: u8) -> u8 {
|
||||
// CHECK-LABEL: fn aggregate(
|
||||
// CHECK-NOT: switchInt(
|
||||
|
||||
const FOO: (u8, u8) = (5, 13);
|
||||
|
||||
let (a, b) = FOO;
|
||||
if a == 7 {
|
||||
b
|
||||
} else {
|
||||
a
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
// CHECK-LABEL: fn main(
|
||||
too_complex(Ok(0));
|
||||
identity(Ok(0));
|
||||
custom_discr(false);
|
||||
@ -464,6 +480,7 @@ fn main() {
|
||||
mutable_ref();
|
||||
renumbered_bb(true);
|
||||
disappearing_bb(7);
|
||||
aggregate(7);
|
||||
}
|
||||
|
||||
// EMIT_MIR jump_threading.too_complex.JumpThreading.diff
|
||||
@ -476,3 +493,4 @@ fn main() {
|
||||
// EMIT_MIR jump_threading.mutable_ref.JumpThreading.diff
|
||||
// EMIT_MIR jump_threading.renumbered_bb.JumpThreading.diff
|
||||
// EMIT_MIR jump_threading.disappearing_bb.JumpThreading.diff
|
||||
// EMIT_MIR jump_threading.aggregate.JumpThreading.diff
|
||||
|
Loading…
Reference in New Issue
Block a user