diff --git a/compiler/rustc_smir/src/rustc_smir/context.rs b/compiler/rustc_smir/src/rustc_smir/context.rs
index 241a0c22310..70313cc021e 100644
--- a/compiler/rustc_smir/src/rustc_smir/context.rs
+++ b/compiler/rustc_smir/src/rustc_smir/context.rs
@@ -5,7 +5,9 @@
use rustc_middle::ty;
use rustc_middle::ty::print::{with_forced_trimmed_paths, with_no_trimmed_paths};
-use rustc_middle::ty::{GenericPredicates, Instance, ParamEnv, ScalarInt, ValTree};
+use rustc_middle::ty::{
+ GenericPredicates, Instance, ParamEnv, ScalarInt, TypeVisitableExt, ValTree,
+};
use rustc_span::def_id::LOCAL_CRATE;
use stable_mir::compiler_interface::Context;
use stable_mir::mir::alloc::GlobalAlloc;
@@ -324,7 +326,8 @@ fn instance_body(&self, def: InstanceDef) -> Option
{
fn instance_ty(&self, def: InstanceDef) -> stable_mir::ty::Ty {
let mut tables = self.0.borrow_mut();
let instance = tables.instances[def];
- instance.ty(tables.tcx, ParamEnv::empty()).stable(&mut *tables)
+ assert!(!instance.has_non_region_param(), "{instance:?} needs further substitution");
+ instance.ty(tables.tcx, ParamEnv::reveal_all()).stable(&mut *tables)
}
fn instance_def_id(&self, def: InstanceDef) -> stable_mir::DefId {
diff --git a/compiler/rustc_smir/src/rustc_smir/convert/mir.rs b/compiler/rustc_smir/src/rustc_smir/convert/mir.rs
index 8c1767501d9..41ab4007a67 100644
--- a/compiler/rustc_smir/src/rustc_smir/convert/mir.rs
+++ b/compiler/rustc_smir/src/rustc_smir/convert/mir.rs
@@ -36,6 +36,7 @@ fn stable(&self, tables: &mut Tables<'tcx>) -> Self::T {
.collect(),
self.arg_count,
self.var_debug_info.iter().map(|info| info.stable(tables)).collect(),
+ self.spread_arg.stable(tables),
)
}
}
diff --git a/compiler/rustc_smir/src/rustc_smir/convert/mod.rs b/compiler/rustc_smir/src/rustc_smir/convert/mod.rs
index 7d8339ab503..7021bdda735 100644
--- a/compiler/rustc_smir/src/rustc_smir/convert/mod.rs
+++ b/compiler/rustc_smir/src/rustc_smir/convert/mod.rs
@@ -57,7 +57,9 @@ fn stable(&self, tables: &mut Tables<'tcx>) -> Self::T {
stable_mir::mir::CoroutineKind::Gen(source.stable(tables))
}
CoroutineKind::Coroutine => stable_mir::mir::CoroutineKind::Coroutine,
- CoroutineKind::AsyncGen(_) => todo!(),
+ CoroutineKind::AsyncGen(source) => {
+ stable_mir::mir::CoroutineKind::AsyncGen(source.stable(tables))
+ }
}
}
}
diff --git a/compiler/stable_mir/src/mir/body.rs b/compiler/stable_mir/src/mir/body.rs
index 3dfe7096399..5023af9ab79 100644
--- a/compiler/stable_mir/src/mir/body.rs
+++ b/compiler/stable_mir/src/mir/body.rs
@@ -22,6 +22,11 @@ pub struct Body {
/// Debug information pertaining to user variables, including captures.
pub var_debug_info: Vec,
+
+ /// Mark an argument (which must be a tuple) as getting passed as its individual components.
+ ///
+ /// This is used for the "rust-call" ABI such as closures.
+ pub(super) spread_arg: Option,
}
pub type BasicBlockIdx = usize;
@@ -36,6 +41,7 @@ pub fn new(
locals: LocalDecls,
arg_count: usize,
var_debug_info: Vec,
+ spread_arg: Option,
) -> Self {
// If locals doesn't contain enough entries, it can lead to panics in
// `ret_local`, `arg_locals`, and `inner_locals`.
@@ -43,7 +49,7 @@ pub fn new(
locals.len() > arg_count,
"A Body must contain at least a local for the return value and each of the function's arguments"
);
- Self { blocks, locals, arg_count, var_debug_info }
+ Self { blocks, locals, arg_count, var_debug_info, spread_arg }
}
/// Return local that holds this function's return value.
@@ -75,6 +81,11 @@ pub fn local_decl(&self, local: Local) -> Option<&LocalDecl> {
self.locals.get(local)
}
+ /// Get an iterator for all local declarations.
+ pub fn local_decls(&self) -> impl Iterator- {
+ self.locals.iter().enumerate()
+ }
+
pub fn dump(&self, w: &mut W) -> io::Result<()> {
writeln!(w, "{}", function_body(self))?;
self.blocks
@@ -98,6 +109,10 @@ pub fn dump(&self, w: &mut W) -> io::Result<()> {
.collect::, _>>()?;
Ok(())
}
+
+ pub fn spread_arg(&self) -> Option {
+ self.spread_arg
+ }
}
type LocalDecls = Vec;
@@ -248,6 +263,57 @@ pub enum AssertMessage {
MisalignedPointerDereference { required: Operand, found: Operand },
}
+impl AssertMessage {
+ pub fn description(&self) -> Result<&'static str, Error> {
+ match self {
+ AssertMessage::Overflow(BinOp::Add, _, _) => Ok("attempt to add with overflow"),
+ AssertMessage::Overflow(BinOp::Sub, _, _) => Ok("attempt to subtract with overflow"),
+ AssertMessage::Overflow(BinOp::Mul, _, _) => Ok("attempt to multiply with overflow"),
+ AssertMessage::Overflow(BinOp::Div, _, _) => Ok("attempt to divide with overflow"),
+ AssertMessage::Overflow(BinOp::Rem, _, _) => {
+ Ok("attempt to calculate the remainder with overflow")
+ }
+ AssertMessage::OverflowNeg(_) => Ok("attempt to negate with overflow"),
+ AssertMessage::Overflow(BinOp::Shr, _, _) => Ok("attempt to shift right with overflow"),
+ AssertMessage::Overflow(BinOp::Shl, _, _) => Ok("attempt to shift left with overflow"),
+ AssertMessage::Overflow(op, _, _) => Err(error!("`{:?}` cannot overflow", op)),
+ AssertMessage::DivisionByZero(_) => Ok("attempt to divide by zero"),
+ AssertMessage::RemainderByZero(_) => {
+ Ok("attempt to calculate the remainder with a divisor of zero")
+ }
+ AssertMessage::ResumedAfterReturn(CoroutineKind::Coroutine) => {
+ Ok("coroutine resumed after completion")
+ }
+ AssertMessage::ResumedAfterReturn(CoroutineKind::Async(_)) => {
+ Ok("`async fn` resumed after completion")
+ }
+ AssertMessage::ResumedAfterReturn(CoroutineKind::Gen(_)) => {
+ Ok("`async gen fn` resumed after completion")
+ }
+ AssertMessage::ResumedAfterReturn(CoroutineKind::AsyncGen(_)) => {
+ Ok("`gen fn` should just keep returning `AssertMessage::None` after completion")
+ }
+ AssertMessage::ResumedAfterPanic(CoroutineKind::Coroutine) => {
+ Ok("coroutine resumed after panicking")
+ }
+ AssertMessage::ResumedAfterPanic(CoroutineKind::Async(_)) => {
+ Ok("`async fn` resumed after panicking")
+ }
+ AssertMessage::ResumedAfterPanic(CoroutineKind::Gen(_)) => {
+ Ok("`async gen fn` resumed after panicking")
+ }
+ AssertMessage::ResumedAfterPanic(CoroutineKind::AsyncGen(_)) => {
+ Ok("`gen fn` should just keep returning `AssertMessage::None` after panicking")
+ }
+
+ AssertMessage::BoundsCheck { .. } => Ok("index out of bounds"),
+ AssertMessage::MisalignedPointerDereference { .. } => {
+ Ok("misaligned pointer dereference")
+ }
+ }
+ }
+}
+
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum BinOp {
Add,
@@ -325,6 +391,7 @@ pub enum CoroutineKind {
Async(CoroutineSource),
Coroutine,
Gen(CoroutineSource),
+ AsyncGen(CoroutineSource),
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
diff --git a/compiler/stable_mir/src/mir/mono.rs b/compiler/stable_mir/src/mir/mono.rs
index bc5d4a3b8f4..c126de23c4b 100644
--- a/compiler/stable_mir/src/mir/mono.rs
+++ b/compiler/stable_mir/src/mir/mono.rs
@@ -1,6 +1,6 @@
use crate::crate_def::CrateDef;
use crate::mir::Body;
-use crate::ty::{Allocation, ClosureDef, ClosureKind, FnDef, FnSig, GenericArgs, IndexedVal, Ty};
+use crate::ty::{Allocation, ClosureDef, ClosureKind, FnDef, GenericArgs, IndexedVal, Ty};
use crate::{with, CrateItem, DefId, Error, ItemKind, Opaque, Symbol};
use std::fmt::{Debug, Formatter};
@@ -115,11 +115,6 @@ pub fn resolve_closure(
})
}
- /// Get this function signature with all types already instantiated.
- pub fn fn_sig(&self) -> FnSig {
- self.ty().kind().fn_sig().unwrap().skip_binder()
- }
-
/// Check whether this instance is an empty shim.
///
/// Allow users to check if this shim can be ignored when called directly.
diff --git a/compiler/stable_mir/src/mir/pretty.rs b/compiler/stable_mir/src/mir/pretty.rs
index 576087498ab..8b7b488d312 100644
--- a/compiler/stable_mir/src/mir/pretty.rs
+++ b/compiler/stable_mir/src/mir/pretty.rs
@@ -260,6 +260,7 @@ pub fn pretty_assert_message(msg: &AssertMessage) -> String {
);
pretty
}
+ AssertMessage::Overflow(op, _, _) => unreachable!("`{:?}` cannot overflow", op),
AssertMessage::OverflowNeg(op) => {
let pretty_op = pretty_operand(op);
pretty.push_str(
@@ -279,17 +280,15 @@ pub fn pretty_assert_message(msg: &AssertMessage) -> String {
);
pretty
}
- AssertMessage::ResumedAfterReturn(_) => {
- format!("attempt to resume a generator after completion")
- }
- AssertMessage::ResumedAfterPanic(_) => format!("attempt to resume a panicked generator"),
AssertMessage::MisalignedPointerDereference { required, found } => {
let pretty_required = pretty_operand(required);
let pretty_found = pretty_operand(found);
pretty.push_str(format!("\"misaligned pointer dereference: address must be a multiple of {{}} but is {{}}\",{pretty_required}, {pretty_found}").as_str());
pretty
}
- _ => todo!(),
+ AssertMessage::ResumedAfterReturn(_) | AssertMessage::ResumedAfterPanic(_) => {
+ msg.description().unwrap().to_string()
+ }
}
}
diff --git a/compiler/stable_mir/src/mir/visit.rs b/compiler/stable_mir/src/mir/visit.rs
index d46caad9a01..98336a72900 100644
--- a/compiler/stable_mir/src/mir/visit.rs
+++ b/compiler/stable_mir/src/mir/visit.rs
@@ -133,7 +133,7 @@ fn visit_var_debug_info(&mut self, var_debug_info: &VarDebugInfo) {
}
fn super_body(&mut self, body: &Body) {
- let Body { blocks, locals: _, arg_count, var_debug_info } = body;
+ let Body { blocks, locals: _, arg_count, var_debug_info, spread_arg: _ } = body;
for bb in blocks {
self.visit_basic_block(bb);
diff --git a/compiler/stable_mir/src/ty.rs b/compiler/stable_mir/src/ty.rs
index f473fd8dbb7..3d5e264104b 100644
--- a/compiler/stable_mir/src/ty.rs
+++ b/compiler/stable_mir/src/ty.rs
@@ -22,9 +22,7 @@ fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
/// Constructors for `Ty`.
impl Ty {
/// Create a new type from a given kind.
- ///
- /// Note that not all types may be supported at this point.
- fn from_rigid_kind(kind: RigidTy) -> Ty {
+ pub fn from_rigid_kind(kind: RigidTy) -> Ty {
with(|cx| cx.new_rigid_ty(kind))
}
@@ -77,6 +75,16 @@ pub fn usize_ty() -> Ty {
pub fn bool_ty() -> Ty {
Ty::from_rigid_kind(RigidTy::Bool)
}
+
+ /// Create a type representing a signed integer.
+ pub fn signed_ty(inner: IntTy) -> Ty {
+ Ty::from_rigid_kind(RigidTy::Int(inner))
+ }
+
+ /// Create a type representing an unsigned integer.
+ pub fn unsigned_ty(inner: UintTy) -> Ty {
+ Ty::from_rigid_kind(RigidTy::Uint(inner))
+ }
}
impl Ty {
diff --git a/tests/ui-fulldeps/stable-mir/check_defs.rs b/tests/ui-fulldeps/stable-mir/check_defs.rs
index d311be5982d..ad667511332 100644
--- a/tests/ui-fulldeps/stable-mir/check_defs.rs
+++ b/tests/ui-fulldeps/stable-mir/check_defs.rs
@@ -69,7 +69,7 @@ fn extract_elem_ty(ty: Ty) -> Ty {
/// Check signature and type of `Vec::::new` and its generic version.
fn test_vec_new(instance: mir::mono::Instance) {
- let sig = instance.fn_sig();
+ let sig = instance.ty().kind().fn_sig().unwrap().skip_binder();
assert_matches!(sig.inputs(), &[]);
let elem_ty = extract_elem_ty(sig.output());
assert_matches!(elem_ty.kind(), TyKind::RigidTy(RigidTy::Uint(UintTy::U8)));