Skip to content

Commit

Permalink
Switched to storing mz_stream as a raw pointer to fix tree borrows vi…
Browse files Browse the repository at this point in the history
…olation.

Removed Deref and DerefMut implementations for StreamWrapper.
  • Loading branch information
icmccorm committed Jan 8, 2024
1 parent f0463d5 commit cd10bb2
Showing 1 changed file with 85 additions and 74 deletions.
159 changes: 85 additions & 74 deletions src/ffi/c.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ impl ErrorMessage {
}

pub struct StreamWrapper {
pub inner: Box<mz_stream>,
pub inner: *mut mz_stream,
}

impl fmt::Debug for StreamWrapper {
Expand All @@ -32,8 +32,11 @@ impl fmt::Debug for StreamWrapper {

impl Default for StreamWrapper {
fn default() -> StreamWrapper {
// We need to store the mz_stream object as a raw pointer since
// a cyclic structure is created in the `state` field, which points
// back to the mz_stream object.
StreamWrapper {
inner: Box::new(mz_stream {
inner: Box::into_raw(Box::new(mz_stream {
next_in: ptr::null_mut(),
avail_in: 0,
total_in: 0,
Expand All @@ -54,7 +57,15 @@ impl Default for StreamWrapper {
zalloc: Some(zalloc),
#[cfg(not(all(feature = "any_zlib", not(feature = "cloudflare-zlib-sys"))))]
zfree: Some(zfree),
}),
})),
}
}
}

impl Drop for StreamWrapper {
fn drop(&mut self) {
unsafe {
drop(Box::from_raw(self.inner));
}
}
}
Expand Down Expand Up @@ -110,20 +121,6 @@ extern "C" fn zfree(_ptr: *mut c_void, address: *mut c_void) {
}
}

impl Deref for StreamWrapper {
type Target = mz_stream;

fn deref(&self) -> &Self::Target {
&*self.inner
}
}

impl DerefMut for StreamWrapper {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut *self.inner
}
}

unsafe impl<D: Direction> Send for Stream<D> {}
unsafe impl<D: Direction> Sync for Stream<D> {}

Expand Down Expand Up @@ -185,9 +182,9 @@ pub struct Inflate {
impl InflateBackend for Inflate {
fn make(zlib_header: bool, window_bits: u8) -> Self {
unsafe {
let mut state = StreamWrapper::default();
let state = StreamWrapper::default();
let ret = mz_inflateInit2(
&mut *state,
state.inner,
if zlib_header {
window_bits as c_int
} else {
Expand All @@ -212,33 +209,40 @@ impl InflateBackend for Inflate {
output: &mut [u8],
flush: FlushDecompress,
) -> Result<Status, DecompressError> {
let raw = &mut *self.inner.stream_wrapper;
raw.msg = ptr::null_mut();
raw.next_in = input.as_ptr() as *mut u8;
raw.avail_in = cmp::min(input.len(), c_uint::MAX as usize) as c_uint;
raw.next_out = output.as_mut_ptr();
raw.avail_out = cmp::min(output.len(), c_uint::MAX as usize) as c_uint;

let rc = unsafe { mz_inflate(raw, flush as c_int) };

// Unfortunately the total counters provided by zlib might be only
// 32 bits wide and overflow while processing large amounts of data.
self.inner.total_in += (raw.next_in as usize - input.as_ptr() as usize) as u64;
self.inner.total_out += (raw.next_out as usize - output.as_ptr() as usize) as u64;

// reset these pointers so we don't accidentally read them later
raw.next_in = ptr::null_mut();
raw.avail_in = 0;
raw.next_out = ptr::null_mut();
raw.avail_out = 0;

match rc {
MZ_DATA_ERROR | MZ_STREAM_ERROR => mem::decompress_failed(self.inner.msg()),
MZ_OK => Ok(Status::Ok),
MZ_BUF_ERROR => Ok(Status::BufError),
MZ_STREAM_END => Ok(Status::StreamEnd),
MZ_NEED_DICT => mem::decompress_need_dict(raw.adler as u32),
c => panic!("unknown return code: {}", c),
let raw = self.inner.stream_wrapper.inner;
// We need to access the `inner` field of the `StreamWrapper` object
// as a raw pointer here since the field `state` in `mz_stream` is
// a pointer back to the `mz_stream` object. Any mutable borrow against
// inner will become invalidated by `mz_inflate`, leading to an invalid
// dereference after that function returns.
unsafe {
(*raw).msg = ptr::null_mut();
(*raw).next_in = input.as_ptr() as *mut u8;
(*raw).avail_in = cmp::min(input.len(), c_uint::MAX as usize) as c_uint;
(*raw).next_out = output.as_mut_ptr();
(*raw).avail_out = cmp::min(output.len(), c_uint::MAX as usize) as c_uint;

let rc = mz_inflate(raw, flush as c_int);

// Unfortunately the total counters provided by zlib might be only
// 32 bits wide and overflow while processing large amounts of data.
self.inner.total_in += ((*raw).next_in as usize - input.as_ptr() as usize) as u64;
self.inner.total_out += ((*raw).next_out as usize - output.as_ptr() as usize) as u64;

// reset these pointers so we don't accidentally read them later
(*raw).next_in = ptr::null_mut();
(*raw).avail_in = 0;
(*raw).next_out = ptr::null_mut();
(*raw).avail_out = 0;

match rc {
MZ_DATA_ERROR | MZ_STREAM_ERROR => mem::decompress_failed(self.inner.msg()),
MZ_OK => Ok(Status::Ok),
MZ_BUF_ERROR => Ok(Status::BufError),
MZ_STREAM_END => Ok(Status::StreamEnd),
MZ_NEED_DICT => mem::decompress_need_dict((*raw).adler as u32),
c => panic!("unknown return code: {}", c),
}
}
}

Expand Down Expand Up @@ -276,9 +280,9 @@ pub struct Deflate {
impl DeflateBackend for Deflate {
fn make(level: Compression, zlib_header: bool, window_bits: u8) -> Self {
unsafe {
let mut state = StreamWrapper::default();
let state = StreamWrapper::default();
let ret = mz_deflateInit2(
&mut *state,
state.inner,
level.0 as c_int,
MZ_DEFLATED,
if zlib_header {
Expand Down Expand Up @@ -306,32 +310,39 @@ impl DeflateBackend for Deflate {
output: &mut [u8],
flush: FlushCompress,
) -> Result<Status, CompressError> {
let raw = &mut *self.inner.stream_wrapper;
raw.msg = ptr::null_mut();
raw.next_in = input.as_ptr() as *mut _;
raw.avail_in = cmp::min(input.len(), c_uint::MAX as usize) as c_uint;
raw.next_out = output.as_mut_ptr();
raw.avail_out = cmp::min(output.len(), c_uint::MAX as usize) as c_uint;

let rc = unsafe { mz_deflate(raw, flush as c_int) };

// Unfortunately the total counters provided by zlib might be only
// 32 bits wide and overflow while processing large amounts of data.
self.inner.total_in += (raw.next_in as usize - input.as_ptr() as usize) as u64;
self.inner.total_out += (raw.next_out as usize - output.as_ptr() as usize) as u64;

// reset these pointers so we don't accidentally read them later
raw.next_in = ptr::null_mut();
raw.avail_in = 0;
raw.next_out = ptr::null_mut();
raw.avail_out = 0;

match rc {
MZ_OK => Ok(Status::Ok),
MZ_BUF_ERROR => Ok(Status::BufError),
MZ_STREAM_END => Ok(Status::StreamEnd),
MZ_STREAM_ERROR => mem::compress_failed(self.inner.msg()),
c => panic!("unknown return code: {}", c),
let raw = self.inner.stream_wrapper.inner;
// We need to access the `inner` field of the `StreamWrapper` object
// as a raw pointer here since the field `state` in `mz_stream` is
// a pointer back to the `mz_stream` object. Any mutable borrow against
// inner will become invalidated by `mz_deflate`, leading to an invalid
// dereference after that function returns.
unsafe {
(*raw).msg = ptr::null_mut();
(*raw).next_in = input.as_ptr() as *mut _;
(*raw).avail_in = cmp::min(input.len(), c_uint::MAX as usize) as c_uint;
(*raw).next_out = output.as_mut_ptr();
(*raw).avail_out = cmp::min(output.len(), c_uint::MAX as usize) as c_uint;

let rc = mz_deflate(raw, flush as c_int);

// Unfortunately the total counters provided by zlib might be only
// 32 bits wide and overflow while processing large amounts of data.

self.inner.total_in += ((*raw).next_in as usize - input.as_ptr() as usize) as u64;
self.inner.total_out += ((*raw).next_out as usize - output.as_ptr() as usize) as u64;
// reset these pointers so we don't accidentally read them later
(*raw).next_in = ptr::null_mut();
(*raw).avail_in = 0;
(*raw).next_out = ptr::null_mut();
(*raw).avail_out = 0;

match rc {
MZ_OK => Ok(Status::Ok),
MZ_BUF_ERROR => Ok(Status::BufError),
MZ_STREAM_END => Ok(Status::StreamEnd),
MZ_STREAM_ERROR => mem::compress_failed(self.inner.msg()),
c => panic!("unknown return code: {}", c),
}
}
}

Expand Down

0 comments on commit cd10bb2

Please sign in to comment.