Auto merge of #119000 - celinval:smir-cstr, r=ouz-a

Add a method to StableMIR to check if a type is a CStr

Also add a check that StableMIR works properly with C string literal.
This commit is contained in:
bors 2023-12-17 08:18:17 +00:00
commit 9f13b9d9ca
4 changed files with 37 additions and 0 deletions

View File

@ -219,6 +219,12 @@ fn adt_is_simd(&self, def: AdtDef) -> bool {
def.internal(&mut *tables).repr().simd() def.internal(&mut *tables).repr().simd()
} }
fn adt_is_cstr(&self, def: AdtDef) -> bool {
let mut tables = self.0.borrow_mut();
let def_id = def.0.internal(&mut *tables);
tables.tcx.lang_items().c_str() == Some(def_id)
}
fn fn_sig(&self, def: FnDef, args: &GenericArgs) -> PolyFnSig { fn fn_sig(&self, def: FnDef, args: &GenericArgs) -> PolyFnSig {
let mut tables = self.0.borrow_mut(); let mut tables = self.0.borrow_mut();
let def_id = def.0.internal(&mut *tables); let def_id = def.0.internal(&mut *tables);

View File

@ -72,6 +72,9 @@ pub trait Context {
/// Returns whether this ADT is simd. /// Returns whether this ADT is simd.
fn adt_is_simd(&self, def: AdtDef) -> bool; fn adt_is_simd(&self, def: AdtDef) -> bool;
/// Returns whether this definition is a C string.
fn adt_is_cstr(&self, def: AdtDef) -> bool;
/// Retrieve the function signature for the given generic arguments. /// Retrieve the function signature for the given generic arguments.
fn fn_sig(&self, def: FnDef, args: &GenericArgs) -> PolyFnSig; fn fn_sig(&self, def: FnDef, args: &GenericArgs) -> PolyFnSig;

View File

@ -316,6 +316,12 @@ pub fn is_str(&self) -> bool {
*self == TyKind::RigidTy(RigidTy::Str) *self == TyKind::RigidTy(RigidTy::Str)
} }
#[inline]
pub fn is_cstr(&self) -> bool {
let TyKind::RigidTy(RigidTy::Adt(def, _)) = self else { return false };
with(|cx| cx.adt_is_cstr(*def))
}
#[inline] #[inline]
pub fn is_slice(&self) -> bool { pub fn is_slice(&self) -> bool {
matches!(self, TyKind::RigidTy(RigidTy::Slice(_))) matches!(self, TyKind::RigidTy(RigidTy::Slice(_)))

View File

@ -33,6 +33,7 @@
use std::assert_matches::assert_matches; use std::assert_matches::assert_matches;
use std::cmp::{max, min}; use std::cmp::{max, min};
use std::collections::HashMap; use std::collections::HashMap;
use std::ffi::CStr;
use std::io::Write; use std::io::Write;
use std::ops::ControlFlow; use std::ops::ControlFlow;
@ -45,6 +46,7 @@ fn test_stable_mir(_tcx: TyCtxt<'_>) -> ControlFlow<()> {
check_foo(*get_item(&items, (ItemKind::Static, "FOO")).unwrap()); check_foo(*get_item(&items, (ItemKind::Static, "FOO")).unwrap());
check_bar(*get_item(&items, (ItemKind::Static, "BAR")).unwrap()); check_bar(*get_item(&items, (ItemKind::Static, "BAR")).unwrap());
check_len(*get_item(&items, (ItemKind::Static, "LEN")).unwrap()); check_len(*get_item(&items, (ItemKind::Static, "LEN")).unwrap());
check_cstr(*get_item(&items, (ItemKind::Static, "C_STR")).unwrap());
check_other_consts(*get_item(&items, (ItemKind::Fn, "other_consts")).unwrap()); check_other_consts(*get_item(&items, (ItemKind::Fn, "other_consts")).unwrap());
check_type_id(*get_item(&items, (ItemKind::Fn, "check_type_id")).unwrap()); check_type_id(*get_item(&items, (ItemKind::Fn, "check_type_id")).unwrap());
ControlFlow::Continue(()) ControlFlow::Continue(())
@ -86,6 +88,24 @@ fn check_bar(item: CrateItem) {
assert_eq!(std::str::from_utf8(&allocation.raw_bytes().unwrap()), Ok("Bar")); assert_eq!(std::str::from_utf8(&allocation.raw_bytes().unwrap()), Ok("Bar"));
} }
/// Check the allocation data for static `C_STR`.
///
/// ```no_run
/// static C_STR: &core::ffi::cstr = c"cstr";
/// ```
fn check_cstr(item: CrateItem) {
let def = StaticDef::try_from(item).unwrap();
let alloc = def.eval_initializer().unwrap();
assert_eq!(alloc.provenance.ptrs.len(), 1);
let deref = item.ty().kind().builtin_deref(true).unwrap();
assert!(deref.ty.kind().is_cstr(), "Expected CStr, but got: {:?}", item.ty());
let alloc_id_0 = alloc.provenance.ptrs[0].1.0;
let GlobalAlloc::Memory(allocation) = GlobalAlloc::from(alloc_id_0) else { unreachable!() };
assert_eq!(allocation.bytes.len(), 5);
assert_eq!(CStr::from_bytes_until_nul(&allocation.raw_bytes().unwrap()), Ok(c"cstr"));
}
/// Check the allocation data for constants used in `other_consts` function. /// Check the allocation data for constants used in `other_consts` function.
fn check_other_consts(item: CrateItem) { fn check_other_consts(item: CrateItem) {
// Instance body will force constant evaluation. // Instance body will force constant evaluation.
@ -206,6 +226,7 @@ fn main() {
generate_input(&path).unwrap(); generate_input(&path).unwrap();
let args = vec![ let args = vec![
"rustc".to_string(), "rustc".to_string(),
"--edition=2021".to_string(),
"--crate-name".to_string(), "--crate-name".to_string(),
CRATE_NAME.to_string(), CRATE_NAME.to_string(),
path.to_string(), path.to_string(),
@ -224,6 +245,7 @@ fn generate_input(path: &str) -> std::io::Result<()> {
static LEN: usize = 2; static LEN: usize = 2;
static FOO: [&str; 2] = ["hi", "there"]; static FOO: [&str; 2] = ["hi", "there"];
static BAR: &str = "Bar"; static BAR: &str = "Bar";
static C_STR: &std::ffi::CStr = c"cstr";
const NULL: *const u8 = std::ptr::null(); const NULL: *const u8 = std::ptr::null();
const TUPLE: (u32, u32) = (10, u32::MAX); const TUPLE: (u32, u32) = (10, u32::MAX);