Use an interpreter in jump threading.

This commit is contained in:
Camille GILLOT 2023-12-31 01:53:51 +00:00
parent 25f8d01fd8
commit be9668d398
4 changed files with 197 additions and 27 deletions

View File

@ -36,16 +36,21 @@
//! cost by `MAX_COST`.
use rustc_arena::DroplessArena;
use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
use rustc_data_structures::fx::FxHashSet;
use rustc_index::bit_set::BitSet;
use rustc_index::IndexVec;
use rustc_middle::mir::interpret::Scalar;
use rustc_middle::mir::visit::Visitor;
use rustc_middle::mir::*;
use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt};
use rustc_middle::ty::layout::LayoutOf;
use rustc_middle::ty::{self, ScalarInt, TyCtxt};
use rustc_mir_dataflow::value_analysis::{Map, PlaceIndex, State, TrackElem};
use rustc_span::DUMMY_SP;
use rustc_target::abi::{TagEncoding, Variants};
use crate::cost_checker::CostChecker;
use crate::dataflow_const_prop::DummyMachine;
pub struct JumpThreading;
@ -71,6 +76,7 @@ fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let mut finder = TOFinder {
tcx,
param_env,
ecx: InterpCx::new(tcx, DUMMY_SP, param_env, DummyMachine),
body,
arena: &arena,
map: &map,
@ -88,7 +94,7 @@ fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
debug!(?discr, ?bb);
let discr_ty = discr.ty(body, tcx).ty;
let Ok(discr_layout) = tcx.layout_of(param_env.and(discr_ty)) else { continue };
let Ok(discr_layout) = finder.ecx.layout_of(discr_ty) else { continue };
let Some(discr) = finder.map.find(discr.as_ref()) else { continue };
debug!(?discr);
@ -142,6 +148,7 @@ struct ThreadingOpportunity {
struct TOFinder<'tcx, 'a> {
tcx: TyCtxt<'tcx>,
param_env: ty::ParamEnv<'tcx>,
ecx: InterpCx<'tcx, 'tcx, DummyMachine>,
body: &'a Body<'tcx>,
map: &'a Map,
loop_headers: &'a BitSet<BasicBlock>,
@ -329,11 +336,11 @@ fn mutated_statement(
}
#[instrument(level = "trace", skip(self))]
fn process_operand(
fn process_immediate(
&mut self,
bb: BasicBlock,
lhs: PlaceIndex,
rhs: &Operand<'tcx>,
rhs: ImmTy<'tcx>,
state: &mut State<ConditionSet<'a>>,
) -> Option<!> {
let register_opportunity = |c: Condition| {
@ -341,13 +348,60 @@ fn process_operand(
self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
};
let conditions = state.try_get_idx(lhs, self.map)?;
if let Immediate::Scalar(Scalar::Int(int)) = *rhs {
conditions.iter_matches(int).for_each(register_opportunity);
}
None
}
#[instrument(level = "trace", skip(self))]
fn process_operand(
&mut self,
bb: BasicBlock,
lhs: PlaceIndex,
rhs: &Operand<'tcx>,
state: &mut State<ConditionSet<'a>>,
) -> Option<!> {
match rhs {
// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
Operand::Constant(constant) => {
let conditions = state.try_get_idx(lhs, self.map)?;
let constant =
constant.const_.normalize(self.tcx, self.param_env).try_to_scalar_int()?;
conditions.iter_matches(constant).for_each(register_opportunity);
let constant = self.ecx.eval_mir_constant(&constant.const_, None, None).ok()?;
self.map.for_each_projection_value(
lhs,
constant,
&mut |elem, op| match elem {
TrackElem::Field(idx) => self.ecx.project_field(op, idx.as_usize()).ok(),
TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).ok(),
TrackElem::Discriminant => {
let variant = self.ecx.read_discriminant(op).ok()?;
let discr_value =
self.ecx.discriminant_for_variant(op.layout.ty, variant).ok()?;
Some(discr_value.into())
}
TrackElem::DerefLen => {
let op: OpTy<'_> = self.ecx.deref_pointer(op).ok()?.into();
let len_usize = op.len(&self.ecx).ok()?;
let layout = self.ecx.layout_of(self.tcx.types.usize).unwrap();
Some(ImmTy::from_uint(len_usize, layout).into())
}
},
&mut |place, op| {
if let Some(conditions) = state.try_get_idx(place, self.map)
&& let Ok(imm) = self.ecx.read_immediate_raw(op)
&& let Some(imm) = imm.right()
&& let Immediate::Scalar(Scalar::Int(int)) = *imm
{
conditions.iter_matches(int).for_each(|c: Condition| {
self.opportunities.push(ThreadingOpportunity {
chain: vec![bb],
target: c.target,
})
})
}
},
);
}
// Transfer the conditions on the copied rhs.
Operand::Move(rhs) | Operand::Copy(rhs) => {
@ -374,18 +428,6 @@ fn process_statement(
// Below, `lhs` is the return value of `mutated_statement`,
// the place to which `conditions` apply.
let discriminant_for_variant = |enum_ty: Ty<'tcx>, variant_index| {
let discr = enum_ty.discriminant_for_variant(self.tcx, variant_index)?;
let discr_layout = self.tcx.layout_of(self.param_env.and(discr.ty)).ok()?;
let scalar = ScalarInt::try_from_uint(discr.val, discr_layout.size)?;
Some(Operand::const_from_scalar(
self.tcx,
discr.ty,
scalar.into(),
rustc_span::DUMMY_SP,
))
};
match &stmt.kind {
// If we expect `discriminant(place) ?= A`,
// we have an opportunity if `variant_index ?= A`.
@ -395,7 +437,7 @@ fn process_statement(
// `SetDiscriminant` may be a no-op if the assigned variant is the untagged variant
// of a niche encoding. If we cannot ensure that we write to the discriminant, do
// nothing.
let enum_layout = self.tcx.layout_of(self.param_env.and(enum_ty)).ok()?;
let enum_layout = self.ecx.layout_of(enum_ty).ok()?;
let writes_discriminant = match enum_layout.variants {
Variants::Single { index } => {
assert_eq!(index, *variant_index);
@ -408,8 +450,8 @@ fn process_statement(
} => *variant_index != untagged_variant,
};
if writes_discriminant {
let discr = discriminant_for_variant(enum_ty, *variant_index)?;
self.process_operand(bb, discr_target, &discr, state)?;
let discr = self.ecx.discriminant_for_variant(enum_ty, *variant_index).ok()?;
self.process_immediate(bb, discr_target, discr, state)?;
}
}
// If we expect `lhs ?= true`, we have an opportunity if we assume `lhs == true`.
@ -440,10 +482,16 @@ fn process_statement(
AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
if let Some(discr_target) =
self.map.apply(lhs, TrackElem::Discriminant)
&& let Some(discr_value) =
discriminant_for_variant(agg_ty, *variant_index)
&& let Ok(discr_value) = self
.ecx
.discriminant_for_variant(agg_ty, *variant_index)
{
self.process_operand(bb, discr_target, &discr_value, state);
self.process_immediate(
bb,
discr_target,
discr_value,
state,
);
}
self.map.apply(lhs, TrackElem::Variant(*variant_index))?
}
@ -577,7 +625,7 @@ fn process_switch_int(
let discr = discr.place()?;
let discr_ty = discr.ty(self.body, self.tcx).ty;
let discr_layout = self.tcx.layout_of(self.param_env.and(discr_ty)).ok()?;
let discr_layout = self.ecx.layout_of(discr_ty).ok()?;
let conditions = state.try_get(discr.as_ref(), self.map)?;
if let Some((value, _)) = targets.iter().find(|&(_, target)| target == target_bb) {

View File

@ -0,0 +1,52 @@
- // MIR for `aggregate` before JumpThreading
+ // MIR for `aggregate` after JumpThreading
fn aggregate(_1: u8) -> u8 {
debug x => _1;
let mut _0: u8;
let _2: u8;
let _3: u8;
let mut _4: (u8, u8);
let mut _5: bool;
let mut _6: u8;
scope 1 {
debug a => _2;
debug b => _3;
}
bb0: {
StorageLive(_4);
_4 = const _;
StorageLive(_2);
_2 = (_4.0: u8);
StorageLive(_3);
_3 = (_4.1: u8);
StorageDead(_4);
StorageLive(_5);
StorageLive(_6);
_6 = _2;
_5 = Eq(move _6, const 7_u8);
- switchInt(move _5) -> [0: bb2, otherwise: bb1];
+ goto -> bb2;
}
bb1: {
StorageDead(_6);
_0 = _3;
goto -> bb3;
}
bb2: {
StorageDead(_6);
_0 = _2;
goto -> bb3;
}
bb3: {
StorageDead(_5);
StorageDead(_3);
StorageDead(_2);
return;
}
}

View File

@ -0,0 +1,52 @@
- // MIR for `aggregate` before JumpThreading
+ // MIR for `aggregate` after JumpThreading
fn aggregate(_1: u8) -> u8 {
debug x => _1;
let mut _0: u8;
let _2: u8;
let _3: u8;
let mut _4: (u8, u8);
let mut _5: bool;
let mut _6: u8;
scope 1 {
debug a => _2;
debug b => _3;
}
bb0: {
StorageLive(_4);
_4 = const _;
StorageLive(_2);
_2 = (_4.0: u8);
StorageLive(_3);
_3 = (_4.1: u8);
StorageDead(_4);
StorageLive(_5);
StorageLive(_6);
_6 = _2;
_5 = Eq(move _6, const 7_u8);
- switchInt(move _5) -> [0: bb2, otherwise: bb1];
+ goto -> bb2;
}
bb1: {
StorageDead(_6);
_0 = _3;
goto -> bb3;
}
bb2: {
StorageDead(_6);
_0 = _2;
goto -> bb3;
}
bb3: {
StorageDead(_5);
StorageDead(_3);
StorageDead(_2);
return;
}
}

View File

@ -453,7 +453,23 @@ fn disappearing_bb(x: u8) -> u8 {
)
}
/// Verify that we can thread jumps when we assign from an aggregate constant.
fn aggregate(x: u8) -> u8 {
// CHECK-LABEL: fn aggregate(
// CHECK-NOT: switchInt(
const FOO: (u8, u8) = (5, 13);
let (a, b) = FOO;
if a == 7 {
b
} else {
a
}
}
fn main() {
// CHECK-LABEL: fn main(
too_complex(Ok(0));
identity(Ok(0));
custom_discr(false);
@ -464,6 +480,7 @@ fn main() {
mutable_ref();
renumbered_bb(true);
disappearing_bb(7);
aggregate(7);
}
// EMIT_MIR jump_threading.too_complex.JumpThreading.diff
@ -476,3 +493,4 @@ fn main() {
// EMIT_MIR jump_threading.mutable_ref.JumpThreading.diff
// EMIT_MIR jump_threading.renumbered_bb.JumpThreading.diff
// EMIT_MIR jump_threading.disappearing_bb.JumpThreading.diff
// EMIT_MIR jump_threading.aggregate.JumpThreading.diff