Avoid overflow in IoSlice::advance_slices

This commit is contained in:
Eduardo Sánchez Muñoz 2023-09-22 20:44:56 +02:00
parent 959b2c703d
commit 93863383c8

View File

@ -1236,22 +1236,22 @@ impl<'a> IoSliceMut<'a> {
pub fn advance_slices(bufs: &mut &mut [IoSliceMut<'a>], n: usize) { pub fn advance_slices(bufs: &mut &mut [IoSliceMut<'a>], n: usize) {
// Number of buffers to remove. // Number of buffers to remove.
let mut remove = 0; let mut remove = 0;
// Total length of all the to be removed buffers. // Remaining length before reaching n.
let mut accumulated_len = 0; let mut left = n;
for buf in bufs.iter() { for buf in bufs.iter() {
if accumulated_len + buf.len() > n { if let Some(remainder) = left.checked_sub(buf.len()) {
break; left = remainder;
} else {
accumulated_len += buf.len();
remove += 1; remove += 1;
} else {
break;
} }
} }
*bufs = &mut take(bufs)[remove..]; *bufs = &mut take(bufs)[remove..];
if bufs.is_empty() { if bufs.is_empty() {
assert!(n == accumulated_len, "advancing io slices beyond their length"); assert!(left == 0, "advancing io slices beyond their length");
} else { } else {
bufs[0].advance(n - accumulated_len) bufs[0].advance(left);
} }
} }
} }
@ -1379,22 +1379,25 @@ impl<'a> IoSlice<'a> {
pub fn advance_slices(bufs: &mut &mut [IoSlice<'a>], n: usize) { pub fn advance_slices(bufs: &mut &mut [IoSlice<'a>], n: usize) {
// Number of buffers to remove. // Number of buffers to remove.
let mut remove = 0; let mut remove = 0;
// Total length of all the to be removed buffers. // Remaining length before reaching n. This prevents overflow
let mut accumulated_len = 0; // that could happen if the length of slices in `bufs` were instead
// accumulated. Those slice may be aliased and, if they are large
// enough, their added length may overflow a `usize`.
let mut left = n;
for buf in bufs.iter() { for buf in bufs.iter() {
if accumulated_len + buf.len() > n { if let Some(remainder) = left.checked_sub(buf.len()) {
break; left = remainder;
} else {
accumulated_len += buf.len();
remove += 1; remove += 1;
} else {
break;
} }
} }
*bufs = &mut take(bufs)[remove..]; *bufs = &mut take(bufs)[remove..];
if bufs.is_empty() { if bufs.is_empty() {
assert!(n == accumulated_len, "advancing io slices beyond their length"); assert!(left == 0, "advancing io slices beyond their length");
} else { } else {
bufs[0].advance(n - accumulated_len) bufs[0].advance(left);
} }
} }
} }