diff --git a/compiler/rustc_const_eval/src/const_eval/mod.rs b/compiler/rustc_const_eval/src/const_eval/mod.rs index 655c89cac22..ef31155215a 100644 --- a/compiler/rustc_const_eval/src/const_eval/mod.rs +++ b/compiler/rustc_const_eval/src/const_eval/mod.rs @@ -101,7 +101,7 @@ pub(crate) fn try_destructure_mir_constant_for_diagnostics<'tcx>( return None; } ty::Adt(def, _) => { - let variant = ecx.read_discriminant(&op).ok()?.1; + let variant = ecx.read_discriminant(&op).ok()?; let down = ecx.project_downcast(&op, variant).ok()?; (def.variants()[variant].fields.len(), Some(variant), down) } diff --git a/compiler/rustc_const_eval/src/const_eval/valtrees.rs b/compiler/rustc_const_eval/src/const_eval/valtrees.rs index 5f65cd6ec25..b519bcdf4a3 100644 --- a/compiler/rustc_const_eval/src/const_eval/valtrees.rs +++ b/compiler/rustc_const_eval/src/const_eval/valtrees.rs @@ -130,7 +130,7 @@ pub(crate) fn const_to_valtree_inner<'tcx>( bug!("uninhabited types should have errored and never gotten converted to valtree") } - let Ok((_, variant)) = ecx.read_discriminant(&place.into()) else { + let Ok(variant) = ecx.read_discriminant(&place.into()) else { return Err(ValTreeCreationError::Other); }; branches(ecx, place, def.variant(variant).fields.len(), def.is_enum().then_some(variant), num_nodes) diff --git a/compiler/rustc_const_eval/src/interpret/discriminant.rs b/compiler/rustc_const_eval/src/interpret/discriminant.rs index 9ea5e7cb1f9..aff86d5f486 100644 --- a/compiler/rustc_const_eval/src/interpret/discriminant.rs +++ b/compiler/rustc_const_eval/src/interpret/discriminant.rs @@ -1,6 +1,6 @@ //! Functions for reading and writing discriminants of multi-variant layouts (enums and generators). -use rustc_middle::ty::layout::{LayoutOf, PrimitiveExt}; +use rustc_middle::ty::layout::{LayoutOf, PrimitiveExt, TyAndLayout}; use rustc_middle::{mir, ty}; use rustc_target::abi::{self, TagEncoding}; use rustc_target::abi::{VariantIdx, Variants}; @@ -93,7 +93,7 @@ pub fn write_discriminant( pub fn read_discriminant( &self, op: &OpTy<'tcx, M::Provenance>, - ) -> InterpResult<'tcx, (Scalar, VariantIdx)> { + ) -> InterpResult<'tcx, VariantIdx> { trace!("read_discriminant_value {:#?}", op.layout); // Get type and layout of the discriminant. let discr_layout = self.layout_of(op.layout.ty.discriminant_ty(*self.tcx))?; @@ -106,30 +106,22 @@ pub fn read_discriminant( // straight-forward (`TagEncoding::Direct`) or with a niche (`TagEncoding::Niche`). let (tag_scalar_layout, tag_encoding, tag_field) = match op.layout.variants { Variants::Single { index } => { - // Hilariously, `Single` is used even for 0-variant enums. - // (See https://github.com/rust-lang/rust/issues/89765). - if matches!(op.layout.ty.kind(), ty::Adt(def, ..) if def.variants().is_empty()) { - throw_ub!(UninhabitedEnumVariantRead(index)) - } - let discr = match op.layout.ty.discriminant_for_variant(*self.tcx, index) { - Some(discr) => { - // This type actually has discriminants. - assert_eq!(discr.ty, discr_layout.ty); - Scalar::from_uint(discr.val, discr_layout.size) + // Do some extra checks on enums. + if op.layout.ty.is_enum() { + // Hilariously, `Single` is used even for 0-variant enums. + // (See https://github.com/rust-lang/rust/issues/89765). + if matches!(op.layout.ty.kind(), ty::Adt(def, ..) if def.variants().is_empty()) + { + throw_ub!(UninhabitedEnumVariantRead(index)) } - None => { - // On a type without actual discriminants, variant is 0. - assert_eq!(index.as_u32(), 0); - Scalar::from_uint(index.as_u32(), discr_layout.size) + // For consisteny with `write_discriminant`, and to make sure that + // `project_downcast` cannot fail due to strange layouts, we declare immediate UB + // for uninhabited variants. + if op.layout.for_variant(self, index).abi.is_uninhabited() { + throw_ub!(UninhabitedEnumVariantRead(index)) } - }; - // For consisteny with `write_discriminant`, and to make sure that - // `project_downcast` cannot fail due to strange layouts, we declare immediate UB - // for uninhabited variants. - if op.layout.ty.is_enum() && op.layout.for_variant(self, index).abi.is_uninhabited() { - throw_ub!(UninhabitedEnumVariantRead(index)) } - return Ok((discr, index)); + return Ok(index); } Variants::Multiple { tag, ref tag_encoding, tag_field, .. } => { (tag, tag_encoding, tag_field) @@ -155,7 +147,7 @@ pub fn read_discriminant( trace!("tag value: {}", tag_val); // Figure out which discriminant and variant this corresponds to. - let (discr, index) = match *tag_encoding { + let index = match *tag_encoding { TagEncoding::Direct => { let scalar = tag_val.to_scalar(); // Generate a specific error if `tag_val` is not an integer. @@ -183,7 +175,7 @@ pub fn read_discriminant( } .ok_or_else(|| err_ub!(InvalidTag(Scalar::from_uint(tag_bits, tag_layout.size))))?; // Return the cast value, and the index. - (discr_val, index.0) + index.0 } TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start } => { let tag_val = tag_val.to_scalar(); @@ -241,13 +233,33 @@ pub fn read_discriminant( // Compute the size of the scalar we need to return. // No need to cast, because the variant index directly serves as discriminant and is // encoded in the tag. - (Scalar::from_uint(variant.as_u32(), discr_layout.size), variant) + variant } }; // For consisteny with `write_discriminant`, and to make sure that `project_downcast` cannot fail due to strange layouts, we declare immediate UB for uninhabited variants. if op.layout.for_variant(self, index).abi.is_uninhabited() { throw_ub!(UninhabitedEnumVariantRead(index)) } - Ok((discr, index)) + Ok(index) + } + + pub fn discriminant_for_variant( + &self, + layout: TyAndLayout<'tcx>, + variant: VariantIdx, + ) -> InterpResult<'tcx, Scalar> { + let discr_layout = self.layout_of(layout.ty.discriminant_ty(*self.tcx))?; + Ok(match layout.ty.discriminant_for_variant(*self.tcx, variant) { + Some(discr) => { + // This type actually has discriminants. + assert_eq!(discr.ty, discr_layout.ty); + Scalar::from_uint(discr.val, discr_layout.size) + } + None => { + // On a type without actual discriminants, variant is 0. + assert_eq!(variant.as_u32(), 0); + Scalar::from_uint(variant.as_u32(), discr_layout.size) + } + }) } } diff --git a/compiler/rustc_const_eval/src/interpret/intrinsics.rs b/compiler/rustc_const_eval/src/interpret/intrinsics.rs index 5f981c9b918..3f697168280 100644 --- a/compiler/rustc_const_eval/src/interpret/intrinsics.rs +++ b/compiler/rustc_const_eval/src/interpret/intrinsics.rs @@ -226,8 +226,9 @@ pub fn emulate_intrinsic( } sym::discriminant_value => { let place = self.deref_operand(&args[0])?; - let discr_val = self.read_discriminant(&place.into())?.0; - self.write_scalar(discr_val, dest)?; + let variant = self.read_discriminant(&place.into())?; + let discr = self.discriminant_for_variant(place.layout, variant)?; + self.write_scalar(discr, dest)?; } sym::exact_div => { let l = self.read_immediate(&args[0])?; diff --git a/compiler/rustc_const_eval/src/interpret/step.rs b/compiler/rustc_const_eval/src/interpret/step.rs index 319c422134c..9182d23128f 100644 --- a/compiler/rustc_const_eval/src/interpret/step.rs +++ b/compiler/rustc_const_eval/src/interpret/step.rs @@ -302,8 +302,9 @@ pub fn eval_rvalue_into_place( Discriminant(place) => { let op = self.eval_place_to_op(place, None)?; - let discr_val = self.read_discriminant(&op)?.0; - self.write_scalar(discr_val, &dest)?; + let variant = self.read_discriminant(&op)?; + let discr = self.discriminant_for_variant(op.layout, variant)?; + self.write_scalar(discr, &dest)?; } } diff --git a/compiler/rustc_const_eval/src/interpret/validity.rs b/compiler/rustc_const_eval/src/interpret/validity.rs index 6618c70ac75..a82c98e7205 100644 --- a/compiler/rustc_const_eval/src/interpret/validity.rs +++ b/compiler/rustc_const_eval/src/interpret/validity.rs @@ -657,7 +657,7 @@ fn read_discriminant( ) -> InterpResult<'tcx, VariantIdx> { self.with_elem(PathElem::EnumTag, move |this| { Ok(try_validation!( - this.ecx.read_discriminant(op).map(|(_, idx)| idx), + this.ecx.read_discriminant(op), this.path, InvalidTag(val) => InvalidEnumTag { value: format!("{val:x}"), diff --git a/compiler/rustc_const_eval/src/interpret/visitor.rs b/compiler/rustc_const_eval/src/interpret/visitor.rs index a50233fa3de..4ec19d9e655 100644 --- a/compiler/rustc_const_eval/src/interpret/visitor.rs +++ b/compiler/rustc_const_eval/src/interpret/visitor.rs @@ -23,7 +23,7 @@ pub trait ValueVisitor<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>>: Sized { /// `read_discriminant` can be hooked for better error messages. #[inline(always)] fn read_discriminant(&mut self, v: &Self::V) -> InterpResult<'tcx, VariantIdx> { - Ok(self.ecx().read_discriminant(&v.to_op(self.ecx())?)?.1) + Ok(self.ecx().read_discriminant(&v.to_op(self.ecx())?)?) } /// This function provides the chance to reorder the order in which fields are visited for