diff --git a/compiler/rustc_mir/src/transform/large_enums.rs b/compiler/rustc_mir/src/transform/large_enums.rs index b742b7a45e6..a8377c95dcb 100644 --- a/compiler/rustc_mir/src/transform/large_enums.rs +++ b/compiler/rustc_mir/src/transform/large_enums.rs @@ -4,7 +4,7 @@ use rustc_data_structures::stable_map::FxHashMap; use rustc_middle::mir::*; use rustc_middle::ty::{self, Const, List, Ty, TyCtxt}; use rustc_span::def_id::DefId; -use rustc_target::abi::{Size, Variants}; +use rustc_target::abi::{Size, TagEncoding, Variants}; /// A pass that seeks to optimize unnecessary moves of large enum types, if there is a large /// enough discrepanc between them @@ -31,17 +31,25 @@ impl EnumSizeOpt { match variants { Variants::Single { .. } => None, Variants::Multiple { variants, .. } if variants.len() <= 1 => None, + Variants::Multiple { tag_encoding, .. } + if matches!(tag_encoding, TagEncoding::Niche { .. }) => + { + None + } Variants::Multiple { variants, .. } => { let min = variants.iter().map(|v| v.size).min().unwrap(); let max = variants.iter().map(|v| v.size).max().unwrap(); if max.bytes() - min.bytes() < D { return None; } - Some(( - layout.size, - variants.len() as u64, - variants.iter().map(|v| v.size).collect(), - )) + let mut discr_sizes = vec![Size::ZERO; adt_def.discriminants(tcx).count()]; + for (var_idx, layout) in variants.iter_enumerated() { + let disc_idx = + adt_def.discriminant_for_variant(tcx, var_idx).val as usize; + assert_eq!(discr_sizes[disc_idx], Size::ZERO); + discr_sizes[disc_idx] = layout.size; + } + Some((layout.size, variants.len() as u64, discr_sizes)) } } } @@ -49,7 +57,7 @@ impl EnumSizeOpt { } } fn optim(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - let mut match_cache = FxHashMap::default(); + let mut alloc_cache = FxHashMap::default(); let body_did = body.source.def_id(); let mut patch = MirPatch::new(body); let (bbs, local_decls) = body.basic_blocks_and_local_decls_mut(); @@ -61,39 +69,45 @@ impl EnumSizeOpt { Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)), )) => { let ty = lhs.ty(local_decls, tcx).ty; + let source_info = st.source_info; + let span = source_info.span; + let (total_size, num_variants, sizes) = - if let Some((ts, nv, s)) = match_cache.get(ty) { - (*ts, *nv, s) - } else if let Some((ts, nv, s)) = Self::candidate(tcx, ty, body_did) { - // FIXME(jknodt) use entry API. - match_cache.insert(ty, (ts, nv, s)); - let (ts, nv, s) = match_cache.get(ty).unwrap(); - (*ts, *nv, s) + if let Some((ts, nv, s)) = Self::candidate(tcx, ty, body_did) { + (ts, nv, s) } else { return None; }; - let source_info = st.source_info; - let span = source_info.span; + let alloc = if let Some(alloc) = alloc_cache.get(ty) { + alloc + } else { + let mut data = + vec![0; std::mem::size_of::() * num_variants as usize]; + data.copy_from_slice(unsafe { std::mem::transmute(&sizes[..]) }); + let alloc = interpret::Allocation::from_bytes( + data, + tcx.data_layout.ptr_sized_integer().align(&tcx.data_layout).abi, + Mutability::Not, + ); + let alloc = tcx.intern_const_alloc(alloc); + alloc_cache.insert(ty, alloc); + // FIXME(jknodt) use entry API + alloc_cache.get(ty).unwrap() + }; let tmp_ty = tcx.mk_ty(ty::Array( tcx.types.usize, Const::from_usize(tcx, num_variants), )); - let new_local = patch.new_temp(tmp_ty, span); - let store_live = - Statement { source_info, kind: StatementKind::StorageLive(new_local) }; + let size_array_local = patch.new_temp(tmp_ty, span); + let store_live = Statement { + source_info, + kind: StatementKind::StorageLive(size_array_local), + }; - let place = Place { local: new_local, projection: List::empty() }; - let mut data = - vec![0; std::mem::size_of::() * num_variants as usize]; - data.copy_from_slice(unsafe { std::mem::transmute(&sizes[..]) }); - let alloc = interpret::Allocation::from_bytes( - data, - tcx.data_layout.ptr_sized_integer().align(&tcx.data_layout).abi, - ); - let alloc = tcx.intern_const_alloc(alloc); + let place = Place { local: size_array_local, projection: List::empty() }; let constant_vals = Constant { span, user_ty: None, @@ -134,9 +148,9 @@ impl EnumSizeOpt { kind: StatementKind::Assign(box ( size_place, Rvalue::Use(Operand::Copy(Place { - local: discr_place.local, + local: size_array_local, projection: tcx - .intern_place_elems(&[PlaceElem::Index(size_place.local)]), + .intern_place_elems(&[PlaceElem::Index(discr_place.local)]), })), )), }; @@ -187,8 +201,10 @@ impl EnumSizeOpt { }), }; - let store_dead = - Statement { source_info, kind: StatementKind::StorageDead(new_local) }; + let store_dead = Statement { + source_info, + kind: StatementKind::StorageDead(size_array_local), + }; let iter = std::array::IntoIter::new([ store_live, const_assign,