Skip to content

Commit 3ce8e84

Browse files
authored
Unsafe improvements: core parquet crate. (#6024)
* Unsafe improvements: core `parquet` crate. * Make FromBytes an unsafe trait.
1 parent c47f230 commit 3ce8e84

File tree

3 files changed

+63
-17
lines changed

3 files changed

+63
-17
lines changed

parquet/src/bloom_filter/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ impl Block {
134134

135135
#[inline]
136136
fn to_ne_bytes(self) -> [u8; 32] {
137-
unsafe { std::mem::transmute(self) }
137+
// SAFETY: [u32; 8] and [u8; 32] have the same size and neither has invalid bit patterns.
138+
unsafe { std::mem::transmute(self.0) }
138139
}
139140

140141
#[inline]

parquet/src/data_type.rs

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,8 @@ macro_rules! gen_as_bytes {
468468
impl AsBytes for $source_ty {
469469
#[allow(clippy::size_of_in_element_count)]
470470
fn as_bytes(&self) -> &[u8] {
471+
// SAFETY: macro is only used with primitive types that have no padding, so the
472+
// resulting slice always refers to initialized memory.
471473
unsafe {
472474
std::slice::from_raw_parts(
473475
self as *const $source_ty as *const u8,
@@ -481,6 +483,8 @@ macro_rules! gen_as_bytes {
481483
#[inline]
482484
#[allow(clippy::size_of_in_element_count)]
483485
fn slice_as_bytes(self_: &[Self]) -> &[u8] {
486+
// SAFETY: macro is only used with primitive types that have no padding, so the
487+
// resulting slice always refers to initialized memory.
484488
unsafe {
485489
std::slice::from_raw_parts(
486490
self_.as_ptr() as *const u8,
@@ -492,10 +496,15 @@ macro_rules! gen_as_bytes {
492496
#[inline]
493497
#[allow(clippy::size_of_in_element_count)]
494498
unsafe fn slice_as_bytes_mut(self_: &mut [Self]) -> &mut [u8] {
495-
std::slice::from_raw_parts_mut(
496-
self_.as_mut_ptr() as *mut u8,
497-
std::mem::size_of_val(self_),
498-
)
499+
// SAFETY: macro is only used with primitive types that have no padding, so the
500+
// resulting slice always refers to initialized memory. Moreover, self has no
501+
// invalid bit patterns, so all writes to the resulting slice will be valid.
502+
unsafe {
503+
std::slice::from_raw_parts_mut(
504+
self_.as_mut_ptr() as *mut u8,
505+
std::mem::size_of_val(self_),
506+
)
507+
}
499508
}
500509
}
501510
};
@@ -534,12 +543,15 @@ unimplemented_slice_as_bytes!(FixedLenByteArray);
534543

535544
impl AsBytes for bool {
536545
fn as_bytes(&self) -> &[u8] {
546+
// SAFETY: a bool is guaranteed to be either 0x00 or 0x01 in memory, so the memory is
547+
// valid.
537548
unsafe { std::slice::from_raw_parts(self as *const bool as *const u8, 1) }
538549
}
539550
}
540551

541552
impl AsBytes for Int96 {
542553
fn as_bytes(&self) -> &[u8] {
554+
// SAFETY: Int96::data is a &[u32; 3].
543555
unsafe { std::slice::from_raw_parts(self.data() as *const [u32] as *const u8, 12) }
544556
}
545557
}
@@ -718,6 +730,7 @@ pub(crate) mod private {
718730

719731
#[inline]
720732
fn encode<W: std::io::Write>(values: &[Self], writer: &mut W, _: &mut BitWriter) -> Result<()> {
733+
// SAFETY: Self is one of i32, i64, f32, f64, which have no padding.
721734
let raw = unsafe {
722735
std::slice::from_raw_parts(
723736
values.as_ptr() as *const u8,
@@ -747,9 +760,10 @@ pub(crate) mod private {
747760
return Err(eof_err!("Not enough bytes to decode"));
748761
}
749762

750-
// SAFETY: Raw types should be as per the standard rust bit-vectors
751-
unsafe {
752-
let raw_buffer = &mut Self::slice_as_bytes_mut(buffer)[..bytes_to_decode];
763+
{
764+
// SAFETY: Self has no invalid bit patterns, so writing to the slice
765+
// obtained with slice_as_bytes_mut is always safe.
766+
let raw_buffer = &mut unsafe { Self::slice_as_bytes_mut(buffer) }[..bytes_to_decode];
753767
raw_buffer.copy_from_slice(data.slice(
754768
decoder.start..decoder.start + bytes_to_decode
755769
).as_ref());
@@ -810,9 +824,7 @@ pub(crate) mod private {
810824
_: &mut BitWriter,
811825
) -> Result<()> {
812826
for value in values {
813-
let raw = unsafe {
814-
std::slice::from_raw_parts(value.data() as *const [u32] as *const u8, 12)
815-
};
827+
let raw = SliceAsBytes::slice_as_bytes(value.data());
816828
writer.write_all(raw)?;
817829
}
818830
Ok(())

parquet/src/util/bit_util.rs

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,11 @@ fn array_from_slice<const N: usize>(bs: &[u8]) -> Result<[u8; N]> {
4242
}
4343
}
4444

45-
pub trait FromBytes: Sized {
45+
/// # Safety
46+
/// All bit patterns 00000xxxx, where there are `BIT_CAPACITY` `x`s,
47+
/// must be valid, unless BIT_CAPACITY is 0.
48+
pub unsafe trait FromBytes: Sized {
49+
const BIT_CAPACITY: usize;
4650
type Buffer: AsMut<[u8]> + Default;
4751
fn try_from_le_slice(b: &[u8]) -> Result<Self>;
4852
fn from_le_bytes(bs: Self::Buffer) -> Self;
@@ -51,7 +55,9 @@ pub trait FromBytes: Sized {
5155
macro_rules! from_le_bytes {
5256
($($ty: ty),*) => {
5357
$(
54-
impl FromBytes for $ty {
58+
// SAFETY: this macro is used for types for which all bit patterns are valid.
59+
unsafe impl FromBytes for $ty {
60+
const BIT_CAPACITY: usize = std::mem::size_of::<$ty>() * 8;
5561
type Buffer = [u8; size_of::<Self>()];
5662
fn try_from_le_slice(b: &[u8]) -> Result<Self> {
5763
Ok(Self::from_le_bytes(array_from_slice(b)?))
@@ -66,7 +72,9 @@ macro_rules! from_le_bytes {
6672

6773
from_le_bytes! { u8, u16, u32, u64, i8, i16, i32, i64, f32, f64 }
6874

69-
impl FromBytes for bool {
75+
// SAFETY: the 0000000x bit pattern is always valid for `bool`.
76+
unsafe impl FromBytes for bool {
77+
const BIT_CAPACITY: usize = 1;
7078
type Buffer = [u8; 1];
7179

7280
fn try_from_le_slice(b: &[u8]) -> Result<Self> {
@@ -77,7 +85,9 @@ impl FromBytes for bool {
7785
}
7886
}
7987

80-
impl FromBytes for Int96 {
88+
// SAFETY: BIT_CAPACITY is 0.
89+
unsafe impl FromBytes for Int96 {
90+
const BIT_CAPACITY: usize = 0;
8191
type Buffer = [u8; 12];
8292

8393
fn try_from_le_slice(b: &[u8]) -> Result<Self> {
@@ -95,7 +105,9 @@ impl FromBytes for Int96 {
95105
}
96106
}
97107

98-
impl FromBytes for ByteArray {
108+
// SAFETY: BIT_CAPACITY is 0.
109+
unsafe impl FromBytes for ByteArray {
110+
const BIT_CAPACITY: usize = 0;
99111
type Buffer = Vec<u8>;
100112

101113
fn try_from_le_slice(b: &[u8]) -> Result<Self> {
@@ -106,7 +118,9 @@ impl FromBytes for ByteArray {
106118
}
107119
}
108120

109-
impl FromBytes for FixedLenByteArray {
121+
// SAFETY: BIT_CAPACITY is 0.
122+
unsafe impl FromBytes for FixedLenByteArray {
123+
const BIT_CAPACITY: usize = 0;
110124
type Buffer = Vec<u8>;
111125

112126
fn try_from_le_slice(b: &[u8]) -> Result<Self> {
@@ -457,10 +471,17 @@ impl BitReader {
457471
}
458472
}
459473

474+
assert_ne!(T::BIT_CAPACITY, 0);
475+
assert!(num_bits <= T::BIT_CAPACITY);
476+
460477
// Read directly into output buffer
461478
match size_of::<T>() {
462479
1 => {
463480
let ptr = batch.as_mut_ptr() as *mut u8;
481+
// SAFETY: batch is properly aligned and sized. Caller guarantees that all bit patterns
482+
// in which only the lowest T::BIT_CAPACITY bits of T are set are valid,
483+
// unpack{8,16,32,64} only set to non0 the lowest num_bits bits, and we
484+
// checked that num_bits <= T::BIT_CAPACITY.
464485
let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) };
465486
while values_to_read - i >= 8 {
466487
let out_slice = (&mut out[i..i + 8]).try_into().unwrap();
@@ -471,6 +492,10 @@ impl BitReader {
471492
}
472493
2 => {
473494
let ptr = batch.as_mut_ptr() as *mut u16;
495+
// SAFETY: batch is properly aligned and sized. Caller guarantees that all bit patterns
496+
// in which only the lowest T::BIT_CAPACITY bits of T are set are valid,
497+
// unpack{8,16,32,64} only set to non0 the lowest num_bits bits, and we
498+
// checked that num_bits <= T::BIT_CAPACITY.
474499
let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) };
475500
while values_to_read - i >= 16 {
476501
let out_slice = (&mut out[i..i + 16]).try_into().unwrap();
@@ -481,6 +506,10 @@ impl BitReader {
481506
}
482507
4 => {
483508
let ptr = batch.as_mut_ptr() as *mut u32;
509+
// SAFETY: batch is properly aligned and sized. Caller guarantees that all bit patterns
510+
// in which only the lowest T::BIT_CAPACITY bits of T are set are valid,
511+
// unpack{8,16,32,64} only set to non0 the lowest num_bits bits, and we
512+
// checked that num_bits <= T::BIT_CAPACITY.
484513
let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) };
485514
while values_to_read - i >= 32 {
486515
let out_slice = (&mut out[i..i + 32]).try_into().unwrap();
@@ -491,6 +520,10 @@ impl BitReader {
491520
}
492521
8 => {
493522
let ptr = batch.as_mut_ptr() as *mut u64;
523+
// SAFETY: batch is properly aligned and sized. Caller guarantees that all bit patterns
524+
// in which only the lowest T::BIT_CAPACITY bits of T are set are valid,
525+
// unpack{8,16,32,64} only set to non0 the lowest num_bits bits, and we
526+
// checked that num_bits <= T::BIT_CAPACITY.
494527
let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) };
495528
while values_to_read - i >= 64 {
496529
let out_slice = (&mut out[i..i + 64]).try_into().unwrap();

0 commit comments

Comments
 (0)