Add comments

Still need to make it so that it maps discriminants to variant indexes.
Maybe instead I can map the variant indexes to discriminants?
This commit is contained in:
kadmin 2021-05-28 04:17:00 +00:00
parent 18144b66e1
commit 96db5e9c7b

View File

@ -4,7 +4,7 @@ use rustc_data_structures::stable_map::FxHashMap;
use rustc_middle::mir::*; use rustc_middle::mir::*;
use rustc_middle::ty::{self, Const, List, Ty, TyCtxt}; use rustc_middle::ty::{self, Const, List, Ty, TyCtxt};
use rustc_span::def_id::DefId; use rustc_span::def_id::DefId;
use rustc_target::abi::{Size, Variants}; use rustc_target::abi::{Size, TagEncoding, Variants};
/// A pass that seeks to optimize unnecessary moves of large enum types, if there is a large /// A pass that seeks to optimize unnecessary moves of large enum types, if there is a large
/// enough discrepanc between them /// enough discrepanc between them
@ -31,17 +31,25 @@ impl<const D: u64> EnumSizeOpt<D> {
match variants { match variants {
Variants::Single { .. } => None, Variants::Single { .. } => None,
Variants::Multiple { variants, .. } if variants.len() <= 1 => None, Variants::Multiple { variants, .. } if variants.len() <= 1 => None,
Variants::Multiple { tag_encoding, .. }
if matches!(tag_encoding, TagEncoding::Niche { .. }) =>
{
None
}
Variants::Multiple { variants, .. } => { Variants::Multiple { variants, .. } => {
let min = variants.iter().map(|v| v.size).min().unwrap(); let min = variants.iter().map(|v| v.size).min().unwrap();
let max = variants.iter().map(|v| v.size).max().unwrap(); let max = variants.iter().map(|v| v.size).max().unwrap();
if max.bytes() - min.bytes() < D { if max.bytes() - min.bytes() < D {
return None; return None;
} }
Some(( let mut discr_sizes = vec![Size::ZERO; adt_def.discriminants(tcx).count()];
layout.size, for (var_idx, layout) in variants.iter_enumerated() {
variants.len() as u64, let disc_idx =
variants.iter().map(|v| v.size).collect(), adt_def.discriminant_for_variant(tcx, var_idx).val as usize;
)) assert_eq!(discr_sizes[disc_idx], Size::ZERO);
discr_sizes[disc_idx] = layout.size;
}
Some((layout.size, variants.len() as u64, discr_sizes))
} }
} }
} }
@ -49,7 +57,7 @@ impl<const D: u64> EnumSizeOpt<D> {
} }
} }
fn optim(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { fn optim(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let mut match_cache = FxHashMap::default(); let mut alloc_cache = FxHashMap::default();
let body_did = body.source.def_id(); let body_did = body.source.def_id();
let mut patch = MirPatch::new(body); let mut patch = MirPatch::new(body);
let (bbs, local_decls) = body.basic_blocks_and_local_decls_mut(); let (bbs, local_decls) = body.basic_blocks_and_local_decls_mut();
@ -61,39 +69,45 @@ impl<const D: u64> EnumSizeOpt<D> {
Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)), Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)),
)) => { )) => {
let ty = lhs.ty(local_decls, tcx).ty; let ty = lhs.ty(local_decls, tcx).ty;
let source_info = st.source_info;
let span = source_info.span;
let (total_size, num_variants, sizes) = let (total_size, num_variants, sizes) =
if let Some((ts, nv, s)) = match_cache.get(ty) { if let Some((ts, nv, s)) = Self::candidate(tcx, ty, body_did) {
(*ts, *nv, s) (ts, nv, s)
} else if let Some((ts, nv, s)) = Self::candidate(tcx, ty, body_did) {
// FIXME(jknodt) use entry API.
match_cache.insert(ty, (ts, nv, s));
let (ts, nv, s) = match_cache.get(ty).unwrap();
(*ts, *nv, s)
} else { } else {
return None; return None;
}; };
let source_info = st.source_info; let alloc = if let Some(alloc) = alloc_cache.get(ty) {
let span = source_info.span; alloc
} else {
let mut data =
vec![0; std::mem::size_of::<usize>() * num_variants as usize];
data.copy_from_slice(unsafe { std::mem::transmute(&sizes[..]) });
let alloc = interpret::Allocation::from_bytes(
data,
tcx.data_layout.ptr_sized_integer().align(&tcx.data_layout).abi,
Mutability::Not,
);
let alloc = tcx.intern_const_alloc(alloc);
alloc_cache.insert(ty, alloc);
// FIXME(jknodt) use entry API
alloc_cache.get(ty).unwrap()
};
let tmp_ty = tcx.mk_ty(ty::Array( let tmp_ty = tcx.mk_ty(ty::Array(
tcx.types.usize, tcx.types.usize,
Const::from_usize(tcx, num_variants), Const::from_usize(tcx, num_variants),
)); ));
let new_local = patch.new_temp(tmp_ty, span); let size_array_local = patch.new_temp(tmp_ty, span);
let store_live = let store_live = Statement {
Statement { source_info, kind: StatementKind::StorageLive(new_local) }; source_info,
kind: StatementKind::StorageLive(size_array_local),
};
let place = Place { local: new_local, projection: List::empty() }; let place = Place { local: size_array_local, projection: List::empty() };
let mut data =
vec![0; std::mem::size_of::<usize>() * num_variants as usize];
data.copy_from_slice(unsafe { std::mem::transmute(&sizes[..]) });
let alloc = interpret::Allocation::from_bytes(
data,
tcx.data_layout.ptr_sized_integer().align(&tcx.data_layout).abi,
);
let alloc = tcx.intern_const_alloc(alloc);
let constant_vals = Constant { let constant_vals = Constant {
span, span,
user_ty: None, user_ty: None,
@ -134,9 +148,9 @@ impl<const D: u64> EnumSizeOpt<D> {
kind: StatementKind::Assign(box ( kind: StatementKind::Assign(box (
size_place, size_place,
Rvalue::Use(Operand::Copy(Place { Rvalue::Use(Operand::Copy(Place {
local: discr_place.local, local: size_array_local,
projection: tcx projection: tcx
.intern_place_elems(&[PlaceElem::Index(size_place.local)]), .intern_place_elems(&[PlaceElem::Index(discr_place.local)]),
})), })),
)), )),
}; };
@ -187,8 +201,10 @@ impl<const D: u64> EnumSizeOpt<D> {
}), }),
}; };
let store_dead = let store_dead = Statement {
Statement { source_info, kind: StatementKind::StorageDead(new_local) }; source_info,
kind: StatementKind::StorageDead(size_array_local),
};
let iter = std::array::IntoIter::new([ let iter = std::array::IntoIter::new([
store_live, store_live,
const_assign, const_assign,