Add a method to check if type is a CStr

This commit is contained in:
Celina G. Val 2023-12-15 13:18:41 -08:00
parent 3f39cae119
commit 86451badf1
4 changed files with 37 additions and 0 deletions

View File

@ -220,6 +220,12 @@ fn adt_is_simd(&self, def: AdtDef) -> bool {
def.internal(&mut *tables).repr().simd()
}
fn adt_is_cstr(&self, def: AdtDef) -> bool {
let mut tables = self.0.borrow_mut();
let def_id = def.0.internal(&mut *tables);
tables.tcx.lang_items().c_str() == Some(def_id)
}
fn fn_sig(&self, def: FnDef, args: &GenericArgs) -> PolyFnSig {
let mut tables = self.0.borrow_mut();
let def_id = def.0.internal(&mut *tables);

View File

@ -72,6 +72,9 @@ pub trait Context {
/// Returns whether this ADT is simd.
fn adt_is_simd(&self, def: AdtDef) -> bool;
/// Returns whether this definition is a C string.
fn adt_is_cstr(&self, def: AdtDef) -> bool;
/// Retrieve the function signature for the given generic arguments.
fn fn_sig(&self, def: FnDef, args: &GenericArgs) -> PolyFnSig;

View File

@ -316,6 +316,12 @@ pub fn is_str(&self) -> bool {
*self == TyKind::RigidTy(RigidTy::Str)
}
#[inline]
pub fn is_cstr(&self) -> bool {
let TyKind::RigidTy(RigidTy::Adt(def, _)) = self else { return false };
with(|cx| cx.adt_is_cstr(*def))
}
#[inline]
pub fn is_slice(&self) -> bool {
matches!(self, TyKind::RigidTy(RigidTy::Slice(_)))

View File

@ -33,6 +33,7 @@
use std::assert_matches::assert_matches;
use std::cmp::{max, min};
use std::collections::HashMap;
use std::ffi::CStr;
use std::io::Write;
use std::ops::ControlFlow;
@ -45,6 +46,7 @@ fn test_stable_mir(_tcx: TyCtxt<'_>) -> ControlFlow<()> {
check_foo(*get_item(&items, (ItemKind::Static, "FOO")).unwrap());
check_bar(*get_item(&items, (ItemKind::Static, "BAR")).unwrap());
check_len(*get_item(&items, (ItemKind::Static, "LEN")).unwrap());
check_cstr(*get_item(&items, (ItemKind::Static, "C_STR")).unwrap());
check_other_consts(*get_item(&items, (ItemKind::Fn, "other_consts")).unwrap());
check_type_id(*get_item(&items, (ItemKind::Fn, "check_type_id")).unwrap());
ControlFlow::Continue(())
@ -86,6 +88,24 @@ fn check_bar(item: CrateItem) {
assert_eq!(std::str::from_utf8(&allocation.raw_bytes().unwrap()), Ok("Bar"));
}
/// Check the allocation data for static `C_STR`.
///
/// ```no_run
/// static C_STR: &core::ffi::cstr = c"cstr";
/// ```
fn check_cstr(item: CrateItem) {
let def = StaticDef::try_from(item).unwrap();
let alloc = def.eval_initializer().unwrap();
assert_eq!(alloc.provenance.ptrs.len(), 1);
let deref = item.ty().kind().builtin_deref(true).unwrap();
assert!(deref.ty.kind().is_cstr(), "Expected CStr, but got: {:?}", item.ty());
let alloc_id_0 = alloc.provenance.ptrs[0].1.0;
let GlobalAlloc::Memory(allocation) = GlobalAlloc::from(alloc_id_0) else { unreachable!() };
assert_eq!(allocation.bytes.len(), 5);
assert_eq!(CStr::from_bytes_until_nul(&allocation.raw_bytes().unwrap()), Ok(c"cstr"));
}
/// Check the allocation data for constants used in `other_consts` function.
fn check_other_consts(item: CrateItem) {
// Instance body will force constant evaluation.
@ -206,6 +226,7 @@ fn main() {
generate_input(&path).unwrap();
let args = vec![
"rustc".to_string(),
"--edition=2021".to_string(),
"--crate-name".to_string(),
CRATE_NAME.to_string(),
path.to_string(),
@ -224,6 +245,7 @@ fn generate_input(path: &str) -> std::io::Result<()> {
static LEN: usize = 2;
static FOO: [&str; 2] = ["hi", "there"];
static BAR: &str = "Bar";
static C_STR: &std::ffi::CStr = c"cstr";
const NULL: *const u8 = std::ptr::null();
const TUPLE: (u32, u32) = (10, u32::MAX);