diff --git a/compiler/rustc_smir/src/rustc_smir/context.rs b/compiler/rustc_smir/src/rustc_smir/context.rs index 8ddd3d48539..93d49038fc1 100644 --- a/compiler/rustc_smir/src/rustc_smir/context.rs +++ b/compiler/rustc_smir/src/rustc_smir/context.rs @@ -255,16 +255,11 @@ impl<'tcx> Context for TablesWrapper<'tcx> { tables.tcx.type_of(item.internal(&mut *tables)).instantiate_identity().stable(&mut *tables) } - fn def_ty_with_args( - &self, - item: stable_mir::DefId, - args: &GenericArgs, - ) -> Result { + fn def_ty_with_args(&self, item: stable_mir::DefId, args: &GenericArgs) -> stable_mir::ty::Ty { let mut tables = self.0.borrow_mut(); let args = args.internal(&mut *tables); let def_ty = tables.tcx.type_of(item.internal(&mut *tables)); - // FIXME(celinval): use try_fold instead to avoid crashing. - Ok(def_ty.instantiate(tables.tcx, args).stable(&mut *tables)) + def_ty.instantiate(tables.tcx, args).stable(&mut *tables) } fn const_literal(&self, cnst: &stable_mir::ty::Const) -> String { diff --git a/compiler/rustc_smir/src/rustc_smir/convert/mod.rs b/compiler/rustc_smir/src/rustc_smir/convert/mod.rs index c3ee0a60f4d..9c0b2b29bca 100644 --- a/compiler/rustc_smir/src/rustc_smir/convert/mod.rs +++ b/compiler/rustc_smir/src/rustc_smir/convert/mod.rs @@ -75,24 +75,3 @@ impl<'tcx> Stable<'tcx> for rustc_span::Span { tables.create_span(*self) } } - -impl<'tcx, T> Stable<'tcx> for &[T] -where - T: Stable<'tcx>, -{ - type T = Vec; - fn stable(&self, tables: &mut Tables<'tcx>) -> Self::T { - self.iter().map(|e| e.stable(tables)).collect() - } -} - -impl<'tcx, T, U> Stable<'tcx> for (T, U) -where - T: Stable<'tcx>, - U: Stable<'tcx>, -{ - type T = (T::T, U::T); - fn stable(&self, tables: &mut Tables<'tcx>) -> Self::T { - (self.0.stable(tables), self.1.stable(tables)) - } -} diff --git a/compiler/rustc_smir/src/rustc_smir/mod.rs b/compiler/rustc_smir/src/rustc_smir/mod.rs index eee587f3b2a..4cb48a12c96 100644 --- a/compiler/rustc_smir/src/rustc_smir/mod.rs +++ b/compiler/rustc_smir/src/rustc_smir/mod.rs @@ -141,3 +141,24 @@ where } } } + +impl<'tcx, T> Stable<'tcx> for &[T] +where + T: Stable<'tcx>, +{ + type T = Vec; + fn stable(&self, tables: &mut Tables<'tcx>) -> Self::T { + self.iter().map(|e| e.stable(tables)).collect() + } +} + +impl<'tcx, T, U> Stable<'tcx> for (T, U) +where + T: Stable<'tcx>, + U: Stable<'tcx>, +{ + type T = (T::T, U::T); + fn stable(&self, tables: &mut Tables<'tcx>) -> Self::T { + (self.0.stable(tables), self.1.stable(tables)) + } +} diff --git a/compiler/stable_mir/src/compiler_interface.rs b/compiler/stable_mir/src/compiler_interface.rs index 8f0c5f73796..d8a3d4eda4b 100644 --- a/compiler/stable_mir/src/compiler_interface.rs +++ b/compiler/stable_mir/src/compiler_interface.rs @@ -91,7 +91,7 @@ pub trait Context { fn def_ty(&self, item: DefId) -> Ty; /// Returns the type of given definition instantiated with the given arguments. - fn def_ty_with_args(&self, item: DefId, args: &GenericArgs) -> Result; + fn def_ty_with_args(&self, item: DefId, args: &GenericArgs) -> Ty; /// Returns literal value of a const as a string. fn const_literal(&self, cnst: &Const) -> String; diff --git a/compiler/stable_mir/src/error.rs b/compiler/stable_mir/src/error.rs index 1ff65717e87..bb5e1a34180 100644 --- a/compiler/stable_mir/src/error.rs +++ b/compiler/stable_mir/src/error.rs @@ -28,7 +28,7 @@ pub enum CompilerError { } /// A generic error to represent an API request that cannot be fulfilled. -#[derive(Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct Error(pub(crate) String); impl Error { diff --git a/compiler/stable_mir/src/mir/visit.rs b/compiler/stable_mir/src/mir/visit.rs index 0c44781c463..d46caad9a01 100644 --- a/compiler/stable_mir/src/mir/visit.rs +++ b/compiler/stable_mir/src/mir/visit.rs @@ -452,7 +452,7 @@ impl Location { } /// Information about a place's usage. -#[derive(Copy, Clone, PartialEq, Eq, Hash)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] pub struct PlaceContext { /// Whether the access is mutable or not. Keep this private so we can increment the type in a /// backward compatible manner. diff --git a/compiler/stable_mir/src/ty.rs b/compiler/stable_mir/src/ty.rs index 2724b1fe0a6..f64b1f5f5a3 100644 --- a/compiler/stable_mir/src/ty.rs +++ b/compiler/stable_mir/src/ty.rs @@ -382,7 +382,9 @@ impl AdtDef { } /// Retrieve the type of this Adt instantiating the type with the given arguments. - pub fn ty_with_args(&self, args: &GenericArgs) -> Result { + /// + /// This will assume the type can be instantiated with these arguments. + pub fn ty_with_args(&self, args: &GenericArgs) -> Ty { with(|cx| cx.def_ty_with_args(self.0, args)) } @@ -441,6 +443,7 @@ impl VariantDef { } } +#[derive(Clone, Debug, Eq, PartialEq)] pub struct FieldDef { /// The field definition. /// @@ -454,7 +457,9 @@ pub struct FieldDef { impl FieldDef { /// Retrieve the type of this field instantiating the type with the given arguments. - pub fn ty_with_args(&self, args: &GenericArgs) -> Result { + /// + /// This will assume the type can be instantiated with these arguments. + pub fn ty_with_args(&self, args: &GenericArgs) -> Ty { with(|cx| cx.def_ty_with_args(self.def, args)) } diff --git a/tests/ui-fulldeps/stable-mir/check_ty_fold.rs b/tests/ui-fulldeps/stable-mir/check_ty_fold.rs new file mode 100644 index 00000000000..b90d47d4540 --- /dev/null +++ b/tests/ui-fulldeps/stable-mir/check_ty_fold.rs @@ -0,0 +1,115 @@ +// run-pass +//! Test that users are able to use stable mir APIs to retrieve monomorphized types, and that +//! we have an error handling for trying to instantiate types with incorrect arguments. + +// ignore-stage1 +// ignore-cross-compile +// ignore-remote +// ignore-windows-gnu mingw has troubles with linking https://github.com/rust-lang/rust/pull/116837 +// edition: 2021 + +#![feature(rustc_private)] +#![feature(assert_matches)] +#![feature(control_flow_enum)] + +extern crate rustc_middle; +#[macro_use] +extern crate rustc_smir; +extern crate rustc_driver; +extern crate rustc_interface; +extern crate stable_mir; + +use rustc_middle::ty::TyCtxt; +use rustc_smir::rustc_internal; +use stable_mir::ty::{RigidTy, TyKind, Ty, }; +use stable_mir::mir::{Body, MirVisitor, FieldIdx, Place, ProjectionElem, visit::{Location, + PlaceContext}}; +use std::io::Write; +use std::ops::ControlFlow; + +const CRATE_NAME: &str = "input"; + +/// This function uses the Stable MIR APIs to get information about the test crate. +fn test_stable_mir(_tcx: TyCtxt<'_>) -> ControlFlow<()> { + let main_fn = stable_mir::entry_fn(); + let body = main_fn.unwrap().body(); + let mut visitor = PlaceVisitor{ body: &body, tested: false}; + visitor.visit_body(&body); + assert!(visitor.tested); + ControlFlow::Continue(()) +} + +struct PlaceVisitor<'a> { + body: &'a Body, + /// Used to ensure that the test was reachable. Otherwise this test would vacuously succeed. + tested: bool, +} + +/// Check that `wrapper.inner` place projection can be correctly interpreted. +/// Ensure that instantiation is correct. +fn check_tys(local_ty: Ty, idx: FieldIdx, expected_ty: Ty) { + let TyKind::RigidTy(RigidTy::Adt(def, args)) = local_ty.kind() else { unreachable!() }; + assert_eq!(def.ty_with_args(&args), local_ty); + + let field_def = &def.variants_iter().next().unwrap().fields()[idx]; + let field_ty = field_def.ty_with_args(&args); + assert_eq!(field_ty, expected_ty); + + // Check that the generic version is different than the instantiated one. + let field_ty_gen = field_def.ty(); + assert_ne!(field_ty_gen, field_ty); +} + +impl<'a> MirVisitor for PlaceVisitor<'a> { + fn visit_place(&mut self, place: &Place, _ptx: PlaceContext, _loc: Location) { + let start_ty = self.body.locals()[place.local].ty; + match place.projection.as_slice() { + [ProjectionElem::Field(idx, ty)] => { + check_tys(start_ty, *idx, *ty); + self.tested = true; + } + _ => {} + } + } +} + +/// This test will generate and analyze a dummy crate using the stable mir. +/// For that, it will first write the dummy crate into a file. +/// Then it will create a `StableMir` using custom arguments and then +/// it will run the compiler. +fn main() { + let path = "ty_fold_input.rs"; + generate_input(&path).unwrap(); + let args = vec![ + "rustc".to_string(), + "-Cpanic=abort".to_string(), + "--crate-name".to_string(), + CRATE_NAME.to_string(), + path.to_string(), + ]; + run!(args, tcx, test_stable_mir(tcx)).unwrap(); +} + +fn generate_input(path: &str) -> std::io::Result<()> { + let mut file = std::fs::File::create(path)?; + write!( + file, + r#" + struct Wrapper {{ + pub inner: T + }} + + impl Wrapper {{ + pub fn new() -> Wrapper {{ + Wrapper {{ inner: T::default() }} + }} + }} + + fn main() {{ + let wrapper = Wrapper::::new(); + let _inner = wrapper.inner; + }} + "# + )?; + Ok(()) +}