[stdio][windows] Use MBTWC and WCTMB

This commit is contained in:
Nicole Mazzuca 2023-01-19 23:21:21 -08:00
parent 07c993eba8
commit 7f25580512
3 changed files with 78 additions and 29 deletions

View File

@ -232,6 +232,7 @@
all(target_vendor = "fortanix", target_env = "sgx"), all(target_vendor = "fortanix", target_env = "sgx"),
feature(slice_index_methods, coerce_unsized, sgx_platform) feature(slice_index_methods, coerce_unsized, sgx_platform)
)] )]
#![cfg_attr(windows, feature(round_char_boundary))]
// //
// Language features: // Language features:
#![feature(alloc_error_handler)] #![feature(alloc_error_handler)]

View File

@ -6,13 +6,15 @@
use crate::ffi::CStr; use crate::ffi::CStr;
use crate::mem; use crate::mem;
use crate::os::raw::{c_char, c_int, c_long, c_longlong, c_uint, c_ulong, c_ushort}; use crate::os::raw::{c_char, c_long, c_longlong, c_uint, c_ulong, c_ushort};
use crate::os::windows::io::{BorrowedHandle, HandleOrInvalid, HandleOrNull}; use crate::os::windows::io::{BorrowedHandle, HandleOrInvalid, HandleOrNull};
use crate::ptr; use crate::ptr;
use core::ffi::NonZero_c_ulong; use core::ffi::NonZero_c_ulong;
use libc::{c_void, size_t, wchar_t}; use libc::{c_void, size_t, wchar_t};
pub use crate::os::raw::c_int;
#[path = "c/errors.rs"] // c.rs is included from two places so we need to specify this #[path = "c/errors.rs"] // c.rs is included from two places so we need to specify this
mod errors; mod errors;
pub use errors::*; pub use errors::*;
@ -47,16 +49,19 @@ pub type ACCESS_MASK = DWORD;
pub type LPBOOL = *mut BOOL; pub type LPBOOL = *mut BOOL;
pub type LPBYTE = *mut BYTE; pub type LPBYTE = *mut BYTE;
pub type LPCCH = *const CHAR;
pub type LPCSTR = *const CHAR; pub type LPCSTR = *const CHAR;
pub type LPCWCH = *const WCHAR;
pub type LPCWSTR = *const WCHAR; pub type LPCWSTR = *const WCHAR;
pub type LPCVOID = *const c_void;
pub type LPDWORD = *mut DWORD; pub type LPDWORD = *mut DWORD;
pub type LPHANDLE = *mut HANDLE; pub type LPHANDLE = *mut HANDLE;
pub type LPOVERLAPPED = *mut OVERLAPPED; pub type LPOVERLAPPED = *mut OVERLAPPED;
pub type LPPROCESS_INFORMATION = *mut PROCESS_INFORMATION; pub type LPPROCESS_INFORMATION = *mut PROCESS_INFORMATION;
pub type LPSECURITY_ATTRIBUTES = *mut SECURITY_ATTRIBUTES; pub type LPSECURITY_ATTRIBUTES = *mut SECURITY_ATTRIBUTES;
pub type LPSTARTUPINFO = *mut STARTUPINFO; pub type LPSTARTUPINFO = *mut STARTUPINFO;
pub type LPSTR = *mut CHAR;
pub type LPVOID = *mut c_void; pub type LPVOID = *mut c_void;
pub type LPCVOID = *const c_void;
pub type LPWCH = *mut WCHAR; pub type LPWCH = *mut WCHAR;
pub type LPWIN32_FIND_DATAW = *mut WIN32_FIND_DATAW; pub type LPWIN32_FIND_DATAW = *mut WIN32_FIND_DATAW;
pub type LPWSADATA = *mut WSADATA; pub type LPWSADATA = *mut WSADATA;
@ -132,6 +137,10 @@ pub const MAX_PATH: usize = 260;
pub const FILE_TYPE_PIPE: u32 = 3; pub const FILE_TYPE_PIPE: u32 = 3;
pub const CP_UTF8: DWORD = 65001;
pub const MB_ERR_INVALID_CHARS: DWORD = 0x08;
pub const WC_ERR_INVALID_CHARS: DWORD = 0x80;
#[repr(C)] #[repr(C)]
#[derive(Copy)] #[derive(Copy)]
pub struct WIN32_FIND_DATAW { pub struct WIN32_FIND_DATAW {
@ -1155,6 +1164,25 @@ extern "system" {
lpFilePart: *mut LPWSTR, lpFilePart: *mut LPWSTR,
) -> DWORD; ) -> DWORD;
pub fn GetFileAttributesW(lpFileName: LPCWSTR) -> DWORD; pub fn GetFileAttributesW(lpFileName: LPCWSTR) -> DWORD;
pub fn MultiByteToWideChar(
CodePage: UINT,
dwFlags: DWORD,
lpMultiByteStr: LPCCH,
cbMultiByte: c_int,
lpWideCharStr: LPWSTR,
cchWideChar: c_int,
) -> c_int;
pub fn WideCharToMultiByte(
CodePage: UINT,
dwFlags: DWORD,
lpWideCharStr: LPCWCH,
cchWideChar: c_int,
lpMultiByteStr: LPSTR,
cbMultiByte: c_int,
lpDefaultChar: LPCCH,
lpUsedDefaultChar: LPBOOL,
) -> c_int;
} }
#[link(name = "ws2_32")] #[link(name = "ws2_32")]

View File

@ -169,14 +169,27 @@ fn write(
} }
fn write_valid_utf8_to_console(handle: c::HANDLE, utf8: &str) -> io::Result<usize> { fn write_valid_utf8_to_console(handle: c::HANDLE, utf8: &str) -> io::Result<usize> {
debug_assert!(!utf8.is_empty());
let mut utf16 = [MaybeUninit::<u16>::uninit(); MAX_BUFFER_SIZE / 2]; let mut utf16 = [MaybeUninit::<u16>::uninit(); MAX_BUFFER_SIZE / 2];
let mut len_utf16 = 0; let utf8 = &utf8[..utf8.floor_char_boundary(utf16.len())];
for (chr, dest) in utf8.encode_utf16().zip(utf16.iter_mut()) {
*dest = MaybeUninit::new(chr); let utf16: &[u16] = unsafe {
len_utf16 += 1; // Note that this theoretically checks validity twice in the (most common) case
} // where the underlying byte sequence is valid utf-8 (given the check in `write()`).
// Safety: We've initialized `len_utf16` values. let result = c::MultiByteToWideChar(
let utf16: &[u16] = unsafe { MaybeUninit::slice_assume_init_ref(&utf16[..len_utf16]) }; c::CP_UTF8, // CodePage
c::MB_ERR_INVALID_CHARS, // dwFlags
utf8.as_ptr() as c::LPCCH, // lpMultiByteStr
utf8.len() as c::c_int, // cbMultiByte
utf16.as_mut_ptr() as c::LPWSTR, // lpWideCharStr
utf16.len() as c::c_int, // cchWideChar
);
assert!(result != 0, "Unexpected error in MultiByteToWideChar");
// Safety: MultiByteToWideChar initializes `result` values.
MaybeUninit::slice_assume_init_ref(&utf16[..result as usize])
};
let mut written = write_u16s(handle, &utf16)?; let mut written = write_u16s(handle, &utf16)?;
@ -189,8 +202,8 @@ fn write_valid_utf8_to_console(handle: c::HANDLE, utf8: &str) -> io::Result<usiz
// a missing surrogate can be produced (and also because of the UTF-8 validation above), // a missing surrogate can be produced (and also because of the UTF-8 validation above),
// write the missing surrogate out now. // write the missing surrogate out now.
// Buffering it would mean we have to lie about the number of bytes written. // Buffering it would mean we have to lie about the number of bytes written.
let first_char_remaining = utf16[written]; let first_code_unit_remaining = utf16[written];
if first_char_remaining >= 0xDCEE && first_char_remaining <= 0xDFFF { if first_code_unit_remaining >= 0xDCEE && first_code_unit_remaining <= 0xDFFF {
// low surrogate // low surrogate
// We just hope this works, and give up otherwise // We just hope this works, and give up otherwise
let _ = write_u16s(handle, &utf16[written..written + 1]); let _ = write_u16s(handle, &utf16[written..written + 1]);
@ -212,6 +225,7 @@ fn write_valid_utf8_to_console(handle: c::HANDLE, utf8: &str) -> io::Result<usiz
} }
fn write_u16s(handle: c::HANDLE, data: &[u16]) -> io::Result<usize> { fn write_u16s(handle: c::HANDLE, data: &[u16]) -> io::Result<usize> {
debug_assert!(data.len() < u32::MAX as usize);
let mut written = 0; let mut written = 0;
cvt(unsafe { cvt(unsafe {
c::WriteConsoleW( c::WriteConsoleW(
@ -365,26 +379,32 @@ fn read_u16s(handle: c::HANDLE, buf: &mut [MaybeUninit<u16>]) -> io::Result<usiz
Ok(amount as usize) Ok(amount as usize)
} }
#[allow(unused)]
fn utf16_to_utf8(utf16: &[u16], utf8: &mut [u8]) -> io::Result<usize> { fn utf16_to_utf8(utf16: &[u16], utf8: &mut [u8]) -> io::Result<usize> {
let mut written = 0; debug_assert!(utf16.len() <= c::c_int::MAX as usize);
for chr in char::decode_utf16(utf16.iter().cloned()) { debug_assert!(utf8.len() <= c::c_int::MAX as usize);
match chr {
Ok(chr) => { let result = unsafe {
chr.encode_utf8(&mut utf8[written..]); c::WideCharToMultiByte(
written += chr.len_utf8(); c::CP_UTF8, // CodePage
} c::WC_ERR_INVALID_CHARS, // dwFlags
Err(_) => { utf16.as_ptr(), // lpWideCharStr
utf16.len() as c::c_int, // cchWideChar
utf8.as_mut_ptr() as c::LPSTR, // lpMultiByteStr
utf8.len() as c::c_int, // cbMultiByte
ptr::null(), // lpDefaultChar
ptr::null_mut(), // lpUsedDefaultChar
)
};
if result == 0 {
// We can't really do any better than forget all data and return an error. // We can't really do any better than forget all data and return an error.
return Err(io::const_io_error!( Err(io::const_io_error!(
io::ErrorKind::InvalidData, io::ErrorKind::InvalidData,
"Windows stdin in console mode does not support non-UTF-16 input; \ "Windows stdin in console mode does not support non-UTF-16 input; \
encountered unpaired surrogate", encountered unpaired surrogate",
)); ))
} else {
Ok(result as usize)
} }
}
}
Ok(written)
} }
impl IncompleteUtf8 { impl IncompleteUtf8 {