Skip to content

Commit

Permalink
fix: Ensure Binary -> Binview cast doesn't overflow the buffer size (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Mar 31, 2024
1 parent 1485fec commit 2b28777
Showing 1 changed file with 99 additions and 7 deletions.
106 changes: 99 additions & 7 deletions crates/polars-arrow/src/compute/cast/utf8_to.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use polars_utils::slice::GetSaferUnchecked;
use polars_utils::vec::PushUnchecked;

use crate::array::*;
use crate::buffer::Buffer;
use crate::datatypes::ArrowDataType;
use crate::offset::Offset;
use crate::types::NativeType;
Expand Down Expand Up @@ -69,14 +70,51 @@ pub fn utf8_to_binary<O: Offset>(
}
}

// Different types to test the overflow path.
#[cfg(not(test))]
type OffsetType = u32;

// To trigger overflow
#[cfg(test)]
type OffsetType = i8;

// If we don't do this the GC of binview will trigger. As we will split up buffers into multiple
// chunks so that we don't overflow the offset u32.
fn truncate_buffer(buf: &Buffer<u8>) -> Buffer<u8> {
// * 2, as it must be able to hold u32::MAX offset + u32::MAX len.
buf.clone()
.sliced(0, std::cmp::min(buf.len(), OffsetType::MAX as usize * 2))
}

pub fn binary_to_binview<O: Offset>(arr: &BinaryArray<O>) -> BinaryViewArray {
let buffer_idx = 0_u32;
let base_ptr = arr.values().as_ptr() as usize;
// Ensure we didn't accidentally set wrong type
#[cfg(not(debug_assertions))]
{
assert_eq!(
std::mem::size_of::<u32>(),
std::mem::size_of::<OffsetType>()
);
}

let mut views = Vec::with_capacity(arr.len());
let mut uses_buffer = false;

let mut base_buffer = arr.values().clone();
// Offset into the buffer
let mut base_ptr = base_buffer.as_ptr() as usize;

// Offset into the binview buffers
let mut buffer_idx = 0_u32;

// Binview buffers
// Note that the buffer may look far further than u32::MAX, but as we don't clone data
let mut buffers = vec![truncate_buffer(&base_buffer)];

for bytes in arr.values_iter() {
let len: u32 = bytes.len().try_into().unwrap();
let len: u32 = bytes
.len()
.try_into()
.expect("max string/binary length exceeded");

let mut payload = [0; 16];
payload[0..4].copy_from_slice(&len.to_le_bytes());
Expand All @@ -85,18 +123,42 @@ pub fn binary_to_binview<O: Offset>(arr: &BinaryArray<O>) -> BinaryViewArray {
payload[4..4 + bytes.len()].copy_from_slice(bytes);
} else {
uses_buffer = true;

// Copy the parts we know are correct.
unsafe { payload[4..8].copy_from_slice(bytes.get_unchecked_release(0..4)) };
let offset = (bytes.as_ptr() as usize - base_ptr) as u32;
payload[0..4].copy_from_slice(&len.to_le_bytes());
payload[8..12].copy_from_slice(&buffer_idx.to_le_bytes());
payload[12..16].copy_from_slice(&offset.to_le_bytes());

let current_bytes_ptr = bytes.as_ptr() as usize;
let offset = current_bytes_ptr - base_ptr;

// Here we check the overflow of the buffer offset.
if let Ok(offset) = OffsetType::try_from(offset) {
#[allow(clippy::unnecessary_cast)]
let offset = offset as u32;
payload[12..16].copy_from_slice(&offset.to_le_bytes());
payload[8..12].copy_from_slice(&buffer_idx.to_le_bytes());
} else {
let len = base_buffer.len() - offset;

// Set new buffer
base_buffer = base_buffer.clone().sliced(offset, len);
base_ptr = base_buffer.as_ptr() as usize;

// And add the (truncated) one to the buffers
buffers.push(truncate_buffer(&base_buffer));
buffer_idx = buffer_idx.checked_add(1).expect("max buffers exceeded");

let offset = 0u32;
payload[12..16].copy_from_slice(&offset.to_le_bytes());
payload[8..12].copy_from_slice(&buffer_idx.to_le_bytes());
}
}

let value = View::from_le_bytes(payload);
unsafe { views.push_unchecked(value) };
}
let buffers = if uses_buffer {
Arc::from([arr.values().clone()])
Arc::from(buffers)
} else {
Arc::from([])
};
Expand All @@ -114,3 +176,33 @@ pub fn binary_to_binview<O: Offset>(arr: &BinaryArray<O>) -> BinaryViewArray {
pub fn utf8_to_utf8view<O: Offset>(arr: &Utf8Array<O>) -> Utf8ViewArray {
unsafe { binary_to_binview(&arr.to_binary()).to_utf8view_unchecked() }
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn overflowing_utf8_to_binview() {
let values = [
"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf",
"123",
"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf",
"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf",
"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf",
"234",
"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf",
"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf",
"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf",
"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf",
"324",
];
let array = Utf8Array::<i64>::from_slice(values);

let out = utf8_to_utf8view(&array);
// Ensure we hit the multiple buffers part.
assert_eq!(out.buffers().len(), 6);
// Ensure we created a valid binview
let out = out.values_iter().collect::<Vec<_>>();
assert_eq!(out, values);
}
}

0 comments on commit 2b28777

Please sign in to comment.